support updating the initially generated version

This commit is contained in:
Abi Raja 2023-11-15 15:08:59 -05:00
parent 6f3c668c2f
commit db56dbd3e6
4 changed files with 75 additions and 25 deletions

View File

@ -48,6 +48,13 @@ async def stream_code_test(websocket: WebSocket):
prompt_messages = assemble_prompt(params["image"]) prompt_messages = assemble_prompt(params["image"])
if params["generationType"] == "update":
# Transform into message format
for index, text in enumerate(params["history"]):
prompt_messages += [
{"role": "assistant" if index % 2 == 0 else "user", "content": text}
]
if SHOULD_MOCK_AI_RESPONSE: if SHOULD_MOCK_AI_RESPONSE:
completion = await mock_completion(process_chunk) completion = await mock_completion(process_chunk)
else: else:

20
backend/utils.py Normal file
View File

@ -0,0 +1,20 @@
import copy
def truncate_data_strings(data):
# Deep clone the data to avoid modifying the original object
cloned_data = copy.deepcopy(data)
if isinstance(cloned_data, dict):
for key, value in cloned_data.items():
# Recursively call the function if the value is a dictionary or a list
if isinstance(value, (dict, list)):
cloned_data[key] = truncate_data_strings(value)
# Truncate the string if it starts with 'data:'
elif isinstance(value, str) and value.startswith("data:"):
cloned_data[key] = value[:20]
elif isinstance(cloned_data, list):
# Process each item in the list
cloned_data = [truncate_data_strings(item) for item in cloned_data]
return cloned_data

View File

@ -2,7 +2,7 @@ import { useState } from "react";
import ImageUpload from "./components/ImageUpload"; import ImageUpload from "./components/ImageUpload";
import CodePreview from "./components/CodePreview"; import CodePreview from "./components/CodePreview";
import Preview from "./components/Preview"; import Preview from "./components/Preview";
import { generateCode } from "./generateCode"; import { CodeGenerationParams, generateCode } from "./generateCode";
import Spinner from "./components/Spinner"; import Spinner from "./components/Spinner";
import classNames from "classnames"; import classNames from "classnames";
import { FaDownload, FaUndo } from "react-icons/fa"; import { FaDownload, FaUndo } from "react-icons/fa";
@ -17,6 +17,8 @@ function App() {
const [referenceImages, setReferenceImages] = useState<string[]>([]); const [referenceImages, setReferenceImages] = useState<string[]>([]);
const [executionConsole, setExecutionConsole] = useState<string[]>([]); const [executionConsole, setExecutionConsole] = useState<string[]>([]);
const [blobUrl, setBlobUrl] = useState(""); const [blobUrl, setBlobUrl] = useState("");
const [updateInstruction, setUpdateInstruction] = useState("");
const [history, setHistory] = useState<string[]>([]);
const createBlobUrl = () => { const createBlobUrl = () => {
const blob = new Blob([generatedCode], { type: "text/html" }); const blob = new Blob([generatedCode], { type: "text/html" });
@ -32,26 +34,41 @@ function App() {
setBlobUrl(""); setBlobUrl("");
}; };
function startCodeGeneration(referenceImages: string[]) { function doGenerateCode(params: CodeGenerationParams) {
setAppState("CODING"); setAppState("CODING");
setReferenceImages(referenceImages);
generateCode( generateCode(
referenceImages[0], params,
function (token) { (token) => setGeneratedCode((prev) => prev + token),
setGeneratedCode((prev) => prev + token); (code) => setGeneratedCode(code),
}, (line) => setExecutionConsole((prev) => [...prev, line]),
function (code) { () => setAppState("CODE_READY")
setGeneratedCode(code);
},
function (line) {
setExecutionConsole((prev) => [...prev, line]);
},
function () {
setAppState("CODE_READY");
}
); );
} }
// Initial version creation
function doCreate(referenceImages: string[]) {
setReferenceImages(referenceImages);
doGenerateCode({
generationType: "create",
image: referenceImages[0],
});
}
// Subsequent updates
function doUpdate() {
const updatedHistory = [...history, generatedCode, updateInstruction];
doGenerateCode({
generationType: "update",
image: referenceImages[0],
history: updatedHistory,
});
setHistory(updatedHistory);
setGeneratedCode("");
setUpdateInstruction("");
}
return ( return (
<div className="mt-6"> <div className="mt-6">
<div className="hidden lg:fixed lg:inset-y-0 lg:z-50 lg:flex lg:w-96 lg:flex-col"> <div className="hidden lg:fixed lg:inset-y-0 lg:z-50 lg:flex lg:w-96 lg:flex-col">
@ -122,8 +139,12 @@ function App() {
</button> </button>
</div> </div>
<div className="grid w-full gap-2"> <div className="grid w-full gap-2">
<Textarea placeholder="Describe what the AI missed the first time around" /> <Textarea
<Button>Update</Button> placeholder="Describe what the AI missed the first time around"
onChange={(e) => setUpdateInstruction(e.target.value)}
value={updateInstruction}
/>
<Button onClick={doUpdate}>Update</Button>
</div> </div>
</div> </div>
)} )}
@ -135,7 +156,7 @@ function App() {
<main className="py-2 lg:pl-96"> <main className="py-2 lg:pl-96">
{appState === "INITIAL" && ( {appState === "INITIAL" && (
<> <>
<ImageUpload setReferenceImages={startCodeGeneration} /> <ImageUpload setReferenceImages={doCreate} />
</> </>
)} )}

View File

@ -5,8 +5,14 @@ const WS_BACKEND_URL =
const ERROR_MESSAGE = const ERROR_MESSAGE =
"Error generating code. Check the Developer Console for details. Feel free to open a Github ticket"; "Error generating code. Check the Developer Console for details. Feel free to open a Github ticket";
export interface CodeGenerationParams {
generationType: "create" | "update";
image: string;
history?: string[];
}
export function generateCode( export function generateCode(
imageUrl: string, params: CodeGenerationParams,
onChange: (chunk: string) => void, onChange: (chunk: string) => void,
onSetCode: (code: string) => void, onSetCode: (code: string) => void,
onStatusUpdate: (status: string) => void, onStatusUpdate: (status: string) => void,
@ -18,11 +24,7 @@ export function generateCode(
const ws = new WebSocket(wsUrl); const ws = new WebSocket(wsUrl);
ws.addEventListener("open", () => { ws.addEventListener("open", () => {
ws.send( ws.send(JSON.stringify(params));
JSON.stringify({
image: imageUrl,
})
);
}); });
ws.addEventListener("message", async (event: MessageEvent) => { ws.addEventListener("message", async (event: MessageEvent) => {