set up multiple generations

This commit is contained in:
Abi Raja 2024-07-30 15:44:48 -04:00
parent d7ab620e0b
commit aff9352dc0
6 changed files with 309 additions and 100 deletions

View File

@ -1,4 +1,5 @@
import os import os
import asyncio
import traceback import traceback
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import openai import openai
@ -14,17 +15,16 @@ from llm import (
) )
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion from mock_llm import mock_completion
from typing import Dict, List, Union, cast, get_args from typing import Any, Coroutine, Dict, List, Union, cast, get_args
from image_generation import create_alt_url_mapping, generate_images from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_imported_code_prompt, assemble_prompt from prompts import assemble_imported_code_prompt, assemble_prompt
from datetime import datetime from datetime import datetime
import json import json
from prompts.claude_prompts import VIDEO_PROMPT from prompts.claude_prompts import VIDEO_PROMPT
from prompts.types import Stack from prompts.types import Stack
from utils import pprint_prompt
# from utils import pprint_prompt # from utils import pprint_prompt
from video.utils import extract_tag_content, assemble_claude_prompt_video from video.utils import assemble_claude_prompt_video
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
@ -50,6 +50,59 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st
f.write(json.dumps({"prompt": prompt_messages, "completion": completion})) f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
# Generate images and return updated completions
async def process_completion(
websocket: WebSocket,
completion: str,
index: int,
should_generate_images: bool,
openai_api_key: str | None,
openai_base_url: str | None,
image_cache: dict[str, str],
):
try:
if should_generate_images and openai_api_key:
await websocket.send_json(
{
"type": "status",
"value": f"Generating images...",
"variantIndex": index,
}
)
updated_html = await generate_images(
completion,
api_key=openai_api_key,
base_url=openai_base_url,
image_cache=image_cache,
)
else:
updated_html = completion
await websocket.send_json(
{"type": "setCode", "value": updated_html, "variantIndex": index}
)
await websocket.send_json(
{
"type": "status",
"value": f"Code generation complete.",
"variantIndex": index,
}
)
except Exception as e:
traceback.print_exc()
print(f"Image generation failed for variant {index}", e)
await websocket.send_json(
{"type": "setCode", "value": completion, "variantIndex": index}
)
await websocket.send_json(
{
"type": "status",
"value": f"Image generation failed but code is complete.",
"variantIndex": index,
}
)
@router.websocket("/generate-code") @router.websocket("/generate-code")
async def stream_code(websocket: WebSocket): async def stream_code(websocket: WebSocket):
await websocket.accept() await websocket.accept()
@ -67,7 +120,7 @@ async def stream_code(websocket: WebSocket):
print("Received params") print("Received params")
# Read the code config settings from the request. Fall back to default if not provided. # Read the code config settings (stack) from the request. Fall back to default if not provided.
generated_code_config = "" generated_code_config = ""
if "generatedCodeConfig" in params and params["generatedCodeConfig"]: if "generatedCodeConfig" in params and params["generatedCodeConfig"]:
generated_code_config = params["generatedCodeConfig"] generated_code_config = params["generatedCodeConfig"]
@ -107,8 +160,6 @@ async def stream_code(websocket: WebSocket):
) )
code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20 code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20
exact_llm_version = None
print( print(
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
) )
@ -162,17 +213,20 @@ async def stream_code(websocket: WebSocket):
print("Using official OpenAI URL") print("Using official OpenAI URL")
# Get the image generation flag from the request. Fall back to True if not provided. # Get the image generation flag from the request. Fall back to True if not provided.
should_generate_images = ( should_generate_images = bool(params.get("isImageGenerationEnabled", True))
params["isImageGenerationEnabled"]
if "isImageGenerationEnabled" in params
else True
)
print("generating code...") print("generating code...")
await websocket.send_json({"type": "status", "value": "Generating code..."}) await websocket.send_json(
{"type": "status", "value": "Generating code...", "variantIndex": 0}
)
await websocket.send_json(
{"type": "status", "value": "Generating code...", "variantIndex": 1}
)
async def process_chunk(content: str): async def process_chunk(content: str, variantIndex: int = 0):
await websocket.send_json({"type": "chunk", "value": content}) await websocket.send_json(
{"type": "chunk", "value": content, "variantIndex": variantIndex}
)
# Image cache for updates so that we don't have to regenerate images # Image cache for updates so that we don't have to regenerate images
image_cache: Dict[str, str] = {} image_cache: Dict[str, str] = {}
@ -239,9 +293,9 @@ async def stream_code(websocket: WebSocket):
# pprint_prompt(prompt_messages) # type: ignore # pprint_prompt(prompt_messages) # type: ignore
if SHOULD_MOCK_AI_RESPONSE: if SHOULD_MOCK_AI_RESPONSE:
completion = await mock_completion( completions = [
process_chunk, input_mode=validated_input_mode await mock_completion(process_chunk, input_mode=validated_input_mode)
) ]
else: else:
try: try:
if validated_input_mode == "video": if validated_input_mode == "video":
@ -251,7 +305,8 @@ async def stream_code(websocket: WebSocket):
) )
raise Exception("No Anthropic key") raise Exception("No Anthropic key")
completion = await stream_claude_response_native( completions = [
await stream_claude_response_native(
system_prompt=VIDEO_PROMPT, system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore messages=prompt_messages, # type: ignore
api_key=anthropic_api_key, api_key=anthropic_api_key,
@ -259,33 +314,57 @@ async def stream_code(websocket: WebSocket):
model=Llm.CLAUDE_3_OPUS, model=Llm.CLAUDE_3_OPUS,
include_thinking=True, include_thinking=True,
) )
exact_llm_version = Llm.CLAUDE_3_OPUS ]
elif (
code_generation_model == Llm.CLAUDE_3_SONNET
or code_generation_model == Llm.CLAUDE_3_5_SONNET_2024_06_20
):
if not anthropic_api_key:
await throw_error(
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
)
raise Exception("No Anthropic key")
completion = await stream_claude_response(
prompt_messages, # type: ignore
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
model=code_generation_model,
)
exact_llm_version = code_generation_model
else: else:
completion = await stream_openai_response(
prompt_messages, # type: ignore # Depending on the presence and absence of various keys,
# we decide which models to run
variant_models = []
if openai_api_key and anthropic_api_key:
variant_models = ["openai", "anthropic"]
elif openai_api_key:
variant_models = ["openai", "openai"]
elif anthropic_api_key:
variant_models = ["anthropic", "anthropic"]
else:
await throw_error(
"No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
)
raise Exception("No OpenAI or Anthropic key")
tasks: List[Coroutine[Any, Any, str]] = []
for index, model in enumerate(variant_models):
if model == "openai":
if openai_api_key is None:
await throw_error("OpenAI API key is missing.")
raise Exception("OpenAI API key is missing.")
tasks.append(
stream_openai_response(
prompt_messages,
api_key=openai_api_key, api_key=openai_api_key,
base_url=openai_base_url, base_url=openai_base_url,
callback=lambda x: process_chunk(x), callback=lambda x, i=index: process_chunk(x, i),
model=code_generation_model, model=Llm.GPT_4O_2024_05_13,
) )
exact_llm_version = code_generation_model )
elif model == "anthropic":
if anthropic_api_key is None:
await throw_error("Anthropic API key is missing.")
raise Exception("Anthropic API key is missing.")
tasks.append(
stream_claude_response(
prompt_messages,
api_key=anthropic_api_key,
callback=lambda x, i=index: process_chunk(x, i),
model=Llm.CLAUDE_3_5_SONNET_2024_06_20,
)
)
completions = await asyncio.gather(*tasks)
print("Models used for generation: ", variant_models)
except openai.AuthenticationError as e: except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e) print("[GENERATE_CODE] Authentication failed", e)
error_message = ( error_message = (
@ -321,42 +400,38 @@ async def stream_code(websocket: WebSocket):
) )
return await throw_error(error_message) return await throw_error(error_message)
if validated_input_mode == "video": # if validated_input_mode == "video":
completion = extract_tag_content("html", completion) # completion = extract_tag_content("html", completions[0])
print("Exact used model for generation: ", exact_llm_version)
# Strip the completion of everything except the HTML content # Strip the completion of everything except the HTML content
completion = extract_html_content(completion) completions = [extract_html_content(completion) for completion in completions]
# Write the messages dict into a log so that we can debug later # Write the messages dict into a log so that we can debug later
write_logs(prompt_messages, completion) # type: ignore write_logs(prompt_messages, completions[0]) # type: ignore
try: try:
if should_generate_images: image_generation_tasks = [
await websocket.send_json( process_completion(
{"type": "status", "value": "Generating images..."} websocket,
)
updated_html = await generate_images(
completion, completion,
api_key=openai_api_key, index,
base_url=openai_base_url, should_generate_images,
image_cache=image_cache, openai_api_key,
) openai_base_url,
else: image_cache,
updated_html = completion
await websocket.send_json({"type": "setCode", "value": updated_html})
await websocket.send_json(
{"type": "status", "value": "Code generation complete."}
) )
for index, completion in enumerate(completions)
]
await asyncio.gather(*image_generation_tasks)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
print("Image generation failed", e) print("An error occurred during image generation and processing", e)
# Send set code even if image generation fails since that triggers
# the frontend to update history
await websocket.send_json({"type": "setCode", "value": completion})
await websocket.send_json( await websocket.send_json(
{"type": "status", "value": "Image generation failed but code is complete."} {
"type": "status",
"value": "An error occurred during image generation and processing.",
"variantIndex": 0,
}
) )
await websocket.close() await websocket.close()

View File

@ -36,7 +36,12 @@ function App() {
// Outputs // Outputs
setGeneratedCode, setGeneratedCode,
setExecutionConsole, currentVariantIndex,
setVariant,
appendToVariant,
resetVariants,
appendExecutionConsole,
resetExecutionConsoles,
currentVersion, currentVersion,
setCurrentVersion, setCurrentVersion,
appHistory, appHistory,
@ -106,10 +111,14 @@ function App() {
const reset = () => { const reset = () => {
setAppState(AppState.INITIAL); setAppState(AppState.INITIAL);
setGeneratedCode(""); setGeneratedCode("");
resetVariants();
resetExecutionConsoles();
// Inputs
setReferenceImages([]); setReferenceImages([]);
setExecutionConsole([]);
setUpdateInstruction(""); setUpdateInstruction("");
setIsImportedFromCode(false); setIsImportedFromCode(false);
setAppHistory([]); setAppHistory([]);
setCurrentVersion(null); setCurrentVersion(null);
setShouldIncludeResultImage(false); setShouldIncludeResultImage(false);
@ -159,7 +168,7 @@ function App() {
parentVersion: number | null parentVersion: number | null
) { ) {
// Reset the execution console // Reset the execution console
setExecutionConsole([]); resetExecutionConsoles();
// Set the app state // Set the app state
setAppState(AppState.CODING); setAppState(AppState.CODING);
@ -171,10 +180,19 @@ function App() {
wsRef, wsRef,
updatedParams, updatedParams,
// On change // On change
(token) => setGeneratedCode((prev) => prev + token), (token, variant) => {
if (variant === currentVariantIndex) {
setGeneratedCode((prev) => prev + token);
}
appendToVariant(token, variant);
},
// On set code // On set code
(code) => { (code, variant) => {
setVariant(code, variant);
setGeneratedCode(code); setGeneratedCode(code);
// TODO: How to deal with variants?
if (params.generationType === "create") { if (params.generationType === "create") {
setAppHistory([ setAppHistory([
{ {
@ -214,7 +232,7 @@ function App() {
} }
}, },
// On status update // On status update
(line) => setExecutionConsole((prev) => [...prev, line]), (line, variant) => appendExecutionConsole(variant, line),
// On cancel // On cancel
() => { () => {
cancelCodeGenerationAndReset(); cancelCodeGenerationAndReset();
@ -314,6 +332,7 @@ function App() {
} }
setGeneratedCode(""); setGeneratedCode("");
resetVariants();
setUpdateInstruction(""); setUpdateInstruction("");
} }

View File

@ -12,6 +12,7 @@ import { Button } from "../ui/button";
import { Textarea } from "../ui/textarea"; import { Textarea } from "../ui/textarea";
import { useEffect, useRef } from "react"; import { useEffect, useRef } from "react";
import HistoryDisplay from "../history/HistoryDisplay"; import HistoryDisplay from "../history/HistoryDisplay";
import Variants from "../variants/Variants";
interface SidebarProps { interface SidebarProps {
showSelectAndEditFeature: boolean; showSelectAndEditFeature: boolean;
@ -35,8 +36,16 @@ function Sidebar({
shouldIncludeResultImage, shouldIncludeResultImage,
setShouldIncludeResultImage, setShouldIncludeResultImage,
} = useAppStore(); } = useAppStore();
const { inputMode, generatedCode, referenceImages, executionConsole } =
useProjectStore(); const {
inputMode,
generatedCode,
referenceImages,
executionConsoles,
currentVariantIndex,
} = useProjectStore();
const executionConsole = executionConsoles[currentVariantIndex] || [];
// When coding is complete, focus on the update instruction textarea // When coding is complete, focus on the update instruction textarea
useEffect(() => { useEffect(() => {
@ -47,6 +56,8 @@ function Sidebar({
return ( return (
<> <>
<Variants />
{/* Show code preview only when coding */} {/* Show code preview only when coding */}
{appState === AppState.CODING && ( {appState === AppState.CODING && (
<div className="flex flex-col"> <div className="flex flex-col">

View File

@ -0,0 +1,64 @@
import { useProjectStore } from "../../store/project-store";
function Variants() {
const {
// Inputs
referenceImages,
// Outputs
variants,
currentVariantIndex,
setCurrentVariantIndex,
setGeneratedCode,
appHistory,
setAppHistory,
} = useProjectStore();
function switchVariant(index: number) {
const variant = variants[index];
setCurrentVariantIndex(index);
setGeneratedCode(variant);
if (appHistory.length === 1) {
setAppHistory([
{
type: "ai_create",
parentIndex: null,
code: variant,
inputs: { image_url: referenceImages[0] },
},
]);
} else {
setAppHistory((prev) => {
const newHistory = [...prev];
newHistory[newHistory.length - 1].code = variant;
return newHistory;
});
}
}
if (variants.length === 0) {
return null;
}
return (
<div className="mt-4 mb-4">
<div className="grid grid-cols-2 gap-2">
{variants.map((_, index) => (
<div
key={index}
className={`p-2 border rounded-md cursor-pointer ${
index === currentVariantIndex
? "bg-blue-100 dark:bg-blue-900"
: "bg-gray-50 dark:bg-gray-800 hover:bg-gray-100 dark:hover:bg-gray-700"
}`}
onClick={() => switchVariant(index)}
>
<h3 className="font-medium mb-1">Option {index + 1}</h3>
</div>
))}
</div>
</div>
);
}
export default Variants;

View File

@ -11,12 +11,18 @@ const ERROR_MESSAGE =
const CANCEL_MESSAGE = "Code generation cancelled"; const CANCEL_MESSAGE = "Code generation cancelled";
type WebSocketResponse = {
type: "chunk" | "status" | "setCode" | "error";
value: string;
variantIndex: number;
};
export function generateCode( export function generateCode(
wsRef: React.MutableRefObject<WebSocket | null>, wsRef: React.MutableRefObject<WebSocket | null>,
params: FullGenerationSettings, params: FullGenerationSettings,
onChange: (chunk: string) => void, onChange: (chunk: string, variantIndex: number) => void,
onSetCode: (code: string) => void, onSetCode: (code: string, variantIndex: number) => void,
onStatusUpdate: (status: string) => void, onStatusUpdate: (status: string, variantIndex: number) => void,
onCancel: () => void, onCancel: () => void,
onComplete: () => void onComplete: () => void
) { ) {
@ -31,13 +37,13 @@ export function generateCode(
}); });
ws.addEventListener("message", async (event: MessageEvent) => { ws.addEventListener("message", async (event: MessageEvent) => {
const response = JSON.parse(event.data); const response = JSON.parse(event.data) as WebSocketResponse;
if (response.type === "chunk") { if (response.type === "chunk") {
onChange(response.value); onChange(response.value, response.variantIndex);
} else if (response.type === "status") { } else if (response.type === "status") {
onStatusUpdate(response.value); onStatusUpdate(response.value, response.variantIndex);
} else if (response.type === "setCode") { } else if (response.type === "setCode") {
onSetCode(response.value); onSetCode(response.value, response.variantIndex);
} else if (response.type === "error") { } else if (response.type === "error") {
console.error("Error generating code", response.value); console.error("Error generating code", response.value);
toast.error(response.value); toast.error(response.value);

View File

@ -16,10 +16,17 @@ interface ProjectStore {
setGeneratedCode: ( setGeneratedCode: (
updater: string | ((currentCode: string) => string) updater: string | ((currentCode: string) => string)
) => void; ) => void;
executionConsole: string[];
setExecutionConsole: ( variants: string[];
updater: string[] | ((currentConsole: string[]) => string[]) currentVariantIndex: number;
) => void; setCurrentVariantIndex: (index: number) => void;
setVariant: (code: string, index: number) => void;
appendToVariant: (newTokens: string, index: number) => void;
resetVariants: () => void;
executionConsoles: { [key: number]: string[] };
appendExecutionConsole: (variantIndex: number, line: string) => void;
resetExecutionConsoles: () => void;
// Tracks the currently shown version from app history // Tracks the currently shown version from app history
// TODO: might want to move to appStore // TODO: might want to move to appStore
@ -48,14 +55,41 @@ export const useProjectStore = create<ProjectStore>((set) => ({
generatedCode: generatedCode:
typeof updater === "function" ? updater(state.generatedCode) : updater, typeof updater === "function" ? updater(state.generatedCode) : updater,
})), })),
executionConsole: [],
setExecutionConsole: (updater) => variants: [],
currentVariantIndex: 0,
setCurrentVariantIndex: (index) => set({ currentVariantIndex: index }),
setVariant: (code: string, index: number) =>
set((state) => {
const newVariants = [...state.variants];
while (newVariants.length <= index) {
newVariants.push("");
}
newVariants[index] = code;
return { variants: newVariants };
}),
appendToVariant: (newTokens: string, index: number) =>
set((state) => {
const newVariants = [...state.variants];
newVariants[index] += newTokens;
return { variants: newVariants };
}),
resetVariants: () => set({ variants: [], currentVariantIndex: 0 }),
executionConsoles: {},
appendExecutionConsole: (variantIndex: number, line: string) =>
set((state) => ({ set((state) => ({
executionConsole: executionConsoles: {
typeof updater === "function" ...state.executionConsoles,
? updater(state.executionConsole) [variantIndex]: [
: updater, ...(state.executionConsoles[variantIndex] || []),
line,
],
},
})), })),
resetExecutionConsoles: () => set({ executionConsoles: {} }),
currentVersion: null, currentVersion: null,
setCurrentVersion: (version) => set({ currentVersion: version }), setCurrentVersion: (version) => set({ currentVersion: version }),