From aff9352dc04b9a4a77d1bb8f14eac81adc73dc12 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Tue, 30 Jul 2024 15:44:48 -0400 Subject: [PATCH] set up multiple generations --- backend/routes/generate_code.py | 225 ++++++++++++------ frontend/src/App.tsx | 31 ++- frontend/src/components/sidebar/Sidebar.tsx | 15 +- frontend/src/components/variants/Variants.tsx | 64 +++++ frontend/src/generateCode.ts | 20 +- frontend/src/store/project-store.ts | 54 ++++- 6 files changed, 309 insertions(+), 100 deletions(-) create mode 100644 frontend/src/components/variants/Variants.tsx diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 19ef382..3d0c915 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -1,4 +1,5 @@ import os +import asyncio import traceback from fastapi import APIRouter, WebSocket import openai @@ -14,17 +15,16 @@ from llm import ( ) from openai.types.chat import ChatCompletionMessageParam from mock_llm import mock_completion -from typing import Dict, List, Union, cast, get_args +from typing import Any, Coroutine, Dict, List, Union, cast, get_args from image_generation import create_alt_url_mapping, generate_images from prompts import assemble_imported_code_prompt, assemble_prompt from datetime import datetime import json from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack -from utils import pprint_prompt # from utils import pprint_prompt -from video.utils import extract_tag_content, assemble_claude_prompt_video +from video.utils import assemble_claude_prompt_video from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore @@ -50,6 +50,59 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st f.write(json.dumps({"prompt": prompt_messages, "completion": completion})) +# Generate images and return updated completions +async def process_completion( + websocket: WebSocket, + completion: str, + index: int, + should_generate_images: bool, + openai_api_key: str | None, + openai_base_url: str | None, + image_cache: dict[str, str], +): + try: + if should_generate_images and openai_api_key: + await websocket.send_json( + { + "type": "status", + "value": f"Generating images...", + "variantIndex": index, + } + ) + updated_html = await generate_images( + completion, + api_key=openai_api_key, + base_url=openai_base_url, + image_cache=image_cache, + ) + else: + updated_html = completion + + await websocket.send_json( + {"type": "setCode", "value": updated_html, "variantIndex": index} + ) + await websocket.send_json( + { + "type": "status", + "value": f"Code generation complete.", + "variantIndex": index, + } + ) + except Exception as e: + traceback.print_exc() + print(f"Image generation failed for variant {index}", e) + await websocket.send_json( + {"type": "setCode", "value": completion, "variantIndex": index} + ) + await websocket.send_json( + { + "type": "status", + "value": f"Image generation failed but code is complete.", + "variantIndex": index, + } + ) + + @router.websocket("/generate-code") async def stream_code(websocket: WebSocket): await websocket.accept() @@ -67,7 +120,7 @@ async def stream_code(websocket: WebSocket): print("Received params") - # Read the code config settings from the request. Fall back to default if not provided. + # Read the code config settings (stack) from the request. Fall back to default if not provided. generated_code_config = "" if "generatedCodeConfig" in params and params["generatedCodeConfig"]: generated_code_config = params["generatedCodeConfig"] @@ -107,8 +160,6 @@ async def stream_code(websocket: WebSocket): ) code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20 - exact_llm_version = None - print( f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." ) @@ -162,17 +213,20 @@ async def stream_code(websocket: WebSocket): print("Using official OpenAI URL") # Get the image generation flag from the request. Fall back to True if not provided. - should_generate_images = ( - params["isImageGenerationEnabled"] - if "isImageGenerationEnabled" in params - else True - ) + should_generate_images = bool(params.get("isImageGenerationEnabled", True)) print("generating code...") - await websocket.send_json({"type": "status", "value": "Generating code..."}) + await websocket.send_json( + {"type": "status", "value": "Generating code...", "variantIndex": 0} + ) + await websocket.send_json( + {"type": "status", "value": "Generating code...", "variantIndex": 1} + ) - async def process_chunk(content: str): - await websocket.send_json({"type": "chunk", "value": content}) + async def process_chunk(content: str, variantIndex: int = 0): + await websocket.send_json( + {"type": "chunk", "value": content, "variantIndex": variantIndex} + ) # Image cache for updates so that we don't have to regenerate images image_cache: Dict[str, str] = {} @@ -239,9 +293,9 @@ async def stream_code(websocket: WebSocket): # pprint_prompt(prompt_messages) # type: ignore if SHOULD_MOCK_AI_RESPONSE: - completion = await mock_completion( - process_chunk, input_mode=validated_input_mode - ) + completions = [ + await mock_completion(process_chunk, input_mode=validated_input_mode) + ] else: try: if validated_input_mode == "video": @@ -251,41 +305,66 @@ async def stream_code(websocket: WebSocket): ) raise Exception("No Anthropic key") - completion = await stream_claude_response_native( - system_prompt=VIDEO_PROMPT, - messages=prompt_messages, # type: ignore - api_key=anthropic_api_key, - callback=lambda x: process_chunk(x), - model=Llm.CLAUDE_3_OPUS, - include_thinking=True, - ) - exact_llm_version = Llm.CLAUDE_3_OPUS - elif ( - code_generation_model == Llm.CLAUDE_3_SONNET - or code_generation_model == Llm.CLAUDE_3_5_SONNET_2024_06_20 - ): - if not anthropic_api_key: - await throw_error( - "No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog" + completions = [ + await stream_claude_response_native( + system_prompt=VIDEO_PROMPT, + messages=prompt_messages, # type: ignore + api_key=anthropic_api_key, + callback=lambda x: process_chunk(x), + model=Llm.CLAUDE_3_OPUS, + include_thinking=True, ) - raise Exception("No Anthropic key") - - completion = await stream_claude_response( - prompt_messages, # type: ignore - api_key=anthropic_api_key, - callback=lambda x: process_chunk(x), - model=code_generation_model, - ) - exact_llm_version = code_generation_model + ] else: - completion = await stream_openai_response( - prompt_messages, # type: ignore - api_key=openai_api_key, - base_url=openai_base_url, - callback=lambda x: process_chunk(x), - model=code_generation_model, - ) - exact_llm_version = code_generation_model + + # Depending on the presence and absence of various keys, + # we decide which models to run + variant_models = [] + if openai_api_key and anthropic_api_key: + variant_models = ["openai", "anthropic"] + elif openai_api_key: + variant_models = ["openai", "openai"] + elif anthropic_api_key: + variant_models = ["anthropic", "anthropic"] + else: + await throw_error( + "No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog" + ) + raise Exception("No OpenAI or Anthropic key") + + tasks: List[Coroutine[Any, Any, str]] = [] + for index, model in enumerate(variant_models): + if model == "openai": + if openai_api_key is None: + await throw_error("OpenAI API key is missing.") + raise Exception("OpenAI API key is missing.") + + tasks.append( + stream_openai_response( + prompt_messages, + api_key=openai_api_key, + base_url=openai_base_url, + callback=lambda x, i=index: process_chunk(x, i), + model=Llm.GPT_4O_2024_05_13, + ) + ) + elif model == "anthropic": + if anthropic_api_key is None: + await throw_error("Anthropic API key is missing.") + raise Exception("Anthropic API key is missing.") + + tasks.append( + stream_claude_response( + prompt_messages, + api_key=anthropic_api_key, + callback=lambda x, i=index: process_chunk(x, i), + model=Llm.CLAUDE_3_5_SONNET_2024_06_20, + ) + ) + + completions = await asyncio.gather(*tasks) + print("Models used for generation: ", variant_models) + except openai.AuthenticationError as e: print("[GENERATE_CODE] Authentication failed", e) error_message = ( @@ -321,42 +400,38 @@ async def stream_code(websocket: WebSocket): ) return await throw_error(error_message) - if validated_input_mode == "video": - completion = extract_tag_content("html", completion) - - print("Exact used model for generation: ", exact_llm_version) + # if validated_input_mode == "video": + # completion = extract_tag_content("html", completions[0]) # Strip the completion of everything except the HTML content - completion = extract_html_content(completion) + completions = [extract_html_content(completion) for completion in completions] # Write the messages dict into a log so that we can debug later - write_logs(prompt_messages, completion) # type: ignore + write_logs(prompt_messages, completions[0]) # type: ignore try: - if should_generate_images: - await websocket.send_json( - {"type": "status", "value": "Generating images..."} - ) - updated_html = await generate_images( + image_generation_tasks = [ + process_completion( + websocket, completion, - api_key=openai_api_key, - base_url=openai_base_url, - image_cache=image_cache, + index, + should_generate_images, + openai_api_key, + openai_base_url, + image_cache, ) - else: - updated_html = completion - await websocket.send_json({"type": "setCode", "value": updated_html}) - await websocket.send_json( - {"type": "status", "value": "Code generation complete."} - ) + for index, completion in enumerate(completions) + ] + await asyncio.gather(*image_generation_tasks) except Exception as e: traceback.print_exc() - print("Image generation failed", e) - # Send set code even if image generation fails since that triggers - # the frontend to update history - await websocket.send_json({"type": "setCode", "value": completion}) + print("An error occurred during image generation and processing", e) await websocket.send_json( - {"type": "status", "value": "Image generation failed but code is complete."} + { + "type": "status", + "value": "An error occurred during image generation and processing.", + "variantIndex": 0, + } ) await websocket.close() diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 9c20df3..e88fa28 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -36,7 +36,12 @@ function App() { // Outputs setGeneratedCode, - setExecutionConsole, + currentVariantIndex, + setVariant, + appendToVariant, + resetVariants, + appendExecutionConsole, + resetExecutionConsoles, currentVersion, setCurrentVersion, appHistory, @@ -106,10 +111,14 @@ function App() { const reset = () => { setAppState(AppState.INITIAL); setGeneratedCode(""); + resetVariants(); + resetExecutionConsoles(); + + // Inputs setReferenceImages([]); - setExecutionConsole([]); setUpdateInstruction(""); setIsImportedFromCode(false); + setAppHistory([]); setCurrentVersion(null); setShouldIncludeResultImage(false); @@ -159,7 +168,7 @@ function App() { parentVersion: number | null ) { // Reset the execution console - setExecutionConsole([]); + resetExecutionConsoles(); // Set the app state setAppState(AppState.CODING); @@ -171,10 +180,19 @@ function App() { wsRef, updatedParams, // On change - (token) => setGeneratedCode((prev) => prev + token), + (token, variant) => { + if (variant === currentVariantIndex) { + setGeneratedCode((prev) => prev + token); + } + + appendToVariant(token, variant); + }, // On set code - (code) => { + (code, variant) => { + setVariant(code, variant); setGeneratedCode(code); + + // TODO: How to deal with variants? if (params.generationType === "create") { setAppHistory([ { @@ -214,7 +232,7 @@ function App() { } }, // On status update - (line) => setExecutionConsole((prev) => [...prev, line]), + (line, variant) => appendExecutionConsole(variant, line), // On cancel () => { cancelCodeGenerationAndReset(); @@ -314,6 +332,7 @@ function App() { } setGeneratedCode(""); + resetVariants(); setUpdateInstruction(""); } diff --git a/frontend/src/components/sidebar/Sidebar.tsx b/frontend/src/components/sidebar/Sidebar.tsx index 5246637..b557456 100644 --- a/frontend/src/components/sidebar/Sidebar.tsx +++ b/frontend/src/components/sidebar/Sidebar.tsx @@ -12,6 +12,7 @@ import { Button } from "../ui/button"; import { Textarea } from "../ui/textarea"; import { useEffect, useRef } from "react"; import HistoryDisplay from "../history/HistoryDisplay"; +import Variants from "../variants/Variants"; interface SidebarProps { showSelectAndEditFeature: boolean; @@ -35,8 +36,16 @@ function Sidebar({ shouldIncludeResultImage, setShouldIncludeResultImage, } = useAppStore(); - const { inputMode, generatedCode, referenceImages, executionConsole } = - useProjectStore(); + + const { + inputMode, + generatedCode, + referenceImages, + executionConsoles, + currentVariantIndex, + } = useProjectStore(); + + const executionConsole = executionConsoles[currentVariantIndex] || []; // When coding is complete, focus on the update instruction textarea useEffect(() => { @@ -47,6 +56,8 @@ function Sidebar({ return ( <> + + {/* Show code preview only when coding */} {appState === AppState.CODING && (
diff --git a/frontend/src/components/variants/Variants.tsx b/frontend/src/components/variants/Variants.tsx new file mode 100644 index 0000000..219841c --- /dev/null +++ b/frontend/src/components/variants/Variants.tsx @@ -0,0 +1,64 @@ +import { useProjectStore } from "../../store/project-store"; + +function Variants() { + const { + // Inputs + referenceImages, + + // Outputs + variants, + currentVariantIndex, + setCurrentVariantIndex, + setGeneratedCode, + appHistory, + setAppHistory, + } = useProjectStore(); + + function switchVariant(index: number) { + const variant = variants[index]; + setCurrentVariantIndex(index); + setGeneratedCode(variant); + if (appHistory.length === 1) { + setAppHistory([ + { + type: "ai_create", + parentIndex: null, + code: variant, + inputs: { image_url: referenceImages[0] }, + }, + ]); + } else { + setAppHistory((prev) => { + const newHistory = [...prev]; + newHistory[newHistory.length - 1].code = variant; + return newHistory; + }); + } + } + + if (variants.length === 0) { + return null; + } + + return ( +
+
+ {variants.map((_, index) => ( +
switchVariant(index)} + > +

Option {index + 1}

+
+ ))} +
+
+ ); +} + +export default Variants; diff --git a/frontend/src/generateCode.ts b/frontend/src/generateCode.ts index 7fc34e7..b373702 100644 --- a/frontend/src/generateCode.ts +++ b/frontend/src/generateCode.ts @@ -11,12 +11,18 @@ const ERROR_MESSAGE = const CANCEL_MESSAGE = "Code generation cancelled"; +type WebSocketResponse = { + type: "chunk" | "status" | "setCode" | "error"; + value: string; + variantIndex: number; +}; + export function generateCode( wsRef: React.MutableRefObject, params: FullGenerationSettings, - onChange: (chunk: string) => void, - onSetCode: (code: string) => void, - onStatusUpdate: (status: string) => void, + onChange: (chunk: string, variantIndex: number) => void, + onSetCode: (code: string, variantIndex: number) => void, + onStatusUpdate: (status: string, variantIndex: number) => void, onCancel: () => void, onComplete: () => void ) { @@ -31,13 +37,13 @@ export function generateCode( }); ws.addEventListener("message", async (event: MessageEvent) => { - const response = JSON.parse(event.data); + const response = JSON.parse(event.data) as WebSocketResponse; if (response.type === "chunk") { - onChange(response.value); + onChange(response.value, response.variantIndex); } else if (response.type === "status") { - onStatusUpdate(response.value); + onStatusUpdate(response.value, response.variantIndex); } else if (response.type === "setCode") { - onSetCode(response.value); + onSetCode(response.value, response.variantIndex); } else if (response.type === "error") { console.error("Error generating code", response.value); toast.error(response.value); diff --git a/frontend/src/store/project-store.ts b/frontend/src/store/project-store.ts index 8a416f0..20516ae 100644 --- a/frontend/src/store/project-store.ts +++ b/frontend/src/store/project-store.ts @@ -16,10 +16,17 @@ interface ProjectStore { setGeneratedCode: ( updater: string | ((currentCode: string) => string) ) => void; - executionConsole: string[]; - setExecutionConsole: ( - updater: string[] | ((currentConsole: string[]) => string[]) - ) => void; + + variants: string[]; + currentVariantIndex: number; + setCurrentVariantIndex: (index: number) => void; + setVariant: (code: string, index: number) => void; + appendToVariant: (newTokens: string, index: number) => void; + resetVariants: () => void; + + executionConsoles: { [key: number]: string[] }; + appendExecutionConsole: (variantIndex: number, line: string) => void; + resetExecutionConsoles: () => void; // Tracks the currently shown version from app history // TODO: might want to move to appStore @@ -48,14 +55,41 @@ export const useProjectStore = create((set) => ({ generatedCode: typeof updater === "function" ? updater(state.generatedCode) : updater, })), - executionConsole: [], - setExecutionConsole: (updater) => + + variants: [], + currentVariantIndex: 0, + + setCurrentVariantIndex: (index) => set({ currentVariantIndex: index }), + setVariant: (code: string, index: number) => + set((state) => { + const newVariants = [...state.variants]; + while (newVariants.length <= index) { + newVariants.push(""); + } + newVariants[index] = code; + return { variants: newVariants }; + }), + appendToVariant: (newTokens: string, index: number) => + set((state) => { + const newVariants = [...state.variants]; + newVariants[index] += newTokens; + return { variants: newVariants }; + }), + resetVariants: () => set({ variants: [], currentVariantIndex: 0 }), + + executionConsoles: {}, + + appendExecutionConsole: (variantIndex: number, line: string) => set((state) => ({ - executionConsole: - typeof updater === "function" - ? updater(state.executionConsole) - : updater, + executionConsoles: { + ...state.executionConsoles, + [variantIndex]: [ + ...(state.executionConsoles[variantIndex] || []), + line, + ], + }, })), + resetExecutionConsoles: () => set({ executionConsoles: {} }), currentVersion: null, setCurrentVersion: (version) => set({ currentVersion: version }),