set up multiple generations

This commit is contained in:
Abi Raja 2024-07-30 15:44:48 -04:00
parent d7ab620e0b
commit aff9352dc0
6 changed files with 309 additions and 100 deletions

View File

@ -1,4 +1,5 @@
import os
import asyncio
import traceback
from fastapi import APIRouter, WebSocket
import openai
@ -14,17 +15,16 @@ from llm import (
)
from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
from typing import Dict, List, Union, cast, get_args
from typing import Any, Coroutine, Dict, List, Union, cast, get_args
from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_imported_code_prompt, assemble_prompt
from datetime import datetime
import json
from prompts.claude_prompts import VIDEO_PROMPT
from prompts.types import Stack
from utils import pprint_prompt
# from utils import pprint_prompt
from video.utils import extract_tag_content, assemble_claude_prompt_video
from video.utils import assemble_claude_prompt_video
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
@ -50,6 +50,59 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st
f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
# Generate images and return updated completions
async def process_completion(
websocket: WebSocket,
completion: str,
index: int,
should_generate_images: bool,
openai_api_key: str | None,
openai_base_url: str | None,
image_cache: dict[str, str],
):
try:
if should_generate_images and openai_api_key:
await websocket.send_json(
{
"type": "status",
"value": f"Generating images...",
"variantIndex": index,
}
)
updated_html = await generate_images(
completion,
api_key=openai_api_key,
base_url=openai_base_url,
image_cache=image_cache,
)
else:
updated_html = completion
await websocket.send_json(
{"type": "setCode", "value": updated_html, "variantIndex": index}
)
await websocket.send_json(
{
"type": "status",
"value": f"Code generation complete.",
"variantIndex": index,
}
)
except Exception as e:
traceback.print_exc()
print(f"Image generation failed for variant {index}", e)
await websocket.send_json(
{"type": "setCode", "value": completion, "variantIndex": index}
)
await websocket.send_json(
{
"type": "status",
"value": f"Image generation failed but code is complete.",
"variantIndex": index,
}
)
@router.websocket("/generate-code")
async def stream_code(websocket: WebSocket):
await websocket.accept()
@ -67,7 +120,7 @@ async def stream_code(websocket: WebSocket):
print("Received params")
# Read the code config settings from the request. Fall back to default if not provided.
# Read the code config settings (stack) from the request. Fall back to default if not provided.
generated_code_config = ""
if "generatedCodeConfig" in params and params["generatedCodeConfig"]:
generated_code_config = params["generatedCodeConfig"]
@ -107,8 +160,6 @@ async def stream_code(websocket: WebSocket):
)
code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20
exact_llm_version = None
print(
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
)
@ -162,17 +213,20 @@ async def stream_code(websocket: WebSocket):
print("Using official OpenAI URL")
# Get the image generation flag from the request. Fall back to True if not provided.
should_generate_images = (
params["isImageGenerationEnabled"]
if "isImageGenerationEnabled" in params
else True
)
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
print("generating code...")
await websocket.send_json({"type": "status", "value": "Generating code..."})
await websocket.send_json(
{"type": "status", "value": "Generating code...", "variantIndex": 0}
)
await websocket.send_json(
{"type": "status", "value": "Generating code...", "variantIndex": 1}
)
async def process_chunk(content: str):
await websocket.send_json({"type": "chunk", "value": content})
async def process_chunk(content: str, variantIndex: int = 0):
await websocket.send_json(
{"type": "chunk", "value": content, "variantIndex": variantIndex}
)
# Image cache for updates so that we don't have to regenerate images
image_cache: Dict[str, str] = {}
@ -239,9 +293,9 @@ async def stream_code(websocket: WebSocket):
# pprint_prompt(prompt_messages) # type: ignore
if SHOULD_MOCK_AI_RESPONSE:
completion = await mock_completion(
process_chunk, input_mode=validated_input_mode
)
completions = [
await mock_completion(process_chunk, input_mode=validated_input_mode)
]
else:
try:
if validated_input_mode == "video":
@ -251,41 +305,66 @@ async def stream_code(websocket: WebSocket):
)
raise Exception("No Anthropic key")
completion = await stream_claude_response_native(
system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
model=Llm.CLAUDE_3_OPUS,
include_thinking=True,
)
exact_llm_version = Llm.CLAUDE_3_OPUS
elif (
code_generation_model == Llm.CLAUDE_3_SONNET
or code_generation_model == Llm.CLAUDE_3_5_SONNET_2024_06_20
):
if not anthropic_api_key:
await throw_error(
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
completions = [
await stream_claude_response_native(
system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
model=Llm.CLAUDE_3_OPUS,
include_thinking=True,
)
raise Exception("No Anthropic key")
completion = await stream_claude_response(
prompt_messages, # type: ignore
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
model=code_generation_model,
)
exact_llm_version = code_generation_model
]
else:
completion = await stream_openai_response(
prompt_messages, # type: ignore
api_key=openai_api_key,
base_url=openai_base_url,
callback=lambda x: process_chunk(x),
model=code_generation_model,
)
exact_llm_version = code_generation_model
# Depending on the presence and absence of various keys,
# we decide which models to run
variant_models = []
if openai_api_key and anthropic_api_key:
variant_models = ["openai", "anthropic"]
elif openai_api_key:
variant_models = ["openai", "openai"]
elif anthropic_api_key:
variant_models = ["anthropic", "anthropic"]
else:
await throw_error(
"No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
)
raise Exception("No OpenAI or Anthropic key")
tasks: List[Coroutine[Any, Any, str]] = []
for index, model in enumerate(variant_models):
if model == "openai":
if openai_api_key is None:
await throw_error("OpenAI API key is missing.")
raise Exception("OpenAI API key is missing.")
tasks.append(
stream_openai_response(
prompt_messages,
api_key=openai_api_key,
base_url=openai_base_url,
callback=lambda x, i=index: process_chunk(x, i),
model=Llm.GPT_4O_2024_05_13,
)
)
elif model == "anthropic":
if anthropic_api_key is None:
await throw_error("Anthropic API key is missing.")
raise Exception("Anthropic API key is missing.")
tasks.append(
stream_claude_response(
prompt_messages,
api_key=anthropic_api_key,
callback=lambda x, i=index: process_chunk(x, i),
model=Llm.CLAUDE_3_5_SONNET_2024_06_20,
)
)
completions = await asyncio.gather(*tasks)
print("Models used for generation: ", variant_models)
except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e)
error_message = (
@ -321,42 +400,38 @@ async def stream_code(websocket: WebSocket):
)
return await throw_error(error_message)
if validated_input_mode == "video":
completion = extract_tag_content("html", completion)
print("Exact used model for generation: ", exact_llm_version)
# if validated_input_mode == "video":
# completion = extract_tag_content("html", completions[0])
# Strip the completion of everything except the HTML content
completion = extract_html_content(completion)
completions = [extract_html_content(completion) for completion in completions]
# Write the messages dict into a log so that we can debug later
write_logs(prompt_messages, completion) # type: ignore
write_logs(prompt_messages, completions[0]) # type: ignore
try:
if should_generate_images:
await websocket.send_json(
{"type": "status", "value": "Generating images..."}
)
updated_html = await generate_images(
image_generation_tasks = [
process_completion(
websocket,
completion,
api_key=openai_api_key,
base_url=openai_base_url,
image_cache=image_cache,
index,
should_generate_images,
openai_api_key,
openai_base_url,
image_cache,
)
else:
updated_html = completion
await websocket.send_json({"type": "setCode", "value": updated_html})
await websocket.send_json(
{"type": "status", "value": "Code generation complete."}
)
for index, completion in enumerate(completions)
]
await asyncio.gather(*image_generation_tasks)
except Exception as e:
traceback.print_exc()
print("Image generation failed", e)
# Send set code even if image generation fails since that triggers
# the frontend to update history
await websocket.send_json({"type": "setCode", "value": completion})
print("An error occurred during image generation and processing", e)
await websocket.send_json(
{"type": "status", "value": "Image generation failed but code is complete."}
{
"type": "status",
"value": "An error occurred during image generation and processing.",
"variantIndex": 0,
}
)
await websocket.close()

View File

@ -36,7 +36,12 @@ function App() {
// Outputs
setGeneratedCode,
setExecutionConsole,
currentVariantIndex,
setVariant,
appendToVariant,
resetVariants,
appendExecutionConsole,
resetExecutionConsoles,
currentVersion,
setCurrentVersion,
appHistory,
@ -106,10 +111,14 @@ function App() {
const reset = () => {
setAppState(AppState.INITIAL);
setGeneratedCode("");
resetVariants();
resetExecutionConsoles();
// Inputs
setReferenceImages([]);
setExecutionConsole([]);
setUpdateInstruction("");
setIsImportedFromCode(false);
setAppHistory([]);
setCurrentVersion(null);
setShouldIncludeResultImage(false);
@ -159,7 +168,7 @@ function App() {
parentVersion: number | null
) {
// Reset the execution console
setExecutionConsole([]);
resetExecutionConsoles();
// Set the app state
setAppState(AppState.CODING);
@ -171,10 +180,19 @@ function App() {
wsRef,
updatedParams,
// On change
(token) => setGeneratedCode((prev) => prev + token),
(token, variant) => {
if (variant === currentVariantIndex) {
setGeneratedCode((prev) => prev + token);
}
appendToVariant(token, variant);
},
// On set code
(code) => {
(code, variant) => {
setVariant(code, variant);
setGeneratedCode(code);
// TODO: How to deal with variants?
if (params.generationType === "create") {
setAppHistory([
{
@ -214,7 +232,7 @@ function App() {
}
},
// On status update
(line) => setExecutionConsole((prev) => [...prev, line]),
(line, variant) => appendExecutionConsole(variant, line),
// On cancel
() => {
cancelCodeGenerationAndReset();
@ -314,6 +332,7 @@ function App() {
}
setGeneratedCode("");
resetVariants();
setUpdateInstruction("");
}

View File

@ -12,6 +12,7 @@ import { Button } from "../ui/button";
import { Textarea } from "../ui/textarea";
import { useEffect, useRef } from "react";
import HistoryDisplay from "../history/HistoryDisplay";
import Variants from "../variants/Variants";
interface SidebarProps {
showSelectAndEditFeature: boolean;
@ -35,8 +36,16 @@ function Sidebar({
shouldIncludeResultImage,
setShouldIncludeResultImage,
} = useAppStore();
const { inputMode, generatedCode, referenceImages, executionConsole } =
useProjectStore();
const {
inputMode,
generatedCode,
referenceImages,
executionConsoles,
currentVariantIndex,
} = useProjectStore();
const executionConsole = executionConsoles[currentVariantIndex] || [];
// When coding is complete, focus on the update instruction textarea
useEffect(() => {
@ -47,6 +56,8 @@ function Sidebar({
return (
<>
<Variants />
{/* Show code preview only when coding */}
{appState === AppState.CODING && (
<div className="flex flex-col">

View File

@ -0,0 +1,64 @@
import { useProjectStore } from "../../store/project-store";
function Variants() {
const {
// Inputs
referenceImages,
// Outputs
variants,
currentVariantIndex,
setCurrentVariantIndex,
setGeneratedCode,
appHistory,
setAppHistory,
} = useProjectStore();
function switchVariant(index: number) {
const variant = variants[index];
setCurrentVariantIndex(index);
setGeneratedCode(variant);
if (appHistory.length === 1) {
setAppHistory([
{
type: "ai_create",
parentIndex: null,
code: variant,
inputs: { image_url: referenceImages[0] },
},
]);
} else {
setAppHistory((prev) => {
const newHistory = [...prev];
newHistory[newHistory.length - 1].code = variant;
return newHistory;
});
}
}
if (variants.length === 0) {
return null;
}
return (
<div className="mt-4 mb-4">
<div className="grid grid-cols-2 gap-2">
{variants.map((_, index) => (
<div
key={index}
className={`p-2 border rounded-md cursor-pointer ${
index === currentVariantIndex
? "bg-blue-100 dark:bg-blue-900"
: "bg-gray-50 dark:bg-gray-800 hover:bg-gray-100 dark:hover:bg-gray-700"
}`}
onClick={() => switchVariant(index)}
>
<h3 className="font-medium mb-1">Option {index + 1}</h3>
</div>
))}
</div>
</div>
);
}
export default Variants;

View File

@ -11,12 +11,18 @@ const ERROR_MESSAGE =
const CANCEL_MESSAGE = "Code generation cancelled";
type WebSocketResponse = {
type: "chunk" | "status" | "setCode" | "error";
value: string;
variantIndex: number;
};
export function generateCode(
wsRef: React.MutableRefObject<WebSocket | null>,
params: FullGenerationSettings,
onChange: (chunk: string) => void,
onSetCode: (code: string) => void,
onStatusUpdate: (status: string) => void,
onChange: (chunk: string, variantIndex: number) => void,
onSetCode: (code: string, variantIndex: number) => void,
onStatusUpdate: (status: string, variantIndex: number) => void,
onCancel: () => void,
onComplete: () => void
) {
@ -31,13 +37,13 @@ export function generateCode(
});
ws.addEventListener("message", async (event: MessageEvent) => {
const response = JSON.parse(event.data);
const response = JSON.parse(event.data) as WebSocketResponse;
if (response.type === "chunk") {
onChange(response.value);
onChange(response.value, response.variantIndex);
} else if (response.type === "status") {
onStatusUpdate(response.value);
onStatusUpdate(response.value, response.variantIndex);
} else if (response.type === "setCode") {
onSetCode(response.value);
onSetCode(response.value, response.variantIndex);
} else if (response.type === "error") {
console.error("Error generating code", response.value);
toast.error(response.value);

View File

@ -16,10 +16,17 @@ interface ProjectStore {
setGeneratedCode: (
updater: string | ((currentCode: string) => string)
) => void;
executionConsole: string[];
setExecutionConsole: (
updater: string[] | ((currentConsole: string[]) => string[])
) => void;
variants: string[];
currentVariantIndex: number;
setCurrentVariantIndex: (index: number) => void;
setVariant: (code: string, index: number) => void;
appendToVariant: (newTokens: string, index: number) => void;
resetVariants: () => void;
executionConsoles: { [key: number]: string[] };
appendExecutionConsole: (variantIndex: number, line: string) => void;
resetExecutionConsoles: () => void;
// Tracks the currently shown version from app history
// TODO: might want to move to appStore
@ -48,14 +55,41 @@ export const useProjectStore = create<ProjectStore>((set) => ({
generatedCode:
typeof updater === "function" ? updater(state.generatedCode) : updater,
})),
executionConsole: [],
setExecutionConsole: (updater) =>
variants: [],
currentVariantIndex: 0,
setCurrentVariantIndex: (index) => set({ currentVariantIndex: index }),
setVariant: (code: string, index: number) =>
set((state) => {
const newVariants = [...state.variants];
while (newVariants.length <= index) {
newVariants.push("");
}
newVariants[index] = code;
return { variants: newVariants };
}),
appendToVariant: (newTokens: string, index: number) =>
set((state) => {
const newVariants = [...state.variants];
newVariants[index] += newTokens;
return { variants: newVariants };
}),
resetVariants: () => set({ variants: [], currentVariantIndex: 0 }),
executionConsoles: {},
appendExecutionConsole: (variantIndex: number, line: string) =>
set((state) => ({
executionConsole:
typeof updater === "function"
? updater(state.executionConsole)
: updater,
executionConsoles: {
...state.executionConsoles,
[variantIndex]: [
...(state.executionConsoles[variantIndex] || []),
line,
],
},
})),
resetExecutionConsoles: () => set({ executionConsoles: {} }),
currentVersion: null,
setCurrentVersion: (version) => set({ currentVersion: version }),