diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py
index 19ef382..3d0c915 100644
--- a/backend/routes/generate_code.py
+++ b/backend/routes/generate_code.py
@@ -1,4 +1,5 @@
import os
+import asyncio
import traceback
from fastapi import APIRouter, WebSocket
import openai
@@ -14,17 +15,16 @@ from llm import (
)
from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
-from typing import Dict, List, Union, cast, get_args
+from typing import Any, Coroutine, Dict, List, Union, cast, get_args
from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_imported_code_prompt, assemble_prompt
from datetime import datetime
import json
from prompts.claude_prompts import VIDEO_PROMPT
from prompts.types import Stack
-from utils import pprint_prompt
# from utils import pprint_prompt
-from video.utils import extract_tag_content, assemble_claude_prompt_video
+from video.utils import assemble_claude_prompt_video
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
@@ -50,6 +50,59 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st
f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
+# Generate images and return updated completions
+async def process_completion(
+ websocket: WebSocket,
+ completion: str,
+ index: int,
+ should_generate_images: bool,
+ openai_api_key: str | None,
+ openai_base_url: str | None,
+ image_cache: dict[str, str],
+):
+ try:
+ if should_generate_images and openai_api_key:
+ await websocket.send_json(
+ {
+ "type": "status",
+ "value": f"Generating images...",
+ "variantIndex": index,
+ }
+ )
+ updated_html = await generate_images(
+ completion,
+ api_key=openai_api_key,
+ base_url=openai_base_url,
+ image_cache=image_cache,
+ )
+ else:
+ updated_html = completion
+
+ await websocket.send_json(
+ {"type": "setCode", "value": updated_html, "variantIndex": index}
+ )
+ await websocket.send_json(
+ {
+ "type": "status",
+ "value": f"Code generation complete.",
+ "variantIndex": index,
+ }
+ )
+ except Exception as e:
+ traceback.print_exc()
+ print(f"Image generation failed for variant {index}", e)
+ await websocket.send_json(
+ {"type": "setCode", "value": completion, "variantIndex": index}
+ )
+ await websocket.send_json(
+ {
+ "type": "status",
+ "value": f"Image generation failed but code is complete.",
+ "variantIndex": index,
+ }
+ )
+
+
@router.websocket("/generate-code")
async def stream_code(websocket: WebSocket):
await websocket.accept()
@@ -67,7 +120,7 @@ async def stream_code(websocket: WebSocket):
print("Received params")
- # Read the code config settings from the request. Fall back to default if not provided.
+ # Read the code config settings (stack) from the request. Fall back to default if not provided.
generated_code_config = ""
if "generatedCodeConfig" in params and params["generatedCodeConfig"]:
generated_code_config = params["generatedCodeConfig"]
@@ -107,8 +160,6 @@ async def stream_code(websocket: WebSocket):
)
code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20
- exact_llm_version = None
-
print(
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
)
@@ -162,17 +213,20 @@ async def stream_code(websocket: WebSocket):
print("Using official OpenAI URL")
# Get the image generation flag from the request. Fall back to True if not provided.
- should_generate_images = (
- params["isImageGenerationEnabled"]
- if "isImageGenerationEnabled" in params
- else True
- )
+ should_generate_images = bool(params.get("isImageGenerationEnabled", True))
print("generating code...")
- await websocket.send_json({"type": "status", "value": "Generating code..."})
+ await websocket.send_json(
+ {"type": "status", "value": "Generating code...", "variantIndex": 0}
+ )
+ await websocket.send_json(
+ {"type": "status", "value": "Generating code...", "variantIndex": 1}
+ )
- async def process_chunk(content: str):
- await websocket.send_json({"type": "chunk", "value": content})
+ async def process_chunk(content: str, variantIndex: int = 0):
+ await websocket.send_json(
+ {"type": "chunk", "value": content, "variantIndex": variantIndex}
+ )
# Image cache for updates so that we don't have to regenerate images
image_cache: Dict[str, str] = {}
@@ -239,9 +293,9 @@ async def stream_code(websocket: WebSocket):
# pprint_prompt(prompt_messages) # type: ignore
if SHOULD_MOCK_AI_RESPONSE:
- completion = await mock_completion(
- process_chunk, input_mode=validated_input_mode
- )
+ completions = [
+ await mock_completion(process_chunk, input_mode=validated_input_mode)
+ ]
else:
try:
if validated_input_mode == "video":
@@ -251,41 +305,66 @@ async def stream_code(websocket: WebSocket):
)
raise Exception("No Anthropic key")
- completion = await stream_claude_response_native(
- system_prompt=VIDEO_PROMPT,
- messages=prompt_messages, # type: ignore
- api_key=anthropic_api_key,
- callback=lambda x: process_chunk(x),
- model=Llm.CLAUDE_3_OPUS,
- include_thinking=True,
- )
- exact_llm_version = Llm.CLAUDE_3_OPUS
- elif (
- code_generation_model == Llm.CLAUDE_3_SONNET
- or code_generation_model == Llm.CLAUDE_3_5_SONNET_2024_06_20
- ):
- if not anthropic_api_key:
- await throw_error(
- "No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
+ completions = [
+ await stream_claude_response_native(
+ system_prompt=VIDEO_PROMPT,
+ messages=prompt_messages, # type: ignore
+ api_key=anthropic_api_key,
+ callback=lambda x: process_chunk(x),
+ model=Llm.CLAUDE_3_OPUS,
+ include_thinking=True,
)
- raise Exception("No Anthropic key")
-
- completion = await stream_claude_response(
- prompt_messages, # type: ignore
- api_key=anthropic_api_key,
- callback=lambda x: process_chunk(x),
- model=code_generation_model,
- )
- exact_llm_version = code_generation_model
+ ]
else:
- completion = await stream_openai_response(
- prompt_messages, # type: ignore
- api_key=openai_api_key,
- base_url=openai_base_url,
- callback=lambda x: process_chunk(x),
- model=code_generation_model,
- )
- exact_llm_version = code_generation_model
+
+ # Depending on the presence and absence of various keys,
+ # we decide which models to run
+ variant_models = []
+ if openai_api_key and anthropic_api_key:
+ variant_models = ["openai", "anthropic"]
+ elif openai_api_key:
+ variant_models = ["openai", "openai"]
+ elif anthropic_api_key:
+ variant_models = ["anthropic", "anthropic"]
+ else:
+ await throw_error(
+ "No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
+ )
+ raise Exception("No OpenAI or Anthropic key")
+
+ tasks: List[Coroutine[Any, Any, str]] = []
+ for index, model in enumerate(variant_models):
+ if model == "openai":
+ if openai_api_key is None:
+ await throw_error("OpenAI API key is missing.")
+ raise Exception("OpenAI API key is missing.")
+
+ tasks.append(
+ stream_openai_response(
+ prompt_messages,
+ api_key=openai_api_key,
+ base_url=openai_base_url,
+ callback=lambda x, i=index: process_chunk(x, i),
+ model=Llm.GPT_4O_2024_05_13,
+ )
+ )
+ elif model == "anthropic":
+ if anthropic_api_key is None:
+ await throw_error("Anthropic API key is missing.")
+ raise Exception("Anthropic API key is missing.")
+
+ tasks.append(
+ stream_claude_response(
+ prompt_messages,
+ api_key=anthropic_api_key,
+ callback=lambda x, i=index: process_chunk(x, i),
+ model=Llm.CLAUDE_3_5_SONNET_2024_06_20,
+ )
+ )
+
+ completions = await asyncio.gather(*tasks)
+ print("Models used for generation: ", variant_models)
+
except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e)
error_message = (
@@ -321,42 +400,38 @@ async def stream_code(websocket: WebSocket):
)
return await throw_error(error_message)
- if validated_input_mode == "video":
- completion = extract_tag_content("html", completion)
-
- print("Exact used model for generation: ", exact_llm_version)
+ # if validated_input_mode == "video":
+ # completion = extract_tag_content("html", completions[0])
# Strip the completion of everything except the HTML content
- completion = extract_html_content(completion)
+ completions = [extract_html_content(completion) for completion in completions]
# Write the messages dict into a log so that we can debug later
- write_logs(prompt_messages, completion) # type: ignore
+ write_logs(prompt_messages, completions[0]) # type: ignore
try:
- if should_generate_images:
- await websocket.send_json(
- {"type": "status", "value": "Generating images..."}
- )
- updated_html = await generate_images(
+ image_generation_tasks = [
+ process_completion(
+ websocket,
completion,
- api_key=openai_api_key,
- base_url=openai_base_url,
- image_cache=image_cache,
+ index,
+ should_generate_images,
+ openai_api_key,
+ openai_base_url,
+ image_cache,
)
- else:
- updated_html = completion
- await websocket.send_json({"type": "setCode", "value": updated_html})
- await websocket.send_json(
- {"type": "status", "value": "Code generation complete."}
- )
+ for index, completion in enumerate(completions)
+ ]
+ await asyncio.gather(*image_generation_tasks)
except Exception as e:
traceback.print_exc()
- print("Image generation failed", e)
- # Send set code even if image generation fails since that triggers
- # the frontend to update history
- await websocket.send_json({"type": "setCode", "value": completion})
+ print("An error occurred during image generation and processing", e)
await websocket.send_json(
- {"type": "status", "value": "Image generation failed but code is complete."}
+ {
+ "type": "status",
+ "value": "An error occurred during image generation and processing.",
+ "variantIndex": 0,
+ }
)
await websocket.close()
diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx
index 9c20df3..e88fa28 100644
--- a/frontend/src/App.tsx
+++ b/frontend/src/App.tsx
@@ -36,7 +36,12 @@ function App() {
// Outputs
setGeneratedCode,
- setExecutionConsole,
+ currentVariantIndex,
+ setVariant,
+ appendToVariant,
+ resetVariants,
+ appendExecutionConsole,
+ resetExecutionConsoles,
currentVersion,
setCurrentVersion,
appHistory,
@@ -106,10 +111,14 @@ function App() {
const reset = () => {
setAppState(AppState.INITIAL);
setGeneratedCode("");
+ resetVariants();
+ resetExecutionConsoles();
+
+ // Inputs
setReferenceImages([]);
- setExecutionConsole([]);
setUpdateInstruction("");
setIsImportedFromCode(false);
+
setAppHistory([]);
setCurrentVersion(null);
setShouldIncludeResultImage(false);
@@ -159,7 +168,7 @@ function App() {
parentVersion: number | null
) {
// Reset the execution console
- setExecutionConsole([]);
+ resetExecutionConsoles();
// Set the app state
setAppState(AppState.CODING);
@@ -171,10 +180,19 @@ function App() {
wsRef,
updatedParams,
// On change
- (token) => setGeneratedCode((prev) => prev + token),
+ (token, variant) => {
+ if (variant === currentVariantIndex) {
+ setGeneratedCode((prev) => prev + token);
+ }
+
+ appendToVariant(token, variant);
+ },
// On set code
- (code) => {
+ (code, variant) => {
+ setVariant(code, variant);
setGeneratedCode(code);
+
+ // TODO: How to deal with variants?
if (params.generationType === "create") {
setAppHistory([
{
@@ -214,7 +232,7 @@ function App() {
}
},
// On status update
- (line) => setExecutionConsole((prev) => [...prev, line]),
+ (line, variant) => appendExecutionConsole(variant, line),
// On cancel
() => {
cancelCodeGenerationAndReset();
@@ -314,6 +332,7 @@ function App() {
}
setGeneratedCode("");
+ resetVariants();
setUpdateInstruction("");
}
diff --git a/frontend/src/components/sidebar/Sidebar.tsx b/frontend/src/components/sidebar/Sidebar.tsx
index 5246637..b557456 100644
--- a/frontend/src/components/sidebar/Sidebar.tsx
+++ b/frontend/src/components/sidebar/Sidebar.tsx
@@ -12,6 +12,7 @@ import { Button } from "../ui/button";
import { Textarea } from "../ui/textarea";
import { useEffect, useRef } from "react";
import HistoryDisplay from "../history/HistoryDisplay";
+import Variants from "../variants/Variants";
interface SidebarProps {
showSelectAndEditFeature: boolean;
@@ -35,8 +36,16 @@ function Sidebar({
shouldIncludeResultImage,
setShouldIncludeResultImage,
} = useAppStore();
- const { inputMode, generatedCode, referenceImages, executionConsole } =
- useProjectStore();
+
+ const {
+ inputMode,
+ generatedCode,
+ referenceImages,
+ executionConsoles,
+ currentVariantIndex,
+ } = useProjectStore();
+
+ const executionConsole = executionConsoles[currentVariantIndex] || [];
// When coding is complete, focus on the update instruction textarea
useEffect(() => {
@@ -47,6 +56,8 @@ function Sidebar({
return (
<>
+