diff --git a/backend/main.py b/backend/main.py index f93284a..a80684e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -48,6 +48,13 @@ async def stream_code_test(websocket: WebSocket): prompt_messages = assemble_prompt(params["image"]) + if params["generationType"] == "update": + # Transform into message format + for index, text in enumerate(params["history"]): + prompt_messages += [ + {"role": "assistant" if index % 2 == 0 else "user", "content": text} + ] + if SHOULD_MOCK_AI_RESPONSE: completion = await mock_completion(process_chunk) else: diff --git a/backend/utils.py b/backend/utils.py new file mode 100644 index 0000000..88c3a67 --- /dev/null +++ b/backend/utils.py @@ -0,0 +1,20 @@ +import copy + + +def truncate_data_strings(data): + # Deep clone the data to avoid modifying the original object + cloned_data = copy.deepcopy(data) + + if isinstance(cloned_data, dict): + for key, value in cloned_data.items(): + # Recursively call the function if the value is a dictionary or a list + if isinstance(value, (dict, list)): + cloned_data[key] = truncate_data_strings(value) + # Truncate the string if it starts with 'data:' + elif isinstance(value, str) and value.startswith("data:"): + cloned_data[key] = value[:20] + elif isinstance(cloned_data, list): + # Process each item in the list + cloned_data = [truncate_data_strings(item) for item in cloned_data] + + return cloned_data diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index ce8d481..06bbec4 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -2,7 +2,7 @@ import { useState } from "react"; import ImageUpload from "./components/ImageUpload"; import CodePreview from "./components/CodePreview"; import Preview from "./components/Preview"; -import { generateCode } from "./generateCode"; +import { CodeGenerationParams, generateCode } from "./generateCode"; import Spinner from "./components/Spinner"; import classNames from "classnames"; import { FaDownload, FaUndo } from "react-icons/fa"; @@ -17,6 +17,8 @@ function App() { const [referenceImages, setReferenceImages] = useState([]); const [executionConsole, setExecutionConsole] = useState([]); const [blobUrl, setBlobUrl] = useState(""); + const [updateInstruction, setUpdateInstruction] = useState(""); + const [history, setHistory] = useState([]); const createBlobUrl = () => { const blob = new Blob([generatedCode], { type: "text/html" }); @@ -32,26 +34,41 @@ function App() { setBlobUrl(""); }; - function startCodeGeneration(referenceImages: string[]) { + function doGenerateCode(params: CodeGenerationParams) { setAppState("CODING"); - setReferenceImages(referenceImages); generateCode( - referenceImages[0], - function (token) { - setGeneratedCode((prev) => prev + token); - }, - function (code) { - setGeneratedCode(code); - }, - function (line) { - setExecutionConsole((prev) => [...prev, line]); - }, - function () { - setAppState("CODE_READY"); - } + params, + (token) => setGeneratedCode((prev) => prev + token), + (code) => setGeneratedCode(code), + (line) => setExecutionConsole((prev) => [...prev, line]), + () => setAppState("CODE_READY") ); } + // Initial version creation + function doCreate(referenceImages: string[]) { + setReferenceImages(referenceImages); + doGenerateCode({ + generationType: "create", + image: referenceImages[0], + }); + } + + // Subsequent updates + function doUpdate() { + const updatedHistory = [...history, generatedCode, updateInstruction]; + + doGenerateCode({ + generationType: "update", + image: referenceImages[0], + history: updatedHistory, + }); + + setHistory(updatedHistory); + setGeneratedCode(""); + setUpdateInstruction(""); + } + return (
@@ -122,8 +139,12 @@ function App() {
-