set up multiple generations
This commit is contained in:
parent
d7ab620e0b
commit
aff9352dc0
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import asyncio
|
||||
import traceback
|
||||
from fastapi import APIRouter, WebSocket
|
||||
import openai
|
||||
@ -14,17 +15,16 @@ from llm import (
|
||||
)
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from mock_llm import mock_completion
|
||||
from typing import Dict, List, Union, cast, get_args
|
||||
from typing import Any, Coroutine, Dict, List, Union, cast, get_args
|
||||
from image_generation import create_alt_url_mapping, generate_images
|
||||
from prompts import assemble_imported_code_prompt, assemble_prompt
|
||||
from datetime import datetime
|
||||
import json
|
||||
from prompts.claude_prompts import VIDEO_PROMPT
|
||||
from prompts.types import Stack
|
||||
from utils import pprint_prompt
|
||||
|
||||
# from utils import pprint_prompt
|
||||
from video.utils import extract_tag_content, assemble_claude_prompt_video
|
||||
from video.utils import assemble_claude_prompt_video
|
||||
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
||||
|
||||
|
||||
@ -50,6 +50,59 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st
|
||||
f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
|
||||
|
||||
|
||||
# Generate images and return updated completions
|
||||
async def process_completion(
|
||||
websocket: WebSocket,
|
||||
completion: str,
|
||||
index: int,
|
||||
should_generate_images: bool,
|
||||
openai_api_key: str | None,
|
||||
openai_base_url: str | None,
|
||||
image_cache: dict[str, str],
|
||||
):
|
||||
try:
|
||||
if should_generate_images and openai_api_key:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "status",
|
||||
"value": f"Generating images...",
|
||||
"variantIndex": index,
|
||||
}
|
||||
)
|
||||
updated_html = await generate_images(
|
||||
completion,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
image_cache=image_cache,
|
||||
)
|
||||
else:
|
||||
updated_html = completion
|
||||
|
||||
await websocket.send_json(
|
||||
{"type": "setCode", "value": updated_html, "variantIndex": index}
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "status",
|
||||
"value": f"Code generation complete.",
|
||||
"variantIndex": index,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Image generation failed for variant {index}", e)
|
||||
await websocket.send_json(
|
||||
{"type": "setCode", "value": completion, "variantIndex": index}
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "status",
|
||||
"value": f"Image generation failed but code is complete.",
|
||||
"variantIndex": index,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/generate-code")
|
||||
async def stream_code(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
@ -67,7 +120,7 @@ async def stream_code(websocket: WebSocket):
|
||||
|
||||
print("Received params")
|
||||
|
||||
# Read the code config settings from the request. Fall back to default if not provided.
|
||||
# Read the code config settings (stack) from the request. Fall back to default if not provided.
|
||||
generated_code_config = ""
|
||||
if "generatedCodeConfig" in params and params["generatedCodeConfig"]:
|
||||
generated_code_config = params["generatedCodeConfig"]
|
||||
@ -107,8 +160,6 @@ async def stream_code(websocket: WebSocket):
|
||||
)
|
||||
code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20
|
||||
|
||||
exact_llm_version = None
|
||||
|
||||
print(
|
||||
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
|
||||
)
|
||||
@ -162,17 +213,20 @@ async def stream_code(websocket: WebSocket):
|
||||
print("Using official OpenAI URL")
|
||||
|
||||
# Get the image generation flag from the request. Fall back to True if not provided.
|
||||
should_generate_images = (
|
||||
params["isImageGenerationEnabled"]
|
||||
if "isImageGenerationEnabled" in params
|
||||
else True
|
||||
)
|
||||
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
|
||||
|
||||
print("generating code...")
|
||||
await websocket.send_json({"type": "status", "value": "Generating code..."})
|
||||
await websocket.send_json(
|
||||
{"type": "status", "value": "Generating code...", "variantIndex": 0}
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "status", "value": "Generating code...", "variantIndex": 1}
|
||||
)
|
||||
|
||||
async def process_chunk(content: str):
|
||||
await websocket.send_json({"type": "chunk", "value": content})
|
||||
async def process_chunk(content: str, variantIndex: int = 0):
|
||||
await websocket.send_json(
|
||||
{"type": "chunk", "value": content, "variantIndex": variantIndex}
|
||||
)
|
||||
|
||||
# Image cache for updates so that we don't have to regenerate images
|
||||
image_cache: Dict[str, str] = {}
|
||||
@ -239,9 +293,9 @@ async def stream_code(websocket: WebSocket):
|
||||
# pprint_prompt(prompt_messages) # type: ignore
|
||||
|
||||
if SHOULD_MOCK_AI_RESPONSE:
|
||||
completion = await mock_completion(
|
||||
process_chunk, input_mode=validated_input_mode
|
||||
)
|
||||
completions = [
|
||||
await mock_completion(process_chunk, input_mode=validated_input_mode)
|
||||
]
|
||||
else:
|
||||
try:
|
||||
if validated_input_mode == "video":
|
||||
@ -251,41 +305,66 @@ async def stream_code(websocket: WebSocket):
|
||||
)
|
||||
raise Exception("No Anthropic key")
|
||||
|
||||
completion = await stream_claude_response_native(
|
||||
system_prompt=VIDEO_PROMPT,
|
||||
messages=prompt_messages, # type: ignore
|
||||
api_key=anthropic_api_key,
|
||||
callback=lambda x: process_chunk(x),
|
||||
model=Llm.CLAUDE_3_OPUS,
|
||||
include_thinking=True,
|
||||
)
|
||||
exact_llm_version = Llm.CLAUDE_3_OPUS
|
||||
elif (
|
||||
code_generation_model == Llm.CLAUDE_3_SONNET
|
||||
or code_generation_model == Llm.CLAUDE_3_5_SONNET_2024_06_20
|
||||
):
|
||||
if not anthropic_api_key:
|
||||
await throw_error(
|
||||
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
|
||||
completions = [
|
||||
await stream_claude_response_native(
|
||||
system_prompt=VIDEO_PROMPT,
|
||||
messages=prompt_messages, # type: ignore
|
||||
api_key=anthropic_api_key,
|
||||
callback=lambda x: process_chunk(x),
|
||||
model=Llm.CLAUDE_3_OPUS,
|
||||
include_thinking=True,
|
||||
)
|
||||
raise Exception("No Anthropic key")
|
||||
|
||||
completion = await stream_claude_response(
|
||||
prompt_messages, # type: ignore
|
||||
api_key=anthropic_api_key,
|
||||
callback=lambda x: process_chunk(x),
|
||||
model=code_generation_model,
|
||||
)
|
||||
exact_llm_version = code_generation_model
|
||||
]
|
||||
else:
|
||||
completion = await stream_openai_response(
|
||||
prompt_messages, # type: ignore
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
callback=lambda x: process_chunk(x),
|
||||
model=code_generation_model,
|
||||
)
|
||||
exact_llm_version = code_generation_model
|
||||
|
||||
# Depending on the presence and absence of various keys,
|
||||
# we decide which models to run
|
||||
variant_models = []
|
||||
if openai_api_key and anthropic_api_key:
|
||||
variant_models = ["openai", "anthropic"]
|
||||
elif openai_api_key:
|
||||
variant_models = ["openai", "openai"]
|
||||
elif anthropic_api_key:
|
||||
variant_models = ["anthropic", "anthropic"]
|
||||
else:
|
||||
await throw_error(
|
||||
"No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
|
||||
)
|
||||
raise Exception("No OpenAI or Anthropic key")
|
||||
|
||||
tasks: List[Coroutine[Any, Any, str]] = []
|
||||
for index, model in enumerate(variant_models):
|
||||
if model == "openai":
|
||||
if openai_api_key is None:
|
||||
await throw_error("OpenAI API key is missing.")
|
||||
raise Exception("OpenAI API key is missing.")
|
||||
|
||||
tasks.append(
|
||||
stream_openai_response(
|
||||
prompt_messages,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
callback=lambda x, i=index: process_chunk(x, i),
|
||||
model=Llm.GPT_4O_2024_05_13,
|
||||
)
|
||||
)
|
||||
elif model == "anthropic":
|
||||
if anthropic_api_key is None:
|
||||
await throw_error("Anthropic API key is missing.")
|
||||
raise Exception("Anthropic API key is missing.")
|
||||
|
||||
tasks.append(
|
||||
stream_claude_response(
|
||||
prompt_messages,
|
||||
api_key=anthropic_api_key,
|
||||
callback=lambda x, i=index: process_chunk(x, i),
|
||||
model=Llm.CLAUDE_3_5_SONNET_2024_06_20,
|
||||
)
|
||||
)
|
||||
|
||||
completions = await asyncio.gather(*tasks)
|
||||
print("Models used for generation: ", variant_models)
|
||||
|
||||
except openai.AuthenticationError as e:
|
||||
print("[GENERATE_CODE] Authentication failed", e)
|
||||
error_message = (
|
||||
@ -321,42 +400,38 @@ async def stream_code(websocket: WebSocket):
|
||||
)
|
||||
return await throw_error(error_message)
|
||||
|
||||
if validated_input_mode == "video":
|
||||
completion = extract_tag_content("html", completion)
|
||||
|
||||
print("Exact used model for generation: ", exact_llm_version)
|
||||
# if validated_input_mode == "video":
|
||||
# completion = extract_tag_content("html", completions[0])
|
||||
|
||||
# Strip the completion of everything except the HTML content
|
||||
completion = extract_html_content(completion)
|
||||
completions = [extract_html_content(completion) for completion in completions]
|
||||
|
||||
# Write the messages dict into a log so that we can debug later
|
||||
write_logs(prompt_messages, completion) # type: ignore
|
||||
write_logs(prompt_messages, completions[0]) # type: ignore
|
||||
|
||||
try:
|
||||
if should_generate_images:
|
||||
await websocket.send_json(
|
||||
{"type": "status", "value": "Generating images..."}
|
||||
)
|
||||
updated_html = await generate_images(
|
||||
image_generation_tasks = [
|
||||
process_completion(
|
||||
websocket,
|
||||
completion,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
image_cache=image_cache,
|
||||
index,
|
||||
should_generate_images,
|
||||
openai_api_key,
|
||||
openai_base_url,
|
||||
image_cache,
|
||||
)
|
||||
else:
|
||||
updated_html = completion
|
||||
await websocket.send_json({"type": "setCode", "value": updated_html})
|
||||
await websocket.send_json(
|
||||
{"type": "status", "value": "Code generation complete."}
|
||||
)
|
||||
for index, completion in enumerate(completions)
|
||||
]
|
||||
await asyncio.gather(*image_generation_tasks)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print("Image generation failed", e)
|
||||
# Send set code even if image generation fails since that triggers
|
||||
# the frontend to update history
|
||||
await websocket.send_json({"type": "setCode", "value": completion})
|
||||
print("An error occurred during image generation and processing", e)
|
||||
await websocket.send_json(
|
||||
{"type": "status", "value": "Image generation failed but code is complete."}
|
||||
{
|
||||
"type": "status",
|
||||
"value": "An error occurred during image generation and processing.",
|
||||
"variantIndex": 0,
|
||||
}
|
||||
)
|
||||
|
||||
await websocket.close()
|
||||
|
||||
@ -36,7 +36,12 @@ function App() {
|
||||
|
||||
// Outputs
|
||||
setGeneratedCode,
|
||||
setExecutionConsole,
|
||||
currentVariantIndex,
|
||||
setVariant,
|
||||
appendToVariant,
|
||||
resetVariants,
|
||||
appendExecutionConsole,
|
||||
resetExecutionConsoles,
|
||||
currentVersion,
|
||||
setCurrentVersion,
|
||||
appHistory,
|
||||
@ -106,10 +111,14 @@ function App() {
|
||||
const reset = () => {
|
||||
setAppState(AppState.INITIAL);
|
||||
setGeneratedCode("");
|
||||
resetVariants();
|
||||
resetExecutionConsoles();
|
||||
|
||||
// Inputs
|
||||
setReferenceImages([]);
|
||||
setExecutionConsole([]);
|
||||
setUpdateInstruction("");
|
||||
setIsImportedFromCode(false);
|
||||
|
||||
setAppHistory([]);
|
||||
setCurrentVersion(null);
|
||||
setShouldIncludeResultImage(false);
|
||||
@ -159,7 +168,7 @@ function App() {
|
||||
parentVersion: number | null
|
||||
) {
|
||||
// Reset the execution console
|
||||
setExecutionConsole([]);
|
||||
resetExecutionConsoles();
|
||||
|
||||
// Set the app state
|
||||
setAppState(AppState.CODING);
|
||||
@ -171,10 +180,19 @@ function App() {
|
||||
wsRef,
|
||||
updatedParams,
|
||||
// On change
|
||||
(token) => setGeneratedCode((prev) => prev + token),
|
||||
(token, variant) => {
|
||||
if (variant === currentVariantIndex) {
|
||||
setGeneratedCode((prev) => prev + token);
|
||||
}
|
||||
|
||||
appendToVariant(token, variant);
|
||||
},
|
||||
// On set code
|
||||
(code) => {
|
||||
(code, variant) => {
|
||||
setVariant(code, variant);
|
||||
setGeneratedCode(code);
|
||||
|
||||
// TODO: How to deal with variants?
|
||||
if (params.generationType === "create") {
|
||||
setAppHistory([
|
||||
{
|
||||
@ -214,7 +232,7 @@ function App() {
|
||||
}
|
||||
},
|
||||
// On status update
|
||||
(line) => setExecutionConsole((prev) => [...prev, line]),
|
||||
(line, variant) => appendExecutionConsole(variant, line),
|
||||
// On cancel
|
||||
() => {
|
||||
cancelCodeGenerationAndReset();
|
||||
@ -314,6 +332,7 @@ function App() {
|
||||
}
|
||||
|
||||
setGeneratedCode("");
|
||||
resetVariants();
|
||||
setUpdateInstruction("");
|
||||
}
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ import { Button } from "../ui/button";
|
||||
import { Textarea } from "../ui/textarea";
|
||||
import { useEffect, useRef } from "react";
|
||||
import HistoryDisplay from "../history/HistoryDisplay";
|
||||
import Variants from "../variants/Variants";
|
||||
|
||||
interface SidebarProps {
|
||||
showSelectAndEditFeature: boolean;
|
||||
@ -35,8 +36,16 @@ function Sidebar({
|
||||
shouldIncludeResultImage,
|
||||
setShouldIncludeResultImage,
|
||||
} = useAppStore();
|
||||
const { inputMode, generatedCode, referenceImages, executionConsole } =
|
||||
useProjectStore();
|
||||
|
||||
const {
|
||||
inputMode,
|
||||
generatedCode,
|
||||
referenceImages,
|
||||
executionConsoles,
|
||||
currentVariantIndex,
|
||||
} = useProjectStore();
|
||||
|
||||
const executionConsole = executionConsoles[currentVariantIndex] || [];
|
||||
|
||||
// When coding is complete, focus on the update instruction textarea
|
||||
useEffect(() => {
|
||||
@ -47,6 +56,8 @@ function Sidebar({
|
||||
|
||||
return (
|
||||
<>
|
||||
<Variants />
|
||||
|
||||
{/* Show code preview only when coding */}
|
||||
{appState === AppState.CODING && (
|
||||
<div className="flex flex-col">
|
||||
|
||||
64
frontend/src/components/variants/Variants.tsx
Normal file
64
frontend/src/components/variants/Variants.tsx
Normal file
@ -0,0 +1,64 @@
|
||||
import { useProjectStore } from "../../store/project-store";
|
||||
|
||||
function Variants() {
|
||||
const {
|
||||
// Inputs
|
||||
referenceImages,
|
||||
|
||||
// Outputs
|
||||
variants,
|
||||
currentVariantIndex,
|
||||
setCurrentVariantIndex,
|
||||
setGeneratedCode,
|
||||
appHistory,
|
||||
setAppHistory,
|
||||
} = useProjectStore();
|
||||
|
||||
function switchVariant(index: number) {
|
||||
const variant = variants[index];
|
||||
setCurrentVariantIndex(index);
|
||||
setGeneratedCode(variant);
|
||||
if (appHistory.length === 1) {
|
||||
setAppHistory([
|
||||
{
|
||||
type: "ai_create",
|
||||
parentIndex: null,
|
||||
code: variant,
|
||||
inputs: { image_url: referenceImages[0] },
|
||||
},
|
||||
]);
|
||||
} else {
|
||||
setAppHistory((prev) => {
|
||||
const newHistory = [...prev];
|
||||
newHistory[newHistory.length - 1].code = variant;
|
||||
return newHistory;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (variants.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mt-4 mb-4">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
{variants.map((_, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className={`p-2 border rounded-md cursor-pointer ${
|
||||
index === currentVariantIndex
|
||||
? "bg-blue-100 dark:bg-blue-900"
|
||||
: "bg-gray-50 dark:bg-gray-800 hover:bg-gray-100 dark:hover:bg-gray-700"
|
||||
}`}
|
||||
onClick={() => switchVariant(index)}
|
||||
>
|
||||
<h3 className="font-medium mb-1">Option {index + 1}</h3>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default Variants;
|
||||
@ -11,12 +11,18 @@ const ERROR_MESSAGE =
|
||||
|
||||
const CANCEL_MESSAGE = "Code generation cancelled";
|
||||
|
||||
type WebSocketResponse = {
|
||||
type: "chunk" | "status" | "setCode" | "error";
|
||||
value: string;
|
||||
variantIndex: number;
|
||||
};
|
||||
|
||||
export function generateCode(
|
||||
wsRef: React.MutableRefObject<WebSocket | null>,
|
||||
params: FullGenerationSettings,
|
||||
onChange: (chunk: string) => void,
|
||||
onSetCode: (code: string) => void,
|
||||
onStatusUpdate: (status: string) => void,
|
||||
onChange: (chunk: string, variantIndex: number) => void,
|
||||
onSetCode: (code: string, variantIndex: number) => void,
|
||||
onStatusUpdate: (status: string, variantIndex: number) => void,
|
||||
onCancel: () => void,
|
||||
onComplete: () => void
|
||||
) {
|
||||
@ -31,13 +37,13 @@ export function generateCode(
|
||||
});
|
||||
|
||||
ws.addEventListener("message", async (event: MessageEvent) => {
|
||||
const response = JSON.parse(event.data);
|
||||
const response = JSON.parse(event.data) as WebSocketResponse;
|
||||
if (response.type === "chunk") {
|
||||
onChange(response.value);
|
||||
onChange(response.value, response.variantIndex);
|
||||
} else if (response.type === "status") {
|
||||
onStatusUpdate(response.value);
|
||||
onStatusUpdate(response.value, response.variantIndex);
|
||||
} else if (response.type === "setCode") {
|
||||
onSetCode(response.value);
|
||||
onSetCode(response.value, response.variantIndex);
|
||||
} else if (response.type === "error") {
|
||||
console.error("Error generating code", response.value);
|
||||
toast.error(response.value);
|
||||
|
||||
@ -16,10 +16,17 @@ interface ProjectStore {
|
||||
setGeneratedCode: (
|
||||
updater: string | ((currentCode: string) => string)
|
||||
) => void;
|
||||
executionConsole: string[];
|
||||
setExecutionConsole: (
|
||||
updater: string[] | ((currentConsole: string[]) => string[])
|
||||
) => void;
|
||||
|
||||
variants: string[];
|
||||
currentVariantIndex: number;
|
||||
setCurrentVariantIndex: (index: number) => void;
|
||||
setVariant: (code: string, index: number) => void;
|
||||
appendToVariant: (newTokens: string, index: number) => void;
|
||||
resetVariants: () => void;
|
||||
|
||||
executionConsoles: { [key: number]: string[] };
|
||||
appendExecutionConsole: (variantIndex: number, line: string) => void;
|
||||
resetExecutionConsoles: () => void;
|
||||
|
||||
// Tracks the currently shown version from app history
|
||||
// TODO: might want to move to appStore
|
||||
@ -48,14 +55,41 @@ export const useProjectStore = create<ProjectStore>((set) => ({
|
||||
generatedCode:
|
||||
typeof updater === "function" ? updater(state.generatedCode) : updater,
|
||||
})),
|
||||
executionConsole: [],
|
||||
setExecutionConsole: (updater) =>
|
||||
|
||||
variants: [],
|
||||
currentVariantIndex: 0,
|
||||
|
||||
setCurrentVariantIndex: (index) => set({ currentVariantIndex: index }),
|
||||
setVariant: (code: string, index: number) =>
|
||||
set((state) => {
|
||||
const newVariants = [...state.variants];
|
||||
while (newVariants.length <= index) {
|
||||
newVariants.push("");
|
||||
}
|
||||
newVariants[index] = code;
|
||||
return { variants: newVariants };
|
||||
}),
|
||||
appendToVariant: (newTokens: string, index: number) =>
|
||||
set((state) => {
|
||||
const newVariants = [...state.variants];
|
||||
newVariants[index] += newTokens;
|
||||
return { variants: newVariants };
|
||||
}),
|
||||
resetVariants: () => set({ variants: [], currentVariantIndex: 0 }),
|
||||
|
||||
executionConsoles: {},
|
||||
|
||||
appendExecutionConsole: (variantIndex: number, line: string) =>
|
||||
set((state) => ({
|
||||
executionConsole:
|
||||
typeof updater === "function"
|
||||
? updater(state.executionConsole)
|
||||
: updater,
|
||||
executionConsoles: {
|
||||
...state.executionConsoles,
|
||||
[variantIndex]: [
|
||||
...(state.executionConsoles[variantIndex] || []),
|
||||
line,
|
||||
],
|
||||
},
|
||||
})),
|
||||
resetExecutionConsoles: () => set({ executionConsoles: {} }),
|
||||
|
||||
currentVersion: null,
|
||||
setCurrentVersion: (version) => set({ currentVersion: version }),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user