From d7ab620e0b03436f1e69a35c34641ec05760eca1 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Fri, 26 Jul 2024 11:56:21 -0400 Subject: [PATCH 01/47] deep copy messages to avoid modifying the original list in the Claude LLM call --- backend/llm.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/backend/llm.py b/backend/llm.py index 2b71102..450ec2f 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -1,4 +1,5 @@ import base64 +import copy from enum import Enum from typing import Any, Awaitable, Callable, List, cast from anthropic import AsyncAnthropic @@ -92,8 +93,12 @@ async def stream_claude_response( temperature = 0.0 # Translate OpenAI messages to Claude messages - system_prompt = cast(str, messages[0].get("content")) - claude_messages = [dict(message) for message in messages[1:]] + + # Deep copy messages to avoid modifying the original list + cloned_messages = copy.deepcopy(messages) + + system_prompt = cast(str, cloned_messages[0].get("content")) + claude_messages = [dict(message) for message in cloned_messages[1:]] for message in claude_messages: if not isinstance(message["content"], list): continue From aff9352dc04b9a4a77d1bb8f14eac81adc73dc12 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Tue, 30 Jul 2024 15:44:48 -0400 Subject: [PATCH 02/47] set up multiple generations --- backend/routes/generate_code.py | 225 ++++++++++++------ frontend/src/App.tsx | 31 ++- frontend/src/components/sidebar/Sidebar.tsx | 15 +- frontend/src/components/variants/Variants.tsx | 64 +++++ frontend/src/generateCode.ts | 20 +- frontend/src/store/project-store.ts | 54 ++++- 6 files changed, 309 insertions(+), 100 deletions(-) create mode 100644 frontend/src/components/variants/Variants.tsx diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 19ef382..3d0c915 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -1,4 +1,5 @@ import os +import asyncio import traceback from fastapi import APIRouter, WebSocket import openai @@ -14,17 +15,16 @@ from llm import ( ) from openai.types.chat import ChatCompletionMessageParam from mock_llm import mock_completion -from typing import Dict, List, Union, cast, get_args +from typing import Any, Coroutine, Dict, List, Union, cast, get_args from image_generation import create_alt_url_mapping, generate_images from prompts import assemble_imported_code_prompt, assemble_prompt from datetime import datetime import json from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack -from utils import pprint_prompt # from utils import pprint_prompt -from video.utils import extract_tag_content, assemble_claude_prompt_video +from video.utils import assemble_claude_prompt_video from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore @@ -50,6 +50,59 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st f.write(json.dumps({"prompt": prompt_messages, "completion": completion})) +# Generate images and return updated completions +async def process_completion( + websocket: WebSocket, + completion: str, + index: int, + should_generate_images: bool, + openai_api_key: str | None, + openai_base_url: str | None, + image_cache: dict[str, str], +): + try: + if should_generate_images and openai_api_key: + await websocket.send_json( + { + "type": "status", + "value": f"Generating images...", + "variantIndex": index, + } + ) + updated_html = await generate_images( + completion, + api_key=openai_api_key, + base_url=openai_base_url, + image_cache=image_cache, + ) + else: + updated_html = completion + + await websocket.send_json( + {"type": "setCode", "value": updated_html, "variantIndex": index} + ) + await websocket.send_json( + { + "type": "status", + "value": f"Code generation complete.", + "variantIndex": index, + } + ) + except Exception as e: + traceback.print_exc() + print(f"Image generation failed for variant {index}", e) + await websocket.send_json( + {"type": "setCode", "value": completion, "variantIndex": index} + ) + await websocket.send_json( + { + "type": "status", + "value": f"Image generation failed but code is complete.", + "variantIndex": index, + } + ) + + @router.websocket("/generate-code") async def stream_code(websocket: WebSocket): await websocket.accept() @@ -67,7 +120,7 @@ async def stream_code(websocket: WebSocket): print("Received params") - # Read the code config settings from the request. Fall back to default if not provided. + # Read the code config settings (stack) from the request. Fall back to default if not provided. generated_code_config = "" if "generatedCodeConfig" in params and params["generatedCodeConfig"]: generated_code_config = params["generatedCodeConfig"] @@ -107,8 +160,6 @@ async def stream_code(websocket: WebSocket): ) code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20 - exact_llm_version = None - print( f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." ) @@ -162,17 +213,20 @@ async def stream_code(websocket: WebSocket): print("Using official OpenAI URL") # Get the image generation flag from the request. Fall back to True if not provided. - should_generate_images = ( - params["isImageGenerationEnabled"] - if "isImageGenerationEnabled" in params - else True - ) + should_generate_images = bool(params.get("isImageGenerationEnabled", True)) print("generating code...") - await websocket.send_json({"type": "status", "value": "Generating code..."}) + await websocket.send_json( + {"type": "status", "value": "Generating code...", "variantIndex": 0} + ) + await websocket.send_json( + {"type": "status", "value": "Generating code...", "variantIndex": 1} + ) - async def process_chunk(content: str): - await websocket.send_json({"type": "chunk", "value": content}) + async def process_chunk(content: str, variantIndex: int = 0): + await websocket.send_json( + {"type": "chunk", "value": content, "variantIndex": variantIndex} + ) # Image cache for updates so that we don't have to regenerate images image_cache: Dict[str, str] = {} @@ -239,9 +293,9 @@ async def stream_code(websocket: WebSocket): # pprint_prompt(prompt_messages) # type: ignore if SHOULD_MOCK_AI_RESPONSE: - completion = await mock_completion( - process_chunk, input_mode=validated_input_mode - ) + completions = [ + await mock_completion(process_chunk, input_mode=validated_input_mode) + ] else: try: if validated_input_mode == "video": @@ -251,41 +305,66 @@ async def stream_code(websocket: WebSocket): ) raise Exception("No Anthropic key") - completion = await stream_claude_response_native( - system_prompt=VIDEO_PROMPT, - messages=prompt_messages, # type: ignore - api_key=anthropic_api_key, - callback=lambda x: process_chunk(x), - model=Llm.CLAUDE_3_OPUS, - include_thinking=True, - ) - exact_llm_version = Llm.CLAUDE_3_OPUS - elif ( - code_generation_model == Llm.CLAUDE_3_SONNET - or code_generation_model == Llm.CLAUDE_3_5_SONNET_2024_06_20 - ): - if not anthropic_api_key: - await throw_error( - "No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog" + completions = [ + await stream_claude_response_native( + system_prompt=VIDEO_PROMPT, + messages=prompt_messages, # type: ignore + api_key=anthropic_api_key, + callback=lambda x: process_chunk(x), + model=Llm.CLAUDE_3_OPUS, + include_thinking=True, ) - raise Exception("No Anthropic key") - - completion = await stream_claude_response( - prompt_messages, # type: ignore - api_key=anthropic_api_key, - callback=lambda x: process_chunk(x), - model=code_generation_model, - ) - exact_llm_version = code_generation_model + ] else: - completion = await stream_openai_response( - prompt_messages, # type: ignore - api_key=openai_api_key, - base_url=openai_base_url, - callback=lambda x: process_chunk(x), - model=code_generation_model, - ) - exact_llm_version = code_generation_model + + # Depending on the presence and absence of various keys, + # we decide which models to run + variant_models = [] + if openai_api_key and anthropic_api_key: + variant_models = ["openai", "anthropic"] + elif openai_api_key: + variant_models = ["openai", "openai"] + elif anthropic_api_key: + variant_models = ["anthropic", "anthropic"] + else: + await throw_error( + "No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog" + ) + raise Exception("No OpenAI or Anthropic key") + + tasks: List[Coroutine[Any, Any, str]] = [] + for index, model in enumerate(variant_models): + if model == "openai": + if openai_api_key is None: + await throw_error("OpenAI API key is missing.") + raise Exception("OpenAI API key is missing.") + + tasks.append( + stream_openai_response( + prompt_messages, + api_key=openai_api_key, + base_url=openai_base_url, + callback=lambda x, i=index: process_chunk(x, i), + model=Llm.GPT_4O_2024_05_13, + ) + ) + elif model == "anthropic": + if anthropic_api_key is None: + await throw_error("Anthropic API key is missing.") + raise Exception("Anthropic API key is missing.") + + tasks.append( + stream_claude_response( + prompt_messages, + api_key=anthropic_api_key, + callback=lambda x, i=index: process_chunk(x, i), + model=Llm.CLAUDE_3_5_SONNET_2024_06_20, + ) + ) + + completions = await asyncio.gather(*tasks) + print("Models used for generation: ", variant_models) + except openai.AuthenticationError as e: print("[GENERATE_CODE] Authentication failed", e) error_message = ( @@ -321,42 +400,38 @@ async def stream_code(websocket: WebSocket): ) return await throw_error(error_message) - if validated_input_mode == "video": - completion = extract_tag_content("html", completion) - - print("Exact used model for generation: ", exact_llm_version) + # if validated_input_mode == "video": + # completion = extract_tag_content("html", completions[0]) # Strip the completion of everything except the HTML content - completion = extract_html_content(completion) + completions = [extract_html_content(completion) for completion in completions] # Write the messages dict into a log so that we can debug later - write_logs(prompt_messages, completion) # type: ignore + write_logs(prompt_messages, completions[0]) # type: ignore try: - if should_generate_images: - await websocket.send_json( - {"type": "status", "value": "Generating images..."} - ) - updated_html = await generate_images( + image_generation_tasks = [ + process_completion( + websocket, completion, - api_key=openai_api_key, - base_url=openai_base_url, - image_cache=image_cache, + index, + should_generate_images, + openai_api_key, + openai_base_url, + image_cache, ) - else: - updated_html = completion - await websocket.send_json({"type": "setCode", "value": updated_html}) - await websocket.send_json( - {"type": "status", "value": "Code generation complete."} - ) + for index, completion in enumerate(completions) + ] + await asyncio.gather(*image_generation_tasks) except Exception as e: traceback.print_exc() - print("Image generation failed", e) - # Send set code even if image generation fails since that triggers - # the frontend to update history - await websocket.send_json({"type": "setCode", "value": completion}) + print("An error occurred during image generation and processing", e) await websocket.send_json( - {"type": "status", "value": "Image generation failed but code is complete."} + { + "type": "status", + "value": "An error occurred during image generation and processing.", + "variantIndex": 0, + } ) await websocket.close() diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 9c20df3..e88fa28 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -36,7 +36,12 @@ function App() { // Outputs setGeneratedCode, - setExecutionConsole, + currentVariantIndex, + setVariant, + appendToVariant, + resetVariants, + appendExecutionConsole, + resetExecutionConsoles, currentVersion, setCurrentVersion, appHistory, @@ -106,10 +111,14 @@ function App() { const reset = () => { setAppState(AppState.INITIAL); setGeneratedCode(""); + resetVariants(); + resetExecutionConsoles(); + + // Inputs setReferenceImages([]); - setExecutionConsole([]); setUpdateInstruction(""); setIsImportedFromCode(false); + setAppHistory([]); setCurrentVersion(null); setShouldIncludeResultImage(false); @@ -159,7 +168,7 @@ function App() { parentVersion: number | null ) { // Reset the execution console - setExecutionConsole([]); + resetExecutionConsoles(); // Set the app state setAppState(AppState.CODING); @@ -171,10 +180,19 @@ function App() { wsRef, updatedParams, // On change - (token) => setGeneratedCode((prev) => prev + token), + (token, variant) => { + if (variant === currentVariantIndex) { + setGeneratedCode((prev) => prev + token); + } + + appendToVariant(token, variant); + }, // On set code - (code) => { + (code, variant) => { + setVariant(code, variant); setGeneratedCode(code); + + // TODO: How to deal with variants? if (params.generationType === "create") { setAppHistory([ { @@ -214,7 +232,7 @@ function App() { } }, // On status update - (line) => setExecutionConsole((prev) => [...prev, line]), + (line, variant) => appendExecutionConsole(variant, line), // On cancel () => { cancelCodeGenerationAndReset(); @@ -314,6 +332,7 @@ function App() { } setGeneratedCode(""); + resetVariants(); setUpdateInstruction(""); } diff --git a/frontend/src/components/sidebar/Sidebar.tsx b/frontend/src/components/sidebar/Sidebar.tsx index 5246637..b557456 100644 --- a/frontend/src/components/sidebar/Sidebar.tsx +++ b/frontend/src/components/sidebar/Sidebar.tsx @@ -12,6 +12,7 @@ import { Button } from "../ui/button"; import { Textarea } from "../ui/textarea"; import { useEffect, useRef } from "react"; import HistoryDisplay from "../history/HistoryDisplay"; +import Variants from "../variants/Variants"; interface SidebarProps { showSelectAndEditFeature: boolean; @@ -35,8 +36,16 @@ function Sidebar({ shouldIncludeResultImage, setShouldIncludeResultImage, } = useAppStore(); - const { inputMode, generatedCode, referenceImages, executionConsole } = - useProjectStore(); + + const { + inputMode, + generatedCode, + referenceImages, + executionConsoles, + currentVariantIndex, + } = useProjectStore(); + + const executionConsole = executionConsoles[currentVariantIndex] || []; // When coding is complete, focus on the update instruction textarea useEffect(() => { @@ -47,6 +56,8 @@ function Sidebar({ return ( <> + + {/* Show code preview only when coding */} {appState === AppState.CODING && (
diff --git a/frontend/src/components/variants/Variants.tsx b/frontend/src/components/variants/Variants.tsx new file mode 100644 index 0000000..219841c --- /dev/null +++ b/frontend/src/components/variants/Variants.tsx @@ -0,0 +1,64 @@ +import { useProjectStore } from "../../store/project-store"; + +function Variants() { + const { + // Inputs + referenceImages, + + // Outputs + variants, + currentVariantIndex, + setCurrentVariantIndex, + setGeneratedCode, + appHistory, + setAppHistory, + } = useProjectStore(); + + function switchVariant(index: number) { + const variant = variants[index]; + setCurrentVariantIndex(index); + setGeneratedCode(variant); + if (appHistory.length === 1) { + setAppHistory([ + { + type: "ai_create", + parentIndex: null, + code: variant, + inputs: { image_url: referenceImages[0] }, + }, + ]); + } else { + setAppHistory((prev) => { + const newHistory = [...prev]; + newHistory[newHistory.length - 1].code = variant; + return newHistory; + }); + } + } + + if (variants.length === 0) { + return null; + } + + return ( +
+
+ {variants.map((_, index) => ( +
switchVariant(index)} + > +

Option {index + 1}

+
+ ))} +
+
+ ); +} + +export default Variants; diff --git a/frontend/src/generateCode.ts b/frontend/src/generateCode.ts index 7fc34e7..b373702 100644 --- a/frontend/src/generateCode.ts +++ b/frontend/src/generateCode.ts @@ -11,12 +11,18 @@ const ERROR_MESSAGE = const CANCEL_MESSAGE = "Code generation cancelled"; +type WebSocketResponse = { + type: "chunk" | "status" | "setCode" | "error"; + value: string; + variantIndex: number; +}; + export function generateCode( wsRef: React.MutableRefObject, params: FullGenerationSettings, - onChange: (chunk: string) => void, - onSetCode: (code: string) => void, - onStatusUpdate: (status: string) => void, + onChange: (chunk: string, variantIndex: number) => void, + onSetCode: (code: string, variantIndex: number) => void, + onStatusUpdate: (status: string, variantIndex: number) => void, onCancel: () => void, onComplete: () => void ) { @@ -31,13 +37,13 @@ export function generateCode( }); ws.addEventListener("message", async (event: MessageEvent) => { - const response = JSON.parse(event.data); + const response = JSON.parse(event.data) as WebSocketResponse; if (response.type === "chunk") { - onChange(response.value); + onChange(response.value, response.variantIndex); } else if (response.type === "status") { - onStatusUpdate(response.value); + onStatusUpdate(response.value, response.variantIndex); } else if (response.type === "setCode") { - onSetCode(response.value); + onSetCode(response.value, response.variantIndex); } else if (response.type === "error") { console.error("Error generating code", response.value); toast.error(response.value); diff --git a/frontend/src/store/project-store.ts b/frontend/src/store/project-store.ts index 8a416f0..20516ae 100644 --- a/frontend/src/store/project-store.ts +++ b/frontend/src/store/project-store.ts @@ -16,10 +16,17 @@ interface ProjectStore { setGeneratedCode: ( updater: string | ((currentCode: string) => string) ) => void; - executionConsole: string[]; - setExecutionConsole: ( - updater: string[] | ((currentConsole: string[]) => string[]) - ) => void; + + variants: string[]; + currentVariantIndex: number; + setCurrentVariantIndex: (index: number) => void; + setVariant: (code: string, index: number) => void; + appendToVariant: (newTokens: string, index: number) => void; + resetVariants: () => void; + + executionConsoles: { [key: number]: string[] }; + appendExecutionConsole: (variantIndex: number, line: string) => void; + resetExecutionConsoles: () => void; // Tracks the currently shown version from app history // TODO: might want to move to appStore @@ -48,14 +55,41 @@ export const useProjectStore = create((set) => ({ generatedCode: typeof updater === "function" ? updater(state.generatedCode) : updater, })), - executionConsole: [], - setExecutionConsole: (updater) => + + variants: [], + currentVariantIndex: 0, + + setCurrentVariantIndex: (index) => set({ currentVariantIndex: index }), + setVariant: (code: string, index: number) => + set((state) => { + const newVariants = [...state.variants]; + while (newVariants.length <= index) { + newVariants.push(""); + } + newVariants[index] = code; + return { variants: newVariants }; + }), + appendToVariant: (newTokens: string, index: number) => + set((state) => { + const newVariants = [...state.variants]; + newVariants[index] += newTokens; + return { variants: newVariants }; + }), + resetVariants: () => set({ variants: [], currentVariantIndex: 0 }), + + executionConsoles: {}, + + appendExecutionConsole: (variantIndex: number, line: string) => set((state) => ({ - executionConsole: - typeof updater === "function" - ? updater(state.executionConsole) - : updater, + executionConsoles: { + ...state.executionConsoles, + [variantIndex]: [ + ...(state.executionConsoles[variantIndex] || []), + line, + ], + }, })), + resetExecutionConsoles: () => set({ executionConsoles: {} }), currentVersion: null, setCurrentVersion: (version) => set({ currentVersion: version }), From f52ca306a57a76a9daf00fd7077765944367e394 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Tue, 30 Jul 2024 16:09:02 -0400 Subject: [PATCH 03/47] fix bug with handling setCode --- frontend/src/App.tsx | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index e88fa28..789b905 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -189,8 +189,11 @@ function App() { }, // On set code (code, variant) => { + if (variant === currentVariantIndex) { + setGeneratedCode(code); + } + setVariant(code, variant); - setGeneratedCode(code); // TODO: How to deal with variants? if (params.generationType === "create") { From 46c480931a6402bfb9b30e14af7582ed06217b34 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Tue, 30 Jul 2024 16:23:53 -0400 Subject: [PATCH 04/47] make execution console show logs from both variants --- frontend/src/components/sidebar/Sidebar.tsx | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/frontend/src/components/sidebar/Sidebar.tsx b/frontend/src/components/sidebar/Sidebar.tsx index b557456..288b3f8 100644 --- a/frontend/src/components/sidebar/Sidebar.tsx +++ b/frontend/src/components/sidebar/Sidebar.tsx @@ -167,14 +167,21 @@ function Sidebar({
)} -
+

Console

- {executionConsole.map((line, index) => ( -
- {line} + {Object.entries(executionConsoles).map(([index, lines]) => ( +
+ {lines.map((line, lineIndex) => ( +
+ {`${index}:${ + lineIndex + 1 + }`} + {line} +
+ ))}
))}
From 0700de7767790aa53f621b9a6fff1d99f3a8f554 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Tue, 30 Jul 2024 16:27:04 -0400 Subject: [PATCH 05/47] standardize to using typed send_message --- backend/mock_llm.py | 4 +- backend/routes/generate_code.py | 75 +++++++++++++-------------------- 2 files changed, 31 insertions(+), 48 deletions(-) diff --git a/backend/mock_llm.py b/backend/mock_llm.py index b85b1b1..a76b906 100644 --- a/backend/mock_llm.py +++ b/backend/mock_llm.py @@ -8,7 +8,7 @@ STREAM_CHUNK_SIZE = 20 async def mock_completion( - process_chunk: Callable[[str], Awaitable[None]], input_mode: InputMode + process_chunk: Callable[[str, int], Awaitable[None]], input_mode: InputMode ) -> str: code_to_return = ( TALLY_FORM_VIDEO_PROMPT_MOCK @@ -17,7 +17,7 @@ async def mock_completion( ) for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE): - await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE]) + await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE], 0) await asyncio.sleep(0.01) if input_mode == "video": diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 3d0c915..ed02b0e 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -15,7 +15,7 @@ from llm import ( ) from openai.types.chat import ChatCompletionMessageParam from mock_llm import mock_completion -from typing import Any, Coroutine, Dict, List, Union, cast, get_args +from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args from image_generation import create_alt_url_mapping, generate_images from prompts import assemble_imported_code_prompt, assemble_prompt from datetime import datetime @@ -52,23 +52,20 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st # Generate images and return updated completions async def process_completion( - websocket: WebSocket, completion: str, index: int, should_generate_images: bool, openai_api_key: str | None, openai_base_url: str | None, image_cache: dict[str, str], + send_message: Callable[ + [Literal["chunk", "status", "setCode", "error"], str, int], + Coroutine[Any, Any, None], + ], ): try: if should_generate_images and openai_api_key: - await websocket.send_json( - { - "type": "status", - "value": f"Generating images...", - "variantIndex": index, - } - ) + await send_message("status", "Generating images...", index) updated_html = await generate_images( completion, api_key=openai_api_key, @@ -78,28 +75,14 @@ async def process_completion( else: updated_html = completion - await websocket.send_json( - {"type": "setCode", "value": updated_html, "variantIndex": index} - ) - await websocket.send_json( - { - "type": "status", - "value": f"Code generation complete.", - "variantIndex": index, - } - ) + await send_message("setCode", updated_html, index) + await send_message("status", "Code generation complete.", index) except Exception as e: traceback.print_exc() print(f"Image generation failed for variant {index}", e) - await websocket.send_json( - {"type": "setCode", "value": completion, "variantIndex": index} - ) - await websocket.send_json( - { - "type": "status", - "value": f"Image generation failed but code is complete.", - "variantIndex": index, - } + await send_message("setCode", completion, index) + await send_message( + "status", "Image generation failed but code is complete.", index ) @@ -115,6 +98,15 @@ async def stream_code(websocket: WebSocket): await websocket.send_json({"type": "error", "value": message}) await websocket.close(APP_ERROR_WEB_SOCKET_CODE) + async def send_message( + type: Literal["chunk", "status", "setCode", "error"], + value: str, + variantIndex: int, + ): + await websocket.send_json( + {"type": type, "value": value, "variantIndex": variantIndex} + ) + # TODO: Are the values always strings? params: Dict[str, str] = await websocket.receive_json() @@ -216,17 +208,11 @@ async def stream_code(websocket: WebSocket): should_generate_images = bool(params.get("isImageGenerationEnabled", True)) print("generating code...") - await websocket.send_json( - {"type": "status", "value": "Generating code...", "variantIndex": 0} - ) - await websocket.send_json( - {"type": "status", "value": "Generating code...", "variantIndex": 1} - ) + await send_message("status", "Generating code...", 0) + await send_message("status", "Generating code...", 1) - async def process_chunk(content: str, variantIndex: int = 0): - await websocket.send_json( - {"type": "chunk", "value": content, "variantIndex": variantIndex} - ) + async def process_chunk(content: str, variantIndex: int): + await send_message("chunk", content, variantIndex) # Image cache for updates so that we don't have to regenerate images image_cache: Dict[str, str] = {} @@ -259,6 +245,7 @@ async def stream_code(websocket: WebSocket): else: prompt_messages = assemble_prompt(params["image"], valid_stack) except: + # TODO: This should use variantIndex await websocket.send_json( { "type": "error", @@ -310,7 +297,7 @@ async def stream_code(websocket: WebSocket): system_prompt=VIDEO_PROMPT, messages=prompt_messages, # type: ignore api_key=anthropic_api_key, - callback=lambda x: process_chunk(x), + callback=lambda x: process_chunk(x, 0), model=Llm.CLAUDE_3_OPUS, include_thinking=True, ) @@ -412,13 +399,13 @@ async def stream_code(websocket: WebSocket): try: image_generation_tasks = [ process_completion( - websocket, completion, index, should_generate_images, openai_api_key, openai_base_url, image_cache, + send_message, ) for index, completion in enumerate(completions) ] @@ -426,12 +413,8 @@ async def stream_code(websocket: WebSocket): except Exception as e: traceback.print_exc() print("An error occurred during image generation and processing", e) - await websocket.send_json( - { - "type": "status", - "value": "An error occurred during image generation and processing.", - "variantIndex": 0, - } + await send_message( + "status", "An error occurred during image generation and processing.", 0 ) await websocket.close() From 64926408b0d193ce6653d2c0045b97297491862b Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Tue, 30 Jul 2024 16:29:06 -0400 Subject: [PATCH 06/47] refactor --- backend/logging/__init__.py | 0 backend/logging/core.py | 23 +++++++++++++++++++++++ backend/routes/generate_code.py | 24 ++---------------------- 3 files changed, 25 insertions(+), 22 deletions(-) create mode 100644 backend/logging/__init__.py create mode 100644 backend/logging/core.py diff --git a/backend/logging/__init__.py b/backend/logging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/logging/core.py b/backend/logging/core.py new file mode 100644 index 0000000..e89096f --- /dev/null +++ b/backend/logging/core.py @@ -0,0 +1,23 @@ +from datetime import datetime +import json +import os +from openai.types.chat import ChatCompletionMessageParam + + +def write_logs(prompt_messages: list[ChatCompletionMessageParam], completion: str): + # Get the logs path from environment, default to the current working directory + logs_path = os.environ.get("LOGS_PATH", os.getcwd()) + + # Create run_logs directory if it doesn't exist within the specified logs path + logs_directory = os.path.join(logs_path, "run_logs") + if not os.path.exists(logs_directory): + os.makedirs(logs_directory) + + print("Writing to logs directory:", logs_directory) + + # Generate a unique filename using the current timestamp within the logs directory + filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json") + + # Write the messages dict into a new file for each run + with open(filename, "w") as f: + f.write(json.dumps({"prompt": prompt_messages, "completion": completion})) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index ed02b0e..59bd42d 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -14,12 +14,11 @@ from llm import ( stream_openai_response, ) from openai.types.chat import ChatCompletionMessageParam +from logging.core import write_logs from mock_llm import mock_completion from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args from image_generation import create_alt_url_mapping, generate_images from prompts import assemble_imported_code_prompt, assemble_prompt -from datetime import datetime -import json from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack @@ -31,25 +30,6 @@ from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore router = APIRouter() -def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str): - # Get the logs path from environment, default to the current working directory - logs_path = os.environ.get("LOGS_PATH", os.getcwd()) - - # Create run_logs directory if it doesn't exist within the specified logs path - logs_directory = os.path.join(logs_path, "run_logs") - if not os.path.exists(logs_directory): - os.makedirs(logs_directory) - - print("Writing to logs directory:", logs_directory) - - # Generate a unique filename using the current timestamp within the logs directory - filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json") - - # Write the messages dict into a new file for each run - with open(filename, "w") as f: - f.write(json.dumps({"prompt": prompt_messages, "completion": completion})) - - # Generate images and return updated completions async def process_completion( completion: str, @@ -394,7 +374,7 @@ async def stream_code(websocket: WebSocket): completions = [extract_html_content(completion) for completion in completions] # Write the messages dict into a log so that we can debug later - write_logs(prompt_messages, completions[0]) # type: ignore + write_logs(prompt_messages, completions[0]) try: image_generation_tasks = [ From 24a123db3608ebd565a288f9114f8d4ad5a64dfd Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Tue, 30 Jul 2024 16:37:36 -0400 Subject: [PATCH 07/47] refactors --- backend/{logging => fs_logging}/__init__.py | 0 backend/{logging => fs_logging}/core.py | 0 backend/routes/generate_code.py | 28 +++++++++++++-------- 3 files changed, 17 insertions(+), 11 deletions(-) rename backend/{logging => fs_logging}/__init__.py (100%) rename backend/{logging => fs_logging}/core.py (100%) diff --git a/backend/logging/__init__.py b/backend/fs_logging/__init__.py similarity index 100% rename from backend/logging/__init__.py rename to backend/fs_logging/__init__.py diff --git a/backend/logging/core.py b/backend/fs_logging/core.py similarity index 100% rename from backend/logging/core.py rename to backend/fs_logging/core.py diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 59bd42d..6d4f0e4 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -14,7 +14,7 @@ from llm import ( stream_openai_response, ) from openai.types.chat import ChatCompletionMessageParam -from logging.core import write_logs +from fs_logging.core import write_logs from mock_llm import mock_completion from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args from image_generation import create_alt_url_mapping, generate_images @@ -30,6 +30,21 @@ from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore router = APIRouter() +# Auto-upgrade usage of older models +def auto_upgrade_model(code_generation_model: Llm) -> Llm: + if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}: + print( + f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13" + ) + return Llm.GPT_4O_2024_05_13 + elif code_generation_model == Llm.CLAUDE_3_SONNET: + print( + f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20" + ) + return Llm.CLAUDE_3_5_SONNET_2024_06_20 + return code_generation_model + + # Generate images and return updated completions async def process_completion( completion: str, @@ -121,16 +136,7 @@ async def stream_code(websocket: WebSocket): raise Exception(f"Invalid model: {code_generation_model_str}") # Auto-upgrade usage of older models - if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}: - print( - f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13" - ) - code_generation_model = Llm.GPT_4O_2024_05_13 - elif code_generation_model == Llm.CLAUDE_3_SONNET: - print( - f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20" - ) - code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20 + code_generation_model = auto_upgrade_model(code_generation_model) print( f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." From 96658819f359994f37a6429fb915f353b674c319 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 09:53:21 -0400 Subject: [PATCH 08/47] fix issue with loading variants --- frontend/src/store/project-store.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/frontend/src/store/project-store.ts b/frontend/src/store/project-store.ts index 20516ae..d35ba14 100644 --- a/frontend/src/store/project-store.ts +++ b/frontend/src/store/project-store.ts @@ -72,6 +72,9 @@ export const useProjectStore = create((set) => ({ appendToVariant: (newTokens: string, index: number) => set((state) => { const newVariants = [...state.variants]; + while (newVariants.length <= index) { + newVariants.push(""); + } newVariants[index] += newTokens; return { variants: newVariants }; }), From 64dd7d62792e115bcae33bf5245f55e2baa2800f Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 10:04:04 -0400 Subject: [PATCH 09/47] refactor --- frontend/src/store/project-store.ts | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/frontend/src/store/project-store.ts b/frontend/src/store/project-store.ts index d35ba14..6bb2cdb 100644 --- a/frontend/src/store/project-store.ts +++ b/frontend/src/store/project-store.ts @@ -56,8 +56,9 @@ export const useProjectStore = create((set) => ({ typeof updater === "function" ? updater(state.generatedCode) : updater, })), - variants: [], currentVariantIndex: 0, + variants: [], + executionConsoles: {}, setCurrentVariantIndex: (index) => set({ currentVariantIndex: index }), setVariant: (code: string, index: number) => @@ -80,8 +81,6 @@ export const useProjectStore = create((set) => ({ }), resetVariants: () => set({ variants: [], currentVariantIndex: 0 }), - executionConsoles: {}, - appendExecutionConsole: (variantIndex: number, line: string) => set((state) => ({ executionConsoles: { From 0f731598ddf69517c9d50b67016b692350decbb4 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 10:10:13 -0400 Subject: [PATCH 10/47] refactor to get .get() --- backend/routes/generate_code.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 6d4f0e4..055853f 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -107,13 +107,11 @@ async def stream_code(websocket: WebSocket): print("Received params") - # Read the code config settings (stack) from the request. Fall back to default if not provided. - generated_code_config = "" - if "generatedCodeConfig" in params and params["generatedCodeConfig"]: - generated_code_config = params["generatedCodeConfig"] + # Read the code config settings (stack) from the request. + generated_code_config = params.get("generatedCodeConfig", "") if not generated_code_config in get_args(Stack): await throw_error(f"Invalid generated code config: {generated_code_config}") - return + raise Exception(f"Invalid generated code config: {generated_code_config}") # Cast the variable to the Stack type valid_stack = cast(Stack, generated_code_config) From ff12790883e91954ed829a83982084050ab06201 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 10:14:23 -0400 Subject: [PATCH 11/47] refactor --- backend/llm.py | 8 ++++++++ backend/routes/generate_code.py | 20 +++++++------------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/backend/llm.py b/backend/llm.py index 450ec2f..d828b6d 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -23,6 +23,14 @@ class Llm(Enum): CLAUDE_3_5_SONNET_2024_06_20 = "claude-3-5-sonnet-20240620" +def is_openai_model(model: Llm) -> bool: + return model in { + Llm.GPT_4_VISION, + Llm.GPT_4_TURBO_2024_04_09, + Llm.GPT_4O_2024_05_13, + } + + # Will throw errors if you send a garbage string def convert_frontend_str_to_llm(frontend_str: str) -> Llm: if frontend_str == "gpt_4_vision": diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 055853f..04f5fa8 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -4,11 +4,12 @@ import traceback from fastapi import APIRouter, WebSocket import openai from codegen.utils import extract_html_content -from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE +from config import ANTHROPIC_API_KEY, IS_PROD, OPENAI_API_KEY, SHOULD_MOCK_AI_RESPONSE from custom_types import InputMode from llm import ( Llm, convert_frontend_str_to_llm, + is_openai_model, stream_claude_response, stream_claude_response_native, stream_openai_response, @@ -90,6 +91,7 @@ async def stream_code(websocket: WebSocket): async def throw_error( message: str, ): + print(message) await websocket.send_json({"type": "error", "value": message}) await websocket.close(APP_ERROR_WEB_SOCKET_CODE) @@ -104,7 +106,6 @@ async def stream_code(websocket: WebSocket): # TODO: Are the values always strings? params: Dict[str, str] = await websocket.receive_json() - print("Received params") # Read the code config settings (stack) from the request. @@ -135,28 +136,21 @@ async def stream_code(websocket: WebSocket): # Auto-upgrade usage of older models code_generation_model = auto_upgrade_model(code_generation_model) - print( f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." ) # Get the OpenAI API key from the request. Fall back to environment variable if not provided. # If neither is provided, we throw an error. - openai_api_key = None - if params["openAiApiKey"]: - openai_api_key = params["openAiApiKey"] + openai_api_key = params.get("openAiApiKey") + if openai_api_key: print("Using OpenAI API key from client-side settings dialog") else: - openai_api_key = os.environ.get("OPENAI_API_KEY") + openai_api_key = OPENAI_API_KEY if openai_api_key: print("Using OpenAI API key from environment variable") - if not openai_api_key and ( - code_generation_model == Llm.GPT_4_VISION - or code_generation_model == Llm.GPT_4_TURBO_2024_04_09 - or code_generation_model == Llm.GPT_4O_2024_05_13 - ): - print("OpenAI API key not found") + if not openai_api_key and is_openai_model(code_generation_model): await throw_error( "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server." ) From 3fbc0f94589ed4d33736f7ca9c0b934155b4b1af Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 10:15:07 -0400 Subject: [PATCH 12/47] refactor --- backend/routes/generate_code.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 04f5fa8..76b1fbe 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -158,9 +158,8 @@ async def stream_code(websocket: WebSocket): # Get the Anthropic API key from the request. Fall back to environment variable if not provided. # If neither is provided, we throw an error later only if Claude is used. - anthropic_api_key = None - if "anthropicApiKey" in params and params["anthropicApiKey"]: - anthropic_api_key = params["anthropicApiKey"] + anthropic_api_key = params.get("anthropicApiKey") + if anthropic_api_key: print("Using Anthropic API key from client-side settings dialog") else: anthropic_api_key = ANTHROPIC_API_KEY From dd7a51dd3468e423c625f07c79ba90bacfae4183 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 10:16:03 -0400 Subject: [PATCH 13/47] refactor --- backend/routes/generate_code.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 76b1fbe..d436e83 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -170,8 +170,8 @@ async def stream_code(websocket: WebSocket): openai_base_url: Union[str, None] = None # Disable user-specified OpenAI Base URL in prod if not os.environ.get("IS_PROD"): - if "openAiBaseURL" in params and params["openAiBaseURL"]: - openai_base_url = params["openAiBaseURL"] + openai_base_url = params.get("openAiBaseURL") + if openai_base_url: print("Using OpenAI Base URL from client-side settings dialog") else: openai_base_url = os.environ.get("OPENAI_BASE_URL") From 3591588e2b7bf62e47cf4f5047e707ce137c9129 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 11:07:30 -0400 Subject: [PATCH 14/47] abstract out prompt assembly into a separate function --- backend/routes/generate_code.py | 136 ++++++++++++++++++-------------- 1 file changed, 76 insertions(+), 60 deletions(-) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index d436e83..ced9655 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -23,7 +23,7 @@ from prompts import assemble_imported_code_prompt, assemble_prompt from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack -# from utils import pprint_prompt +from utils import pprint_prompt from video.utils import assemble_claude_prompt_video from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore @@ -31,6 +31,64 @@ from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore router = APIRouter() +async def create_prompt( + params: Dict[str, str], stack: Stack, model: Llm, input_mode: InputMode +) -> tuple[list[ChatCompletionMessageParam], Dict[str, str]]: + + image_cache: Dict[str, str] = {} + + # If this generation started off with imported code, we need to assemble the prompt differently + if params.get("isImportedFromCode"): + original_imported_code = params["history"][0] + prompt_messages = assemble_imported_code_prompt( + original_imported_code, stack, model + ) + for index, text in enumerate(params["history"][1:]): + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + prompt_messages.append(message) + else: + # Assemble the prompt for non-imported code + if params.get("resultImage"): + prompt_messages = assemble_prompt( + params["image"], stack, params["resultImage"] + ) + else: + prompt_messages = assemble_prompt(params["image"], stack) + + if params["generationType"] == "update": + # Transform the history tree into message format + # TODO: Move this to frontend + for index, text in enumerate(params["history"]): + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + prompt_messages.append(message) + + image_cache = create_alt_url_mapping(params["history"][-2]) + + if input_mode == "video": + video_data_url = params["image"] + prompt_messages = await assemble_claude_prompt_video(video_data_url) + + return prompt_messages, image_cache + + # Auto-upgrade usage of older models def auto_upgrade_model(code_generation_model: Llm) -> Llm: if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}: @@ -185,76 +243,34 @@ async def stream_code(websocket: WebSocket): should_generate_images = bool(params.get("isImageGenerationEnabled", True)) print("generating code...") + + # TODO(*): Print with send_message instead of print statements await send_message("status", "Generating code...", 0) await send_message("status", "Generating code...", 1) + # TODO(*): Move down async def process_chunk(content: str, variantIndex: int): await send_message("chunk", content, variantIndex) # Image cache for updates so that we don't have to regenerate images image_cache: Dict[str, str] = {} - # If this generation started off with imported code, we need to assemble the prompt differently - if params.get("isImportedFromCode") and params["isImportedFromCode"]: - original_imported_code = params["history"][0] - prompt_messages = assemble_imported_code_prompt( - original_imported_code, valid_stack, code_generation_model + try: + prompt_messages, image_cache = await create_prompt( + params, valid_stack, code_generation_model, validated_input_mode ) - for index, text in enumerate(params["history"][1:]): - if index % 2 == 0: - message: ChatCompletionMessageParam = { - "role": "user", - "content": text, - } - else: - message: ChatCompletionMessageParam = { - "role": "assistant", - "content": text, - } - prompt_messages.append(message) - else: - # Assemble the prompt - try: - if params.get("resultImage") and params["resultImage"]: - prompt_messages = assemble_prompt( - params["image"], valid_stack, params["resultImage"] - ) - else: - prompt_messages = assemble_prompt(params["image"], valid_stack) - except: - # TODO: This should use variantIndex - await websocket.send_json( - { - "type": "error", - "value": "Error assembling prompt. Contact support at support@picoapps.xyz", - } - ) - await websocket.close() - return + except: + # TODO(*): This should use variantIndex + await websocket.send_json( + { + "type": "error", + "value": "Error assembling prompt. Contact support at support@picoapps.xyz", + } + ) + await websocket.close() + raise - if params["generationType"] == "update": - # Transform the history tree into message format - # TODO: Move this to frontend - for index, text in enumerate(params["history"]): - if index % 2 == 0: - message: ChatCompletionMessageParam = { - "role": "assistant", - "content": text, - } - else: - message: ChatCompletionMessageParam = { - "role": "user", - "content": text, - } - prompt_messages.append(message) - - image_cache = create_alt_url_mapping(params["history"][-2]) - - if validated_input_mode == "video": - video_data_url = params["image"] - prompt_messages = await assemble_claude_prompt_video(video_data_url) - - # pprint_prompt(prompt_messages) # type: ignore + pprint_prompt(prompt_messages) # type: ignore if SHOULD_MOCK_AI_RESPONSE: completions = [ From bcb89a3c2315cbcab18864a056c6c46d30d26fe0 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 11:15:56 -0400 Subject: [PATCH 15/47] refactor --- backend/prompts/__init__.py | 69 +++++++++++++++++++++++++++++++-- backend/routes/generate_code.py | 65 +------------------------------ 2 files changed, 67 insertions(+), 67 deletions(-) diff --git a/backend/prompts/__init__.py b/backend/prompts/__init__.py index dc96ab9..de884c5 100644 --- a/backend/prompts/__init__.py +++ b/backend/prompts/__init__.py @@ -1,11 +1,14 @@ -from typing import List, NoReturn, Union +from typing import Union from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam +from custom_types import InputMode +from image_generation import create_alt_url_mapping from llm import Llm from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS from prompts.screenshot_system_prompts import SYSTEM_PROMPTS from prompts.types import Stack +from video.utils import assemble_claude_prompt_video USER_PROMPT = """ @@ -17,9 +20,67 @@ Generate code for a SVG that looks exactly like this. """ +async def create_prompt( + params: dict[str, str], stack: Stack, model: Llm, input_mode: InputMode +) -> tuple[list[ChatCompletionMessageParam], dict[str, str]]: + + image_cache: dict[str, str] = {} + + # If this generation started off with imported code, we need to assemble the prompt differently + if params.get("isImportedFromCode"): + original_imported_code = params["history"][0] + prompt_messages = assemble_imported_code_prompt( + original_imported_code, stack, model + ) + for index, text in enumerate(params["history"][1:]): + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + prompt_messages.append(message) + else: + # Assemble the prompt for non-imported code + if params.get("resultImage"): + prompt_messages = assemble_prompt( + params["image"], stack, params["resultImage"] + ) + else: + prompt_messages = assemble_prompt(params["image"], stack) + + if params["generationType"] == "update": + # Transform the history tree into message format + # TODO: Move this to frontend + for index, text in enumerate(params["history"]): + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + prompt_messages.append(message) + + image_cache = create_alt_url_mapping(params["history"][-2]) + + if input_mode == "video": + video_data_url = params["image"] + prompt_messages = await assemble_claude_prompt_video(video_data_url) + + return prompt_messages, image_cache + + def assemble_imported_code_prompt( code: str, stack: Stack, model: Llm -) -> List[ChatCompletionMessageParam]: +) -> list[ChatCompletionMessageParam]: system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack] user_content = ( @@ -53,11 +114,11 @@ def assemble_prompt( image_data_url: str, stack: Stack, result_image_data_url: Union[str, None] = None, -) -> List[ChatCompletionMessageParam]: +) -> list[ChatCompletionMessageParam]: system_content = SYSTEM_PROMPTS[stack] user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT - user_content: List[ChatCompletionContentPartParam] = [ + user_content: list[ChatCompletionContentPartParam] = [ { "type": "image_url", "image_url": {"url": image_data_url, "detail": "high"}, diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index ced9655..0c6e921 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -14,81 +14,20 @@ from llm import ( stream_claude_response_native, stream_openai_response, ) -from openai.types.chat import ChatCompletionMessageParam from fs_logging.core import write_logs from mock_llm import mock_completion from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args -from image_generation import create_alt_url_mapping, generate_images -from prompts import assemble_imported_code_prompt, assemble_prompt +from image_generation import generate_images +from prompts import create_prompt from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack - from utils import pprint_prompt -from video.utils import assemble_claude_prompt_video from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore router = APIRouter() -async def create_prompt( - params: Dict[str, str], stack: Stack, model: Llm, input_mode: InputMode -) -> tuple[list[ChatCompletionMessageParam], Dict[str, str]]: - - image_cache: Dict[str, str] = {} - - # If this generation started off with imported code, we need to assemble the prompt differently - if params.get("isImportedFromCode"): - original_imported_code = params["history"][0] - prompt_messages = assemble_imported_code_prompt( - original_imported_code, stack, model - ) - for index, text in enumerate(params["history"][1:]): - if index % 2 == 0: - message: ChatCompletionMessageParam = { - "role": "user", - "content": text, - } - else: - message: ChatCompletionMessageParam = { - "role": "assistant", - "content": text, - } - prompt_messages.append(message) - else: - # Assemble the prompt for non-imported code - if params.get("resultImage"): - prompt_messages = assemble_prompt( - params["image"], stack, params["resultImage"] - ) - else: - prompt_messages = assemble_prompt(params["image"], stack) - - if params["generationType"] == "update": - # Transform the history tree into message format - # TODO: Move this to frontend - for index, text in enumerate(params["history"]): - if index % 2 == 0: - message: ChatCompletionMessageParam = { - "role": "assistant", - "content": text, - } - else: - message: ChatCompletionMessageParam = { - "role": "user", - "content": text, - } - prompt_messages.append(message) - - image_cache = create_alt_url_mapping(params["history"][-2]) - - if input_mode == "video": - video_data_url = params["image"] - prompt_messages = await assemble_claude_prompt_video(video_data_url) - - return prompt_messages, image_cache - - # Auto-upgrade usage of older models def auto_upgrade_model(code_generation_model: Llm) -> Llm: if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}: From c61a2ac7729f583daf2c44937f6806c65a8c0443 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 11:18:35 -0400 Subject: [PATCH 16/47] fix TODO --- backend/routes/generate_code.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 0c6e921..f58a89a 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -199,14 +199,9 @@ async def stream_code(websocket: WebSocket): params, valid_stack, code_generation_model, validated_input_mode ) except: - # TODO(*): This should use variantIndex - await websocket.send_json( - { - "type": "error", - "value": "Error assembling prompt. Contact support at support@picoapps.xyz", - } + await throw_error( + "Error assembling prompt. Contact support at support@picoapps.xyz" ) - await websocket.close() raise pprint_prompt(prompt_messages) # type: ignore From 637c1b4c1dddee5a3b1c76546d1e5a481ab4d6a8 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 11:22:25 -0400 Subject: [PATCH 17/47] fix TODO --- backend/routes/generate_code.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index f58a89a..c487c67 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -187,9 +187,7 @@ async def stream_code(websocket: WebSocket): await send_message("status", "Generating code...", 0) await send_message("status", "Generating code...", 1) - # TODO(*): Move down - async def process_chunk(content: str, variantIndex: int): - await send_message("chunk", content, variantIndex) + ### Prompt creation # Image cache for updates so that we don't have to regenerate images image_cache: Dict[str, str] = {} @@ -206,6 +204,11 @@ async def stream_code(websocket: WebSocket): pprint_prompt(prompt_messages) # type: ignore + ### Code generation + + async def process_chunk(content: str, variantIndex: int): + await send_message("chunk", content, variantIndex) + if SHOULD_MOCK_AI_RESPONSE: completions = [ await mock_completion(process_chunk, input_mode=validated_input_mode) From 7b2e2963ad9b440353c4ec922bc8b3e000e9af4e Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 11:25:49 -0400 Subject: [PATCH 18/47] print for debugging --- backend/routes/generate_code.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index c487c67..e96e720 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -97,6 +97,12 @@ async def stream_code(websocket: WebSocket): value: str, variantIndex: int, ): + # Print for debugging on the backend + if type == "error": + print(f"Error (variant {variantIndex}): {value}") + elif type == "status": + print(f"Status (variant {variantIndex}): {value}") + await websocket.send_json( {"type": type, "value": value, "variantIndex": variantIndex} ) @@ -181,8 +187,6 @@ async def stream_code(websocket: WebSocket): # Get the image generation flag from the request. Fall back to True if not provided. should_generate_images = bool(params.get("isImageGenerationEnabled", True)) - print("generating code...") - # TODO(*): Print with send_message instead of print statements await send_message("status", "Generating code...", 0) await send_message("status", "Generating code...", 1) From 701d97ec74e5be47616fa5463a492f8502e65403 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 11:27:26 -0400 Subject: [PATCH 19/47] add comments --- backend/routes/generate_code.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index e96e720..bd4238b 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -85,6 +85,8 @@ async def stream_code(websocket: WebSocket): print("Incoming websocket connection...") + + ## Communication protocol setup async def throw_error( message: str, ): @@ -107,6 +109,8 @@ async def stream_code(websocket: WebSocket): {"type": type, "value": value, "variantIndex": variantIndex} ) + ## Parameter validation + # TODO: Are the values always strings? params: Dict[str, str] = await websocket.receive_json() print("Received params") From 5c3f915bce2be697ad027da93aebea1226e14762 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 13:36:22 -0400 Subject: [PATCH 20/47] parallelize just image generation --- backend/routes/generate_code.py | 87 ++++++++++++++------------------- 1 file changed, 38 insertions(+), 49 deletions(-) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index bd4238b..c33dc07 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -16,7 +16,7 @@ from llm import ( ) from fs_logging.core import write_logs from mock_llm import mock_completion -from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args +from typing import Any, Coroutine, Dict, List, Literal, Union, cast, get_args from image_generation import generate_images from prompts import create_prompt from prompts.claude_prompts import VIDEO_PROMPT @@ -43,40 +43,27 @@ def auto_upgrade_model(code_generation_model: Llm) -> Llm: return code_generation_model -# Generate images and return updated completions -async def process_completion( +# Generate images, if needed +async def perform_image_generation( completion: str, - index: int, should_generate_images: bool, openai_api_key: str | None, openai_base_url: str | None, image_cache: dict[str, str], - send_message: Callable[ - [Literal["chunk", "status", "setCode", "error"], str, int], - Coroutine[Any, Any, None], - ], ): - try: - if should_generate_images and openai_api_key: - await send_message("status", "Generating images...", index) - updated_html = await generate_images( - completion, - api_key=openai_api_key, - base_url=openai_base_url, - image_cache=image_cache, - ) - else: - updated_html = completion + if not should_generate_images: + return completion - await send_message("setCode", updated_html, index) - await send_message("status", "Code generation complete.", index) - except Exception as e: - traceback.print_exc() - print(f"Image generation failed for variant {index}", e) - await send_message("setCode", completion, index) - await send_message( - "status", "Image generation failed but code is complete.", index - ) + if not openai_api_key: + print("No OpenAI API key found. Skipping image generation.") + return completion + + return await generate_images( + completion, + api_key=openai_api_key, + base_url=openai_base_url, + image_cache=image_cache, + ) @router.websocket("/generate-code") @@ -85,7 +72,6 @@ async def stream_code(websocket: WebSocket): print("Incoming websocket connection...") - ## Communication protocol setup async def throw_error( message: str, @@ -110,7 +96,7 @@ async def stream_code(websocket: WebSocket): ) ## Parameter validation - + # TODO: Are the values always strings? params: Dict[str, str] = await websocket.receive_json() print("Received params") @@ -328,31 +314,34 @@ async def stream_code(websocket: WebSocket): # if validated_input_mode == "video": # completion = extract_tag_content("html", completions[0]) + ## Post-processing + # Strip the completion of everything except the HTML content completions = [extract_html_content(completion) for completion in completions] # Write the messages dict into a log so that we can debug later write_logs(prompt_messages, completions[0]) - try: - image_generation_tasks = [ - process_completion( - completion, - index, - should_generate_images, - openai_api_key, - openai_base_url, - image_cache, - send_message, - ) - for index, completion in enumerate(completions) - ] - await asyncio.gather(*image_generation_tasks) - except Exception as e: - traceback.print_exc() - print("An error occurred during image generation and processing", e) - await send_message( - "status", "An error occurred during image generation and processing.", 0 + ## Image Generation + + for index, _ in enumerate(completions): + await send_message("status", "Generating images...", index) + + image_generation_tasks = [ + perform_image_generation( + completion, + should_generate_images, + openai_api_key, + openai_base_url, + image_cache, ) + for completion in completions + ] + + updated_completions = await asyncio.gather(*image_generation_tasks) + + for index, updated_html in enumerate(updated_completions): + await send_message("setCode", updated_html, index) + await send_message("status", "Code generation complete.", index) await websocket.close() From 823bd2e2498336078db7fcc18207f16cb956230b Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 14:44:33 -0400 Subject: [PATCH 21/47] hide execution console --- frontend/src/components/sidebar/Sidebar.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/components/sidebar/Sidebar.tsx b/frontend/src/components/sidebar/Sidebar.tsx index 288b3f8..84db6b6 100644 --- a/frontend/src/components/sidebar/Sidebar.tsx +++ b/frontend/src/components/sidebar/Sidebar.tsx @@ -167,7 +167,7 @@ function Sidebar({
)} -
+

Console

{Object.entries(executionConsoles).map(([index, lines]) => (
From c76c7c202a9c82b985920c3c757c846e93752904 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 15:46:53 -0400 Subject: [PATCH 22/47] move parameter extraction to separate fn --- backend/config.py | 1 + backend/routes/generate_code.py | 179 +++++++++++++++++++------------- 2 files changed, 108 insertions(+), 72 deletions(-) diff --git a/backend/config.py b/backend/config.py index 19bdca4..30da119 100644 --- a/backend/config.py +++ b/backend/config.py @@ -5,6 +5,7 @@ import os OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None) ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None) +OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", None) # Image generation (optional) REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 48709d5..8b8d45a 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -1,5 +1,5 @@ -import os import asyncio +from dataclasses import dataclass from fastapi import APIRouter, WebSocket import openai from codegen.utils import extract_html_content @@ -7,6 +7,7 @@ from config import ( ANTHROPIC_API_KEY, IS_PROD, OPENAI_API_KEY, + OPENAI_BASE_URL, REPLICATE_API_KEY, SHOULD_MOCK_AI_RESPONSE, ) @@ -21,12 +22,13 @@ from llm import ( ) from fs_logging.core import write_logs from mock_llm import mock_completion -from typing import Any, Coroutine, Dict, List, Literal, Union, cast, get_args +from typing import Any, Callable, Coroutine, Dict, List, Literal, cast, get_args from image_generation.core import generate_images from prompts import create_prompt from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack -from utils import pprint_prompt + +# from utils import pprint_prompt from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore @@ -83,10 +85,95 @@ async def perform_image_generation( ) +@dataclass +class ExtractedParams: + stack: Stack + input_mode: InputMode + code_generation_model: Llm + should_generate_images: bool + openai_api_key: str | None + anthropic_api_key: str | None + openai_base_url: str | None + + +async def extract_params( + params: Dict[str, str], throw_error: Callable[[str], Coroutine[Any, Any, None]] +) -> ExtractedParams: + # Read the code config settings (stack) from the request. + generated_code_config = params.get("generatedCodeConfig", "") + if generated_code_config not in get_args(Stack): + await throw_error(f"Invalid generated code config: {generated_code_config}") + raise ValueError(f"Invalid generated code config: {generated_code_config}") + validated_stack = cast(Stack, generated_code_config) + + # Validate the input mode + input_mode = params.get("inputMode") + if input_mode not in get_args(InputMode): + await throw_error(f"Invalid input mode: {input_mode}") + raise ValueError(f"Invalid input mode: {input_mode}") + validated_input_mode = cast(InputMode, input_mode) + + # Read the model from the request. Fall back to default if not provided. + code_generation_model_str = params.get( + "codeGenerationModel", Llm.GPT_4O_2024_05_13.value + ) + try: + code_generation_model = convert_frontend_str_to_llm(code_generation_model_str) + except ValueError: + await throw_error(f"Invalid model: {code_generation_model_str}") + raise ValueError(f"Invalid model: {code_generation_model_str}") + + openai_api_key = get_from_settings_dialog_or_env( + params, "openAiApiKey", OPENAI_API_KEY + ) + + # If neither is provided, we throw an error later only if Claude is used. + anthropic_api_key = get_from_settings_dialog_or_env( + params, "anthropicApiKey", ANTHROPIC_API_KEY + ) + + # Base URL for OpenAI API + openai_base_url: str | None = None + # Disable user-specified OpenAI Base URL in prod + if not IS_PROD: + openai_base_url = get_from_settings_dialog_or_env( + params, "openAiBaseURL", OPENAI_BASE_URL + ) + if not openai_base_url: + print("Using official OpenAI URL") + + # Get the image generation flag from the request. Fall back to True if not provided. + should_generate_images = bool(params.get("isImageGenerationEnabled", True)) + + return ExtractedParams( + stack=validated_stack, + input_mode=validated_input_mode, + code_generation_model=code_generation_model, + should_generate_images=should_generate_images, + openai_api_key=openai_api_key, + anthropic_api_key=anthropic_api_key, + openai_base_url=openai_base_url, + ) + + +def get_from_settings_dialog_or_env( + params: dict[str, str], key: str, env_var: str | None +) -> str | None: + value = params.get(key) + if value: + print(f"Using {key} from client-side settings dialog") + return value + + if env_var: + print(f"Using {key} from environment variable") + return env_var + + return None + + @router.websocket("/generate-code") async def stream_code(websocket: WebSocket): await websocket.accept() - print("Incoming websocket connection...") ## Communication protocol setup @@ -112,89 +199,37 @@ async def stream_code(websocket: WebSocket): {"type": type, "value": value, "variantIndex": variantIndex} ) - ## Parameter validation + ## Parameter extract and validation # TODO: Are the values always strings? - params: Dict[str, str] = await websocket.receive_json() + params: dict[str, str] = await websocket.receive_json() print("Received params") - # Read the code config settings (stack) from the request. - generated_code_config = params.get("generatedCodeConfig", "") - if not generated_code_config in get_args(Stack): - await throw_error(f"Invalid generated code config: {generated_code_config}") - raise Exception(f"Invalid generated code config: {generated_code_config}") - # Cast the variable to the Stack type - valid_stack = cast(Stack, generated_code_config) - - # Validate the input mode - input_mode = params.get("inputMode") - if not input_mode in get_args(InputMode): - await throw_error(f"Invalid input mode: {input_mode}") - raise Exception(f"Invalid input mode: {input_mode}") - # Cast the variable to the right type - validated_input_mode = cast(InputMode, input_mode) - - # Read the model from the request. Fall back to default if not provided. - code_generation_model_str = params.get( - "codeGenerationModel", Llm.GPT_4O_2024_05_13.value - ) - try: - code_generation_model = convert_frontend_str_to_llm(code_generation_model_str) - except: - await throw_error(f"Invalid model: {code_generation_model_str}") - raise Exception(f"Invalid model: {code_generation_model_str}") + extracted_params = await extract_params(params, throw_error) + # TODO(*): Rename to stack and input_mode + valid_stack = extracted_params.stack + validated_input_mode = extracted_params.input_mode + code_generation_model = extracted_params.code_generation_model + openai_api_key = extracted_params.openai_api_key + openai_base_url = extracted_params.openai_base_url + anthropic_api_key = extracted_params.anthropic_api_key + should_generate_images = extracted_params.should_generate_images # Auto-upgrade usage of older models code_generation_model = auto_upgrade_model(code_generation_model) + print( - f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." + f"Generating {valid_stack} code in {validated_input_mode} mode using {code_generation_model}..." ) - # Get the OpenAI API key from the request. Fall back to environment variable if not provided. - # If neither is provided, we throw an error. - openai_api_key = params.get("openAiApiKey") - if openai_api_key: - print("Using OpenAI API key from client-side settings dialog") - else: - openai_api_key = OPENAI_API_KEY - if openai_api_key: - print("Using OpenAI API key from environment variable") - + # TODO(*): Do I still need this? if not openai_api_key and is_openai_model(code_generation_model): await throw_error( "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server." ) return - # Get the Anthropic API key from the request. Fall back to environment variable if not provided. - # If neither is provided, we throw an error later only if Claude is used. - anthropic_api_key = params.get("anthropicApiKey") - if anthropic_api_key: - print("Using Anthropic API key from client-side settings dialog") - else: - anthropic_api_key = ANTHROPIC_API_KEY - if anthropic_api_key: - print("Using Anthropic API key from environment variable") - - # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. - openai_base_url: Union[str, None] = None - # Disable user-specified OpenAI Base URL in prod - if not os.environ.get("IS_PROD"): - openai_base_url = params.get("openAiBaseURL") - if openai_base_url: - print("Using OpenAI Base URL from client-side settings dialog") - else: - openai_base_url = os.environ.get("OPENAI_BASE_URL") - if openai_base_url: - print("Using OpenAI Base URL from environment variable") - - if not openai_base_url: - print("Using official OpenAI URL") - - # Get the image generation flag from the request. Fall back to True if not provided. - should_generate_images = bool(params.get("isImageGenerationEnabled", True)) - - # TODO(*): Print with send_message instead of print statements + # TODO(*): Don't assume number of variants await send_message("status", "Generating code...", 0) await send_message("status", "Generating code...", 1) @@ -213,7 +248,7 @@ async def stream_code(websocket: WebSocket): ) raise - pprint_prompt(prompt_messages) # type: ignore + # pprint_prompt(prompt_messages) # type: ignore ### Code generation From fb5480b036d8835bd16d3609e1c5ea9b9528e17a Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 16:05:16 -0400 Subject: [PATCH 23/47] fix type error --- backend/image_generation/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/image_generation/core.py b/backend/image_generation/core.py index dfd3375..91a5535 100644 --- a/backend/image_generation/core.py +++ b/backend/image_generation/core.py @@ -27,7 +27,7 @@ async def process_tasks( processed_results: List[Union[str, None]] = [] for result in results: - if isinstance(result, Exception): + if isinstance(result, BaseException): print(f"An exception occurred: {result}") processed_results.append(None) else: From b158597d7e969de715a6bcc288530553487e666a Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Mon, 5 Aug 2024 14:11:01 -0400 Subject: [PATCH 24/47] fix bug with prompt assembly for imported code with Claude which disallows multiple user messages in a row --- backend/prompts/__init__.py | 36 ++++---------- backend/prompts/test_prompts.py | 84 ++++++++++++++++++++------------- backend/routes/generate_code.py | 7 ++- 3 files changed, 64 insertions(+), 63 deletions(-) diff --git a/backend/prompts/__init__.py b/backend/prompts/__init__.py index d7103f9..955502b 100644 --- a/backend/prompts/__init__.py +++ b/backend/prompts/__init__.py @@ -1,10 +1,8 @@ from typing import Union - from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam + from custom_types import InputMode from image_generation.core import create_alt_url_mapping -from llm import Llm - from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS from prompts.screenshot_system_prompts import SYSTEM_PROMPTS from prompts.types import Stack @@ -21,7 +19,7 @@ Generate code for a SVG that looks exactly like this. async def create_prompt( - params: dict[str, str], stack: Stack, model: Llm, input_mode: InputMode + params: dict[str, str], stack: Stack, input_mode: InputMode ) -> tuple[list[ChatCompletionMessageParam], dict[str, str]]: image_cache: dict[str, str] = {} @@ -29,9 +27,7 @@ async def create_prompt( # If this generation started off with imported code, we need to assemble the prompt differently if params.get("isImportedFromCode"): original_imported_code = params["history"][0] - prompt_messages = assemble_imported_code_prompt( - original_imported_code, stack, model - ) + prompt_messages = assemble_imported_code_prompt(original_imported_code, stack) for index, text in enumerate(params["history"][1:]): if index % 2 == 0: message: ChatCompletionMessageParam = { @@ -79,7 +75,7 @@ async def create_prompt( def assemble_imported_code_prompt( - code: str, stack: Stack, model: Llm + code: str, stack: Stack ) -> list[ChatCompletionMessageParam]: system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack] @@ -89,24 +85,12 @@ def assemble_imported_code_prompt( else "Here is the code of the SVG: " + code ) - if model == Llm.CLAUDE_3_5_SONNET_2024_06_20: - return [ - { - "role": "system", - "content": system_content + "\n " + user_content, - } - ] - else: - return [ - { - "role": "system", - "content": system_content, - }, - { - "role": "user", - "content": user_content, - }, - ] + return [ + { + "role": "system", + "content": system_content + "\n " + user_content, + } + ] # TODO: Use result_image_data_url diff --git a/backend/prompts/test_prompts.py b/backend/prompts/test_prompts.py index 9175fd8..049f9db 100644 --- a/backend/prompts/test_prompts.py +++ b/backend/prompts/test_prompts.py @@ -391,63 +391,81 @@ def test_prompts(): def test_imported_code_prompts(): - tailwind_prompt = assemble_imported_code_prompt( - "code", "html_tailwind", Llm.GPT_4O_2024_05_13 - ) + code = "Sample code" + + tailwind_prompt = assemble_imported_code_prompt(code, "html_tailwind") expected_tailwind_prompt = [ - {"role": "system", "content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert tailwind_prompt == expected_tailwind_prompt - html_css_prompt = assemble_imported_code_prompt( - "code", "html_css", Llm.GPT_4O_2024_05_13 - ) + html_css_prompt = assemble_imported_code_prompt(code, "html_css") expected_html_css_prompt = [ - {"role": "system", "content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert html_css_prompt == expected_html_css_prompt - react_tailwind_prompt = assemble_imported_code_prompt( - "code", "react_tailwind", Llm.GPT_4O_2024_05_13 - ) + react_tailwind_prompt = assemble_imported_code_prompt(code, "react_tailwind") expected_react_tailwind_prompt = [ - {"role": "system", "content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert react_tailwind_prompt == expected_react_tailwind_prompt - bootstrap_prompt = assemble_imported_code_prompt( - "code", "bootstrap", Llm.GPT_4O_2024_05_13 - ) + bootstrap_prompt = assemble_imported_code_prompt(code, "bootstrap") expected_bootstrap_prompt = [ - {"role": "system", "content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert bootstrap_prompt == expected_bootstrap_prompt - ionic_tailwind = assemble_imported_code_prompt( - "code", "ionic_tailwind", Llm.GPT_4O_2024_05_13 - ) + ionic_tailwind = assemble_imported_code_prompt(code, "ionic_tailwind") expected_ionic_tailwind = [ - {"role": "system", "content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert ionic_tailwind == expected_ionic_tailwind - vue_tailwind = assemble_imported_code_prompt( - "code", "vue_tailwind", Llm.GPT_4O_2024_05_13 - ) + vue_tailwind = assemble_imported_code_prompt(code, "vue_tailwind") expected_vue_tailwind = [ - {"role": "system", "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert vue_tailwind == expected_vue_tailwind - svg = assemble_imported_code_prompt("code", "svg", Llm.GPT_4O_2024_05_13) + svg = assemble_imported_code_prompt(code, "svg") expected_svg = [ - {"role": "system", "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the SVG: code"}, + { + "role": "system", + "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT + + "\n Here is the code of the SVG: " + + code, + } ] assert svg == expected_svg diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 8b8d45a..d7f46c5 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -27,8 +27,7 @@ from image_generation.core import generate_images from prompts import create_prompt from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack - -# from utils import pprint_prompt +from utils import pprint_prompt from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore @@ -240,7 +239,7 @@ async def stream_code(websocket: WebSocket): try: prompt_messages, image_cache = await create_prompt( - params, valid_stack, code_generation_model, validated_input_mode + params, valid_stack, validated_input_mode ) except: await throw_error( @@ -248,7 +247,7 @@ async def stream_code(websocket: WebSocket): ) raise - # pprint_prompt(prompt_messages) # type: ignore + pprint_prompt(prompt_messages) # type: ignore ### Code generation From 5f6dd08411941d69f8ea156277196a861feb44ff Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Mon, 5 Aug 2024 16:17:48 -0400 Subject: [PATCH 25/47] reset inputMode when resetting state --- frontend/src/App.tsx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 789b905..048426c 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -110,19 +110,20 @@ function App() { const reset = () => { setAppState(AppState.INITIAL); + setShouldIncludeResultImage(false); + setUpdateInstruction(""); + disableInSelectAndEditMode(); setGeneratedCode(""); resetVariants(); resetExecutionConsoles(); // Inputs + setInputMode("image"); setReferenceImages([]); - setUpdateInstruction(""); setIsImportedFromCode(false); setAppHistory([]); setCurrentVersion(null); - setShouldIncludeResultImage(false); - disableInSelectAndEditMode(); }; const regenerate = () => { From 8e8f0b4b644b36d1843a4cbfd412569702f60143 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Thu, 22 Aug 2024 13:26:42 -0400 Subject: [PATCH 26/47] intermediate changes towards multiple generations --- frontend/package.json | 1 + frontend/src/App.tsx | 244 +++++----- .../src/components/history/HistoryDisplay.tsx | 69 ++- .../src/components/history/history_types.ts | 55 ++- frontend/src/components/history/utils.test.ts | 428 +++++++++--------- frontend/src/components/history/utils.ts | 93 +--- .../src/components/preview/PreviewPane.tsx | 16 +- frontend/src/components/sidebar/Sidebar.tsx | 19 +- frontend/src/components/variants/Variants.tsx | 45 +- frontend/src/store/project-store.ts | 130 +++--- frontend/yarn.lock | 5 + 11 files changed, 538 insertions(+), 567 deletions(-) diff --git a/frontend/package.json b/frontend/package.json index 4652dc7..0d454b0 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -36,6 +36,7 @@ "codemirror": "^6.0.1", "copy-to-clipboard": "^3.3.3", "html2canvas": "^1.4.1", + "nanoid": "^5.0.7", "react": "^18.2.0", "react-dom": "^18.2.0", "react-dropzone": "^14.2.3", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 048426c..4ec5c6f 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -8,8 +8,7 @@ import { OnboardingNote } from "./components/messages/OnboardingNote"; import { usePersistedState } from "./hooks/usePersistedState"; import TermsOfServiceDialog from "./components/TermsOfServiceDialog"; import { USER_CLOSE_WEB_SOCKET_CODE } from "./constants"; -import { History } from "./components/history/history_types"; -import { extractHistoryTree } from "./components/history/utils"; +import { extractHistory } from "./components/history/utils"; import toast from "react-hot-toast"; import { Stack } from "./lib/stacks"; import { CodeGenerationModel } from "./lib/models"; @@ -23,6 +22,7 @@ import DeprecationMessage from "./components/messages/DeprecationMessage"; import { GenerationSettings } from "./components/settings/GenerationSettings"; import StartPane from "./components/start-pane/StartPane"; import { takeScreenshot } from "./lib/takeScreenshot"; +import { Commit, createCommit } from "./components/history/history_types"; function App() { const { @@ -34,18 +34,18 @@ function App() { referenceImages, setReferenceImages, + head, + commits, + addCommit, + removeCommit, + setHead, + appendCommitCode, + setCommitCode, + resetCommits, + // Outputs - setGeneratedCode, - currentVariantIndex, - setVariant, - appendToVariant, - resetVariants, appendExecutionConsole, resetExecutionConsoles, - currentVersion, - setCurrentVersion, - appHistory, - setAppHistory, } = useProjectStore(); const { @@ -113,34 +113,30 @@ function App() { setShouldIncludeResultImage(false); setUpdateInstruction(""); disableInSelectAndEditMode(); - setGeneratedCode(""); - resetVariants(); resetExecutionConsoles(); + resetCommits(); + // Inputs setInputMode("image"); setReferenceImages([]); setIsImportedFromCode(false); - - setAppHistory([]); - setCurrentVersion(null); }; const regenerate = () => { - if (currentVersion === null) { + // TODO: post to Sentry + if (head === null) { toast.error( "No current version set. Please open a Github issue as this shouldn't happen." ); return; } - // Retrieve the previous command - const previousCommand = appHistory[currentVersion]; - if (previousCommand.type !== "ai_create") { + const currentCommit = commits[head]; + if (currentCommit.type !== "ai_create") { toast.error("Only the first version can be regenerated."); return; } - // Re-run the create doCreate(referenceImages, inputMode); }; @@ -149,25 +145,32 @@ function App() { const cancelCodeGeneration = () => { wsRef.current?.close?.(USER_CLOSE_WEB_SOCKET_CODE); // make sure stop can correct the state even if the websocket is already closed - cancelCodeGenerationAndReset(); + // TODO: Look into this + // cancelCodeGenerationAndReset(); }; // Used for code generation failure as well - const cancelCodeGenerationAndReset = () => { - // When this is the first version, reset the entire app state - if (currentVersion === null) { + const cancelCodeGenerationAndReset = (commit: Commit) => { + // When the current commit is the first version, reset the entire app state + if (commit.type === "ai_create") { reset(); } else { - // Otherwise, revert to the last version - setGeneratedCode(appHistory[currentVersion].code); + // Otherwise, remove current commit from commits + removeCommit(commit.hash); + + // Revert to parent commit + const parentCommitHash = commit.parentHash; + if (parentCommitHash) { + setHead(parentCommitHash); + } else { + // TODO: Hit Sentry + } + setAppState(AppState.CODE_READY); } }; - function doGenerateCode( - params: CodeGenerationParams, - parentVersion: number | null - ) { + function doGenerateCode(params: CodeGenerationParams) { // Reset the execution console resetExecutionConsoles(); @@ -177,69 +180,51 @@ function App() { // Merge settings with params const updatedParams = { ...params, ...settings }; + const baseCommitObject = { + date_created: new Date(), + variants: [{ code: "" }, { code: "" }], + selectedVariantIndex: 0, + }; + + const commitInputObject = + params.generationType === "create" + ? { + ...baseCommitObject, + type: "ai_create" as const, + parentHash: null, + inputs: { image_url: referenceImages[0] }, + } + : { + ...baseCommitObject, + type: "ai_edit" as const, + parentHash: head, + inputs: { + prompt: params.history + ? params.history[params.history.length - 1] + : "", + }, + }; + + const commit = createCommit(commitInputObject); + addCommit(commit); + setHead(commit.hash); + generateCode( wsRef, updatedParams, // On change (token, variant) => { - if (variant === currentVariantIndex) { - setGeneratedCode((prev) => prev + token); - } - - appendToVariant(token, variant); + appendCommitCode(commit.hash, variant, token); }, // On set code (code, variant) => { - if (variant === currentVariantIndex) { - setGeneratedCode(code); - } - - setVariant(code, variant); - - // TODO: How to deal with variants? - if (params.generationType === "create") { - setAppHistory([ - { - type: "ai_create", - parentIndex: null, - code, - inputs: { image_url: referenceImages[0] }, - }, - ]); - setCurrentVersion(0); - } else { - setAppHistory((prev) => { - // Validate parent version - if (parentVersion === null) { - toast.error( - "No parent version set. Contact support or open a Github issue." - ); - return prev; - } - - const newHistory: History = [ - ...prev, - { - type: "ai_edit", - parentIndex: parentVersion, - code, - inputs: { - prompt: params.history - ? params.history[params.history.length - 1] - : "", // History should never be empty when performing an edit - }, - }, - ]; - setCurrentVersion(newHistory.length - 1); - return newHistory; - }); - } + setCommitCode(commit.hash, variant, code); }, // On status update (line, variant) => appendExecutionConsole(variant, line), // On cancel () => { - cancelCodeGenerationAndReset(); + cancelCodeGenerationAndReset(commit); }, // On complete () => { @@ -259,14 +244,11 @@ function App() { // Kick off the code generation if (referenceImages.length > 0) { - doGenerateCode( - { - generationType: "create", - image: referenceImages[0], - inputMode, - }, - currentVersion - ); + doGenerateCode({ + generationType: "create", + image: referenceImages[0], + inputMode, + }); } } @@ -280,16 +262,17 @@ function App() { return; } - if (currentVersion === null) { - toast.error( - "No current version set. Contact support or open a Github issue." - ); - return; - } + // if (currentVersion === null) { + // toast.error( + // "No current version set. Contact support or open a Github issue." + // ); + // return; + // } let historyTree; try { - historyTree = extractHistoryTree(appHistory, currentVersion); + // TODO: Fix head being null + historyTree = extractHistory(head || "", commits); } catch { toast.error( "Version history is invalid. This shouldn't happen. Please contact support or open a Github issue." @@ -309,34 +292,28 @@ function App() { const updatedHistory = [...historyTree, modifiedUpdateInstruction]; + console.log(updatedHistory); + if (shouldIncludeResultImage) { const resultImage = await takeScreenshot(); - doGenerateCode( - { - generationType: "update", - inputMode, - image: referenceImages[0], - resultImage: resultImage, - history: updatedHistory, - isImportedFromCode, - }, - currentVersion - ); + doGenerateCode({ + generationType: "update", + inputMode, + image: referenceImages[0], + resultImage: resultImage, + history: updatedHistory, + isImportedFromCode, + }); } else { - doGenerateCode( - { - generationType: "update", - inputMode, - image: referenceImages[0], - history: updatedHistory, - isImportedFromCode, - }, - currentVersion - ); + doGenerateCode({ + generationType: "update", + inputMode, + image: referenceImages[0], + history: updatedHistory, + isImportedFromCode, + }); } - setGeneratedCode(""); - resetVariants(); setUpdateInstruction(""); } @@ -358,18 +335,27 @@ function App() { // Set input state setIsImportedFromCode(true); + console.log(code); + // Set up this project - setGeneratedCode(code); + // TODO* + // setGeneratedCode(code); setStack(stack); - setAppHistory([ - { - type: "code_create", - parentIndex: null, - code, - inputs: { code }, - }, - ]); - setCurrentVersion(0); + // setAppHistory([ + // { + // type: "code_create", + // parentIndex: null, + // code, + // inputs: { code }, + // }, + // ]); + // setVariant(0, { + // type: "code_create", + // parentIndex: null, + // code, + // }); + // setCurrentVariantIndex(0); + // setCurrentVersion(0); // Set the app state setAppState(AppState.CODE_READY); diff --git a/frontend/src/components/history/HistoryDisplay.tsx b/frontend/src/components/history/HistoryDisplay.tsx index a3dcdb1..5b81f22 100644 --- a/frontend/src/components/history/HistoryDisplay.tsx +++ b/frontend/src/components/history/HistoryDisplay.tsx @@ -2,7 +2,7 @@ import toast from "react-hot-toast"; import classNames from "classnames"; import { Badge } from "../ui/badge"; -import { renderHistory } from "./utils"; +import { summarizeHistoryItem } from "./utils"; import { Collapsible, CollapsibleContent, @@ -17,25 +17,58 @@ interface Props { } export default function HistoryDisplay({ shouldDisableReverts }: Props) { - const { - appHistory: history, - currentVersion, - setCurrentVersion, - setGeneratedCode, - } = useProjectStore(); - const renderedHistory = renderHistory(history, currentVersion); + const { commits, head, setHead } = useProjectStore(); - const revertToVersion = (index: number) => { - if (index < 0 || index >= history.length || !history[index]) return; - setCurrentVersion(index); - setGeneratedCode(history[index].code); + // TODO: Clean this up + + const newHistory = Object.values(commits).flatMap((commit) => { + if (commit.type === "ai_create" || commit.type === "ai_edit") { + return { + type: commit.type, + hash: commit.hash, + summary: summarizeHistoryItem(commit), + parentHash: commit.parentHash, + code: commit.variants[commit.selectedVariantIndex].code, + inputs: commit.inputs, + date_created: commit.date_created, + }; + } + return []; + }); + + // Sort by date created + newHistory.sort( + (a, b) => + new Date(a.date_created).getTime() - new Date(b.date_created).getTime() + ); + + const setParentVersion = ( + parentHash: string | null, + currentHash: string | null + ) => { + if (!parentHash) return null; + const parentIndex = newHistory.findIndex( + (item) => item.hash === parentHash + ); + const currentIndex = newHistory.findIndex( + (item) => item.hash === currentHash + ); + return parentIndex !== -1 && parentIndex != currentIndex - 1 + ? parentIndex + 1 + : null; }; - return renderedHistory.length === 0 ? null : ( + // Update newHistory to include the parent version + const updatedHistory = newHistory.map((item) => ({ + ...item, + parentVersion: setParentVersion(item.parentHash, item.hash), + })); + + return updatedHistory.length === 0 ? null : (

Versions

    - {renderedHistory.map((item, index) => ( + {updatedHistory.map((item, index) => (
  • @@ -55,14 +88,14 @@ export default function HistoryDisplay({ shouldDisableReverts }: Props) { ? toast.error( "Please wait for code generation to complete before viewing an older version." ) - : revertToVersion(index) + : setHead(item.hash) } >

    {item.summary}

    {item.parentVersion !== null && (

    - (parent: {item.parentVersion}) + (parent: v{item.parentVersion})

    )}
    diff --git a/frontend/src/components/history/history_types.ts b/frontend/src/components/history/history_types.ts index 8dcd219..183b1b6 100644 --- a/frontend/src/components/history/history_types.ts +++ b/frontend/src/components/history/history_types.ts @@ -1,37 +1,44 @@ -export type HistoryItemType = "ai_create" | "ai_edit" | "code_create"; +export type CommitType = "ai_create" | "ai_edit" | "code_create"; -type CommonHistoryItem = { - parentIndex: null | number; +export type CommitHash = string; + +export type Variant = { code: string; }; -export type HistoryItem = - | ({ - type: "ai_create"; - inputs: AiCreateInputs; - } & CommonHistoryItem) - | ({ - type: "ai_edit"; - inputs: AiEditInputs; - } & CommonHistoryItem) - | ({ - type: "code_create"; - inputs: CodeCreateInputs; - } & CommonHistoryItem); - -export type AiCreateInputs = { - image_url: string; +export type BaseCommit = { + hash: CommitHash; + parentHash: CommitHash | null; + date_created: Date; + variants: Variant[]; + selectedVariantIndex: number; }; -export type AiEditInputs = { - prompt: string; +import { nanoid } from "nanoid"; + +// TODO: Move to a different file +export function createCommit( + commit: Omit | Omit +): Commit { + const hash = nanoid(); + return { ...commit, hash }; +} + +export type AiCreateCommit = BaseCommit & { + type: "ai_create"; + inputs: { + image_url: string; + }; }; -export type CodeCreateInputs = { - code: string; +export type AiEditCommit = BaseCommit & { + type: "ai_edit"; + inputs: { + prompt: string; + }; }; -export type History = HistoryItem[]; +export type Commit = AiCreateCommit | AiEditCommit; export type RenderedHistoryItem = { type: string; diff --git a/frontend/src/components/history/utils.test.ts b/frontend/src/components/history/utils.test.ts index e321bdc..f70d2bc 100644 --- a/frontend/src/components/history/utils.test.ts +++ b/frontend/src/components/history/utils.test.ts @@ -1,231 +1,231 @@ -import { extractHistoryTree, renderHistory } from "./utils"; -import type { History } from "./history_types"; +// import { extractHistoryTree, renderHistory } from "./utils"; +// import type { History } from "./history_types"; -const basicLinearHistory: History = [ - { - type: "ai_create", - parentIndex: null, - code: "1. create", - inputs: { - image_url: "", - }, - }, - { - type: "ai_edit", - parentIndex: 0, - code: "2. edit with better icons", - inputs: { - prompt: "use better icons", - }, - }, - { - type: "ai_edit", - parentIndex: 1, - code: "3. edit with better icons and red text", - inputs: { - prompt: "make text red", - }, - }, -]; +// const basicLinearHistory: History = [ +// { +// type: "ai_create", +// parentIndex: null, +// code: "1. create", +// inputs: { +// image_url: "", +// }, +// }, +// { +// type: "ai_edit", +// parentIndex: 0, +// code: "2. edit with better icons", +// inputs: { +// prompt: "use better icons", +// }, +// }, +// { +// type: "ai_edit", +// parentIndex: 1, +// code: "3. edit with better icons and red text", +// inputs: { +// prompt: "make text red", +// }, +// }, +// ]; -const basicLinearHistoryWithCode: History = [ - { - type: "code_create", - parentIndex: null, - code: "1. create", - inputs: { - code: "1. create", - }, - }, - ...basicLinearHistory.slice(1), -]; +// const basicLinearHistoryWithCode: History = [ +// { +// type: "code_create", +// parentIndex: null, +// code: "1. create", +// inputs: { +// code: "1. create", +// }, +// }, +// ...basicLinearHistory.slice(1), +// ]; -const basicBranchingHistory: History = [ - ...basicLinearHistory, - { - type: "ai_edit", - parentIndex: 1, - code: "4. edit with better icons and green text", - inputs: { - prompt: "make text green", - }, - }, -]; +// const basicBranchingHistory: History = [ +// ...basicLinearHistory, +// { +// type: "ai_edit", +// parentIndex: 1, +// code: "4. edit with better icons and green text", +// inputs: { +// prompt: "make text green", +// }, +// }, +// ]; -const longerBranchingHistory: History = [ - ...basicBranchingHistory, - { - type: "ai_edit", - parentIndex: 3, - code: "5. edit with better icons and green, bold text", - inputs: { - prompt: "make text bold", - }, - }, -]; +// const longerBranchingHistory: History = [ +// ...basicBranchingHistory, +// { +// type: "ai_edit", +// parentIndex: 3, +// code: "5. edit with better icons and green, bold text", +// inputs: { +// prompt: "make text bold", +// }, +// }, +// ]; -const basicBadHistory: History = [ - { - type: "ai_create", - parentIndex: null, - code: "1. create", - inputs: { - image_url: "", - }, - }, - { - type: "ai_edit", - parentIndex: 2, // <- Bad parent index - code: "2. edit with better icons", - inputs: { - prompt: "use better icons", - }, - }, -]; +// const basicBadHistory: History = [ +// { +// type: "ai_create", +// parentIndex: null, +// code: "1. create", +// inputs: { +// image_url: "", +// }, +// }, +// { +// type: "ai_edit", +// parentIndex: 2, // <- Bad parent index +// code: "2. edit with better icons", +// inputs: { +// prompt: "use better icons", +// }, +// }, +// ]; -describe("History Utils", () => { - test("should correctly extract the history tree", () => { - expect(extractHistoryTree(basicLinearHistory, 2)).toEqual([ - "1. create", - "use better icons", - "2. edit with better icons", - "make text red", - "3. edit with better icons and red text", - ]); +// describe("History Utils", () => { +// test("should correctly extract the history tree", () => { +// expect(extractHistoryTree(basicLinearHistory, 2)).toEqual([ +// "1. create", +// "use better icons", +// "2. edit with better icons", +// "make text red", +// "3. edit with better icons and red text", +// ]); - expect(extractHistoryTree(basicLinearHistory, 0)).toEqual([ - "1. create", - ]); +// expect(extractHistoryTree(basicLinearHistory, 0)).toEqual([ +// "1. create", +// ]); - // Test branching - expect(extractHistoryTree(basicBranchingHistory, 3)).toEqual([ - "1. create", - "use better icons", - "2. edit with better icons", - "make text green", - "4. edit with better icons and green text", - ]); +// // Test branching +// expect(extractHistoryTree(basicBranchingHistory, 3)).toEqual([ +// "1. create", +// "use better icons", +// "2. edit with better icons", +// "make text green", +// "4. edit with better icons and green text", +// ]); - expect(extractHistoryTree(longerBranchingHistory, 4)).toEqual([ - "1. create", - "use better icons", - "2. edit with better icons", - "make text green", - "4. edit with better icons and green text", - "make text bold", - "5. edit with better icons and green, bold text", - ]); +// expect(extractHistoryTree(longerBranchingHistory, 4)).toEqual([ +// "1. create", +// "use better icons", +// "2. edit with better icons", +// "make text green", +// "4. edit with better icons and green text", +// "make text bold", +// "5. edit with better icons and green, bold text", +// ]); - expect(extractHistoryTree(longerBranchingHistory, 2)).toEqual([ - "1. create", - "use better icons", - "2. edit with better icons", - "make text red", - "3. edit with better icons and red text", - ]); +// expect(extractHistoryTree(longerBranchingHistory, 2)).toEqual([ +// "1. create", +// "use better icons", +// "2. edit with better icons", +// "make text red", +// "3. edit with better icons and red text", +// ]); - // Errors +// // Errors - // Bad index - expect(() => extractHistoryTree(basicLinearHistory, 100)).toThrow(); - expect(() => extractHistoryTree(basicLinearHistory, -2)).toThrow(); +// // Bad index +// expect(() => extractHistoryTree(basicLinearHistory, 100)).toThrow(); +// expect(() => extractHistoryTree(basicLinearHistory, -2)).toThrow(); - // Bad tree - expect(() => extractHistoryTree(basicBadHistory, 1)).toThrow(); - }); +// // Bad tree +// expect(() => extractHistoryTree(basicBadHistory, 1)).toThrow(); +// }); - test("should correctly render the history tree", () => { - expect(renderHistory(basicLinearHistory, 2)).toEqual([ - { - isActive: false, - parentVersion: null, - summary: "Create", - type: "Create", - }, - { - isActive: false, - parentVersion: null, - summary: "use better icons", - type: "Edit", - }, - { - isActive: true, - parentVersion: null, - summary: "make text red", - type: "Edit", - }, - ]); +// test("should correctly render the history tree", () => { +// expect(renderHistory(basicLinearHistory, 2)).toEqual([ +// { +// isActive: false, +// parentVersion: null, +// summary: "Create", +// type: "Create", +// }, +// { +// isActive: false, +// parentVersion: null, +// summary: "use better icons", +// type: "Edit", +// }, +// { +// isActive: true, +// parentVersion: null, +// summary: "make text red", +// type: "Edit", +// }, +// ]); - // Current version is the first version - expect(renderHistory(basicLinearHistory, 0)).toEqual([ - { - isActive: true, - parentVersion: null, - summary: "Create", - type: "Create", - }, - { - isActive: false, - parentVersion: null, - summary: "use better icons", - type: "Edit", - }, - { - isActive: false, - parentVersion: null, - summary: "make text red", - type: "Edit", - }, - ]); +// // Current version is the first version +// expect(renderHistory(basicLinearHistory, 0)).toEqual([ +// { +// isActive: true, +// parentVersion: null, +// summary: "Create", +// type: "Create", +// }, +// { +// isActive: false, +// parentVersion: null, +// summary: "use better icons", +// type: "Edit", +// }, +// { +// isActive: false, +// parentVersion: null, +// summary: "make text red", +// type: "Edit", +// }, +// ]); - // Render a history with code - expect(renderHistory(basicLinearHistoryWithCode, 0)).toEqual([ - { - isActive: true, - parentVersion: null, - summary: "Imported from code", - type: "Imported from code", - }, - { - isActive: false, - parentVersion: null, - summary: "use better icons", - type: "Edit", - }, - { - isActive: false, - parentVersion: null, - summary: "make text red", - type: "Edit", - }, - ]); +// // Render a history with code +// expect(renderHistory(basicLinearHistoryWithCode, 0)).toEqual([ +// { +// isActive: true, +// parentVersion: null, +// summary: "Imported from code", +// type: "Imported from code", +// }, +// { +// isActive: false, +// parentVersion: null, +// summary: "use better icons", +// type: "Edit", +// }, +// { +// isActive: false, +// parentVersion: null, +// summary: "make text red", +// type: "Edit", +// }, +// ]); - // Render a non-linear history - expect(renderHistory(basicBranchingHistory, 3)).toEqual([ - { - isActive: false, - parentVersion: null, - summary: "Create", - type: "Create", - }, - { - isActive: false, - parentVersion: null, - summary: "use better icons", - type: "Edit", - }, - { - isActive: false, - parentVersion: null, - summary: "make text red", - type: "Edit", - }, - { - isActive: true, - parentVersion: "v2", - summary: "make text green", - type: "Edit", - }, - ]); - }); -}); +// // Render a non-linear history +// expect(renderHistory(basicBranchingHistory, 3)).toEqual([ +// { +// isActive: false, +// parentVersion: null, +// summary: "Create", +// type: "Create", +// }, +// { +// isActive: false, +// parentVersion: null, +// summary: "use better icons", +// type: "Edit", +// }, +// { +// isActive: false, +// parentVersion: null, +// summary: "make text red", +// type: "Edit", +// }, +// { +// isActive: true, +// parentVersion: "v2", +// summary: "make text green", +// type: "Edit", +// }, +// ]); +// }); +// }); diff --git a/frontend/src/components/history/utils.ts b/frontend/src/components/history/utils.ts index 785c20b..e442925 100644 --- a/frontend/src/components/history/utils.ts +++ b/frontend/src/components/history/utils.ts @@ -1,33 +1,29 @@ -import { - History, - HistoryItem, - HistoryItemType, - RenderedHistoryItem, -} from "./history_types"; +import { Commit, CommitHash } from "./history_types"; -export function extractHistoryTree( - history: History, - version: number +export function extractHistory( + hash: CommitHash, + commits: Record ): string[] { const flatHistory: string[] = []; - let currentIndex: number | null = version; - while (currentIndex !== null) { - const item: HistoryItem = history[currentIndex]; + let currentCommitHash: CommitHash | null = hash; + while (currentCommitHash !== null) { + const commit: Commit = commits[currentCommitHash]; - if (item) { - if (item.type === "ai_create") { + if (commit) { + if (commit.type === "ai_create") { // Don't include the image for ai_create - flatHistory.unshift(item.code); - } else if (item.type === "ai_edit") { - flatHistory.unshift(item.code); - flatHistory.unshift(item.inputs.prompt); - } else if (item.type === "code_create") { - flatHistory.unshift(item.code); + flatHistory.unshift(commit.variants[commit.selectedVariantIndex].code); + } else if (commit.type === "ai_edit") { + flatHistory.unshift(commit.variants[commit.selectedVariantIndex].code); + flatHistory.unshift(commit.inputs.prompt); } + // } else if (item.type === "code_create") { + // flatHistory.unshift(item.code); + // } // Move to the parent of the current item - currentIndex = item.parentIndex; + currentCommitHash = commit.parentHash; } else { throw new Error("Malformed history: missing parent index"); } @@ -36,61 +32,16 @@ export function extractHistoryTree( return flatHistory; } -function displayHistoryItemType(itemType: HistoryItemType) { - switch (itemType) { +export function summarizeHistoryItem(commit: Commit) { + const commitType = commit.type; + switch (commitType) { case "ai_create": return "Create"; case "ai_edit": - return "Edit"; - case "code_create": - return "Imported from code"; + return commit.inputs.prompt; default: { - const exhaustiveCheck: never = itemType; + const exhaustiveCheck: never = commitType; throw new Error(`Unhandled case: ${exhaustiveCheck}`); } } } - -function summarizeHistoryItem(item: HistoryItem) { - const itemType = item.type; - switch (itemType) { - case "ai_create": - return "Create"; - case "ai_edit": - return item.inputs.prompt; - case "code_create": - return "Imported from code"; - default: { - const exhaustiveCheck: never = itemType; - throw new Error(`Unhandled case: ${exhaustiveCheck}`); - } - } -} - -export const renderHistory = ( - history: History, - currentVersion: number | null -) => { - const renderedHistory: RenderedHistoryItem[] = []; - - for (let i = 0; i < history.length; i++) { - const item = history[i]; - // Only show the parent version if it's not the previous version - // (i.e. it's the branching point) and if it's not the first version - const parentVersion = - item.parentIndex !== null && item.parentIndex !== i - 1 - ? `v${(item.parentIndex || 0) + 1}` - : null; - const type = displayHistoryItemType(item.type); - const isActive = i === currentVersion; - const summary = summarizeHistoryItem(item); - renderedHistory.push({ - isActive, - summary: summary, - parentVersion, - type, - }); - } - - return renderedHistory; -}; diff --git a/frontend/src/components/preview/PreviewPane.tsx b/frontend/src/components/preview/PreviewPane.tsx index bc4c1cf..83ee182 100644 --- a/frontend/src/components/preview/PreviewPane.tsx +++ b/frontend/src/components/preview/PreviewPane.tsx @@ -23,12 +23,17 @@ interface Props { function PreviewPane({ doUpdate, reset, settings }: Props) { const { appState } = useAppStore(); - const { inputMode, generatedCode, setGeneratedCode } = useProjectStore(); + const { inputMode, head, commits } = useProjectStore(); + + const currentCommit = head && commits[head] ? commits[head] : ""; + const currentCode = currentCommit + ? currentCommit.variants[currentCommit.selectedVariantIndex].code + : ""; const previewCode = inputMode === "video" && appState === AppState.CODING - ? extractHtml(generatedCode) - : generatedCode; + ? extractHtml(currentCode) + : currentCode; return (
    @@ -45,7 +50,7 @@ function PreviewPane({ doUpdate, reset, settings }: Props) { Reset
    - +