set up multiple generations
This commit is contained in:
parent
d7ab620e0b
commit
aff9352dc0
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, WebSocket
|
||||||
import openai
|
import openai
|
||||||
@ -14,17 +15,16 @@ from llm import (
|
|||||||
)
|
)
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
from mock_llm import mock_completion
|
from mock_llm import mock_completion
|
||||||
from typing import Dict, List, Union, cast, get_args
|
from typing import Any, Coroutine, Dict, List, Union, cast, get_args
|
||||||
from image_generation import create_alt_url_mapping, generate_images
|
from image_generation import create_alt_url_mapping, generate_images
|
||||||
from prompts import assemble_imported_code_prompt, assemble_prompt
|
from prompts import assemble_imported_code_prompt, assemble_prompt
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
from prompts.claude_prompts import VIDEO_PROMPT
|
from prompts.claude_prompts import VIDEO_PROMPT
|
||||||
from prompts.types import Stack
|
from prompts.types import Stack
|
||||||
from utils import pprint_prompt
|
|
||||||
|
|
||||||
# from utils import pprint_prompt
|
# from utils import pprint_prompt
|
||||||
from video.utils import extract_tag_content, assemble_claude_prompt_video
|
from video.utils import assemble_claude_prompt_video
|
||||||
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@ -50,6 +50,59 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st
|
|||||||
f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
|
f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
|
||||||
|
|
||||||
|
|
||||||
|
# Generate images and return updated completions
|
||||||
|
async def process_completion(
|
||||||
|
websocket: WebSocket,
|
||||||
|
completion: str,
|
||||||
|
index: int,
|
||||||
|
should_generate_images: bool,
|
||||||
|
openai_api_key: str | None,
|
||||||
|
openai_base_url: str | None,
|
||||||
|
image_cache: dict[str, str],
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if should_generate_images and openai_api_key:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "status",
|
||||||
|
"value": f"Generating images...",
|
||||||
|
"variantIndex": index,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
updated_html = await generate_images(
|
||||||
|
completion,
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_base_url,
|
||||||
|
image_cache=image_cache,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
updated_html = completion
|
||||||
|
|
||||||
|
await websocket.send_json(
|
||||||
|
{"type": "setCode", "value": updated_html, "variantIndex": index}
|
||||||
|
)
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "status",
|
||||||
|
"value": f"Code generation complete.",
|
||||||
|
"variantIndex": index,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Image generation failed for variant {index}", e)
|
||||||
|
await websocket.send_json(
|
||||||
|
{"type": "setCode", "value": completion, "variantIndex": index}
|
||||||
|
)
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "status",
|
||||||
|
"value": f"Image generation failed but code is complete.",
|
||||||
|
"variantIndex": index,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/generate-code")
|
@router.websocket("/generate-code")
|
||||||
async def stream_code(websocket: WebSocket):
|
async def stream_code(websocket: WebSocket):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
@ -67,7 +120,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
|
|
||||||
print("Received params")
|
print("Received params")
|
||||||
|
|
||||||
# Read the code config settings from the request. Fall back to default if not provided.
|
# Read the code config settings (stack) from the request. Fall back to default if not provided.
|
||||||
generated_code_config = ""
|
generated_code_config = ""
|
||||||
if "generatedCodeConfig" in params and params["generatedCodeConfig"]:
|
if "generatedCodeConfig" in params and params["generatedCodeConfig"]:
|
||||||
generated_code_config = params["generatedCodeConfig"]
|
generated_code_config = params["generatedCodeConfig"]
|
||||||
@ -107,8 +160,6 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20
|
code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20
|
||||||
|
|
||||||
exact_llm_version = None
|
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
|
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
|
||||||
)
|
)
|
||||||
@ -162,17 +213,20 @@ async def stream_code(websocket: WebSocket):
|
|||||||
print("Using official OpenAI URL")
|
print("Using official OpenAI URL")
|
||||||
|
|
||||||
# Get the image generation flag from the request. Fall back to True if not provided.
|
# Get the image generation flag from the request. Fall back to True if not provided.
|
||||||
should_generate_images = (
|
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
|
||||||
params["isImageGenerationEnabled"]
|
|
||||||
if "isImageGenerationEnabled" in params
|
|
||||||
else True
|
|
||||||
)
|
|
||||||
|
|
||||||
print("generating code...")
|
print("generating code...")
|
||||||
await websocket.send_json({"type": "status", "value": "Generating code..."})
|
await websocket.send_json(
|
||||||
|
{"type": "status", "value": "Generating code...", "variantIndex": 0}
|
||||||
|
)
|
||||||
|
await websocket.send_json(
|
||||||
|
{"type": "status", "value": "Generating code...", "variantIndex": 1}
|
||||||
|
)
|
||||||
|
|
||||||
async def process_chunk(content: str):
|
async def process_chunk(content: str, variantIndex: int = 0):
|
||||||
await websocket.send_json({"type": "chunk", "value": content})
|
await websocket.send_json(
|
||||||
|
{"type": "chunk", "value": content, "variantIndex": variantIndex}
|
||||||
|
)
|
||||||
|
|
||||||
# Image cache for updates so that we don't have to regenerate images
|
# Image cache for updates so that we don't have to regenerate images
|
||||||
image_cache: Dict[str, str] = {}
|
image_cache: Dict[str, str] = {}
|
||||||
@ -239,9 +293,9 @@ async def stream_code(websocket: WebSocket):
|
|||||||
# pprint_prompt(prompt_messages) # type: ignore
|
# pprint_prompt(prompt_messages) # type: ignore
|
||||||
|
|
||||||
if SHOULD_MOCK_AI_RESPONSE:
|
if SHOULD_MOCK_AI_RESPONSE:
|
||||||
completion = await mock_completion(
|
completions = [
|
||||||
process_chunk, input_mode=validated_input_mode
|
await mock_completion(process_chunk, input_mode=validated_input_mode)
|
||||||
)
|
]
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
if validated_input_mode == "video":
|
if validated_input_mode == "video":
|
||||||
@ -251,7 +305,8 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
raise Exception("No Anthropic key")
|
raise Exception("No Anthropic key")
|
||||||
|
|
||||||
completion = await stream_claude_response_native(
|
completions = [
|
||||||
|
await stream_claude_response_native(
|
||||||
system_prompt=VIDEO_PROMPT,
|
system_prompt=VIDEO_PROMPT,
|
||||||
messages=prompt_messages, # type: ignore
|
messages=prompt_messages, # type: ignore
|
||||||
api_key=anthropic_api_key,
|
api_key=anthropic_api_key,
|
||||||
@ -259,33 +314,57 @@ async def stream_code(websocket: WebSocket):
|
|||||||
model=Llm.CLAUDE_3_OPUS,
|
model=Llm.CLAUDE_3_OPUS,
|
||||||
include_thinking=True,
|
include_thinking=True,
|
||||||
)
|
)
|
||||||
exact_llm_version = Llm.CLAUDE_3_OPUS
|
]
|
||||||
elif (
|
|
||||||
code_generation_model == Llm.CLAUDE_3_SONNET
|
|
||||||
or code_generation_model == Llm.CLAUDE_3_5_SONNET_2024_06_20
|
|
||||||
):
|
|
||||||
if not anthropic_api_key:
|
|
||||||
await throw_error(
|
|
||||||
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
|
|
||||||
)
|
|
||||||
raise Exception("No Anthropic key")
|
|
||||||
|
|
||||||
completion = await stream_claude_response(
|
|
||||||
prompt_messages, # type: ignore
|
|
||||||
api_key=anthropic_api_key,
|
|
||||||
callback=lambda x: process_chunk(x),
|
|
||||||
model=code_generation_model,
|
|
||||||
)
|
|
||||||
exact_llm_version = code_generation_model
|
|
||||||
else:
|
else:
|
||||||
completion = await stream_openai_response(
|
|
||||||
prompt_messages, # type: ignore
|
# Depending on the presence and absence of various keys,
|
||||||
|
# we decide which models to run
|
||||||
|
variant_models = []
|
||||||
|
if openai_api_key and anthropic_api_key:
|
||||||
|
variant_models = ["openai", "anthropic"]
|
||||||
|
elif openai_api_key:
|
||||||
|
variant_models = ["openai", "openai"]
|
||||||
|
elif anthropic_api_key:
|
||||||
|
variant_models = ["anthropic", "anthropic"]
|
||||||
|
else:
|
||||||
|
await throw_error(
|
||||||
|
"No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
|
||||||
|
)
|
||||||
|
raise Exception("No OpenAI or Anthropic key")
|
||||||
|
|
||||||
|
tasks: List[Coroutine[Any, Any, str]] = []
|
||||||
|
for index, model in enumerate(variant_models):
|
||||||
|
if model == "openai":
|
||||||
|
if openai_api_key is None:
|
||||||
|
await throw_error("OpenAI API key is missing.")
|
||||||
|
raise Exception("OpenAI API key is missing.")
|
||||||
|
|
||||||
|
tasks.append(
|
||||||
|
stream_openai_response(
|
||||||
|
prompt_messages,
|
||||||
api_key=openai_api_key,
|
api_key=openai_api_key,
|
||||||
base_url=openai_base_url,
|
base_url=openai_base_url,
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x, i=index: process_chunk(x, i),
|
||||||
model=code_generation_model,
|
model=Llm.GPT_4O_2024_05_13,
|
||||||
)
|
)
|
||||||
exact_llm_version = code_generation_model
|
)
|
||||||
|
elif model == "anthropic":
|
||||||
|
if anthropic_api_key is None:
|
||||||
|
await throw_error("Anthropic API key is missing.")
|
||||||
|
raise Exception("Anthropic API key is missing.")
|
||||||
|
|
||||||
|
tasks.append(
|
||||||
|
stream_claude_response(
|
||||||
|
prompt_messages,
|
||||||
|
api_key=anthropic_api_key,
|
||||||
|
callback=lambda x, i=index: process_chunk(x, i),
|
||||||
|
model=Llm.CLAUDE_3_5_SONNET_2024_06_20,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
completions = await asyncio.gather(*tasks)
|
||||||
|
print("Models used for generation: ", variant_models)
|
||||||
|
|
||||||
except openai.AuthenticationError as e:
|
except openai.AuthenticationError as e:
|
||||||
print("[GENERATE_CODE] Authentication failed", e)
|
print("[GENERATE_CODE] Authentication failed", e)
|
||||||
error_message = (
|
error_message = (
|
||||||
@ -321,42 +400,38 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
return await throw_error(error_message)
|
return await throw_error(error_message)
|
||||||
|
|
||||||
if validated_input_mode == "video":
|
# if validated_input_mode == "video":
|
||||||
completion = extract_tag_content("html", completion)
|
# completion = extract_tag_content("html", completions[0])
|
||||||
|
|
||||||
print("Exact used model for generation: ", exact_llm_version)
|
|
||||||
|
|
||||||
# Strip the completion of everything except the HTML content
|
# Strip the completion of everything except the HTML content
|
||||||
completion = extract_html_content(completion)
|
completions = [extract_html_content(completion) for completion in completions]
|
||||||
|
|
||||||
# Write the messages dict into a log so that we can debug later
|
# Write the messages dict into a log so that we can debug later
|
||||||
write_logs(prompt_messages, completion) # type: ignore
|
write_logs(prompt_messages, completions[0]) # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if should_generate_images:
|
image_generation_tasks = [
|
||||||
await websocket.send_json(
|
process_completion(
|
||||||
{"type": "status", "value": "Generating images..."}
|
websocket,
|
||||||
)
|
|
||||||
updated_html = await generate_images(
|
|
||||||
completion,
|
completion,
|
||||||
api_key=openai_api_key,
|
index,
|
||||||
base_url=openai_base_url,
|
should_generate_images,
|
||||||
image_cache=image_cache,
|
openai_api_key,
|
||||||
)
|
openai_base_url,
|
||||||
else:
|
image_cache,
|
||||||
updated_html = completion
|
|
||||||
await websocket.send_json({"type": "setCode", "value": updated_html})
|
|
||||||
await websocket.send_json(
|
|
||||||
{"type": "status", "value": "Code generation complete."}
|
|
||||||
)
|
)
|
||||||
|
for index, completion in enumerate(completions)
|
||||||
|
]
|
||||||
|
await asyncio.gather(*image_generation_tasks)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print("Image generation failed", e)
|
print("An error occurred during image generation and processing", e)
|
||||||
# Send set code even if image generation fails since that triggers
|
|
||||||
# the frontend to update history
|
|
||||||
await websocket.send_json({"type": "setCode", "value": completion})
|
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{"type": "status", "value": "Image generation failed but code is complete."}
|
{
|
||||||
|
"type": "status",
|
||||||
|
"value": "An error occurred during image generation and processing.",
|
||||||
|
"variantIndex": 0,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|||||||
@ -36,7 +36,12 @@ function App() {
|
|||||||
|
|
||||||
// Outputs
|
// Outputs
|
||||||
setGeneratedCode,
|
setGeneratedCode,
|
||||||
setExecutionConsole,
|
currentVariantIndex,
|
||||||
|
setVariant,
|
||||||
|
appendToVariant,
|
||||||
|
resetVariants,
|
||||||
|
appendExecutionConsole,
|
||||||
|
resetExecutionConsoles,
|
||||||
currentVersion,
|
currentVersion,
|
||||||
setCurrentVersion,
|
setCurrentVersion,
|
||||||
appHistory,
|
appHistory,
|
||||||
@ -106,10 +111,14 @@ function App() {
|
|||||||
const reset = () => {
|
const reset = () => {
|
||||||
setAppState(AppState.INITIAL);
|
setAppState(AppState.INITIAL);
|
||||||
setGeneratedCode("");
|
setGeneratedCode("");
|
||||||
|
resetVariants();
|
||||||
|
resetExecutionConsoles();
|
||||||
|
|
||||||
|
// Inputs
|
||||||
setReferenceImages([]);
|
setReferenceImages([]);
|
||||||
setExecutionConsole([]);
|
|
||||||
setUpdateInstruction("");
|
setUpdateInstruction("");
|
||||||
setIsImportedFromCode(false);
|
setIsImportedFromCode(false);
|
||||||
|
|
||||||
setAppHistory([]);
|
setAppHistory([]);
|
||||||
setCurrentVersion(null);
|
setCurrentVersion(null);
|
||||||
setShouldIncludeResultImage(false);
|
setShouldIncludeResultImage(false);
|
||||||
@ -159,7 +168,7 @@ function App() {
|
|||||||
parentVersion: number | null
|
parentVersion: number | null
|
||||||
) {
|
) {
|
||||||
// Reset the execution console
|
// Reset the execution console
|
||||||
setExecutionConsole([]);
|
resetExecutionConsoles();
|
||||||
|
|
||||||
// Set the app state
|
// Set the app state
|
||||||
setAppState(AppState.CODING);
|
setAppState(AppState.CODING);
|
||||||
@ -171,10 +180,19 @@ function App() {
|
|||||||
wsRef,
|
wsRef,
|
||||||
updatedParams,
|
updatedParams,
|
||||||
// On change
|
// On change
|
||||||
(token) => setGeneratedCode((prev) => prev + token),
|
(token, variant) => {
|
||||||
|
if (variant === currentVariantIndex) {
|
||||||
|
setGeneratedCode((prev) => prev + token);
|
||||||
|
}
|
||||||
|
|
||||||
|
appendToVariant(token, variant);
|
||||||
|
},
|
||||||
// On set code
|
// On set code
|
||||||
(code) => {
|
(code, variant) => {
|
||||||
|
setVariant(code, variant);
|
||||||
setGeneratedCode(code);
|
setGeneratedCode(code);
|
||||||
|
|
||||||
|
// TODO: How to deal with variants?
|
||||||
if (params.generationType === "create") {
|
if (params.generationType === "create") {
|
||||||
setAppHistory([
|
setAppHistory([
|
||||||
{
|
{
|
||||||
@ -214,7 +232,7 @@ function App() {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
// On status update
|
// On status update
|
||||||
(line) => setExecutionConsole((prev) => [...prev, line]),
|
(line, variant) => appendExecutionConsole(variant, line),
|
||||||
// On cancel
|
// On cancel
|
||||||
() => {
|
() => {
|
||||||
cancelCodeGenerationAndReset();
|
cancelCodeGenerationAndReset();
|
||||||
@ -314,6 +332,7 @@ function App() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
setGeneratedCode("");
|
setGeneratedCode("");
|
||||||
|
resetVariants();
|
||||||
setUpdateInstruction("");
|
setUpdateInstruction("");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import { Button } from "../ui/button";
|
|||||||
import { Textarea } from "../ui/textarea";
|
import { Textarea } from "../ui/textarea";
|
||||||
import { useEffect, useRef } from "react";
|
import { useEffect, useRef } from "react";
|
||||||
import HistoryDisplay from "../history/HistoryDisplay";
|
import HistoryDisplay from "../history/HistoryDisplay";
|
||||||
|
import Variants from "../variants/Variants";
|
||||||
|
|
||||||
interface SidebarProps {
|
interface SidebarProps {
|
||||||
showSelectAndEditFeature: boolean;
|
showSelectAndEditFeature: boolean;
|
||||||
@ -35,8 +36,16 @@ function Sidebar({
|
|||||||
shouldIncludeResultImage,
|
shouldIncludeResultImage,
|
||||||
setShouldIncludeResultImage,
|
setShouldIncludeResultImage,
|
||||||
} = useAppStore();
|
} = useAppStore();
|
||||||
const { inputMode, generatedCode, referenceImages, executionConsole } =
|
|
||||||
useProjectStore();
|
const {
|
||||||
|
inputMode,
|
||||||
|
generatedCode,
|
||||||
|
referenceImages,
|
||||||
|
executionConsoles,
|
||||||
|
currentVariantIndex,
|
||||||
|
} = useProjectStore();
|
||||||
|
|
||||||
|
const executionConsole = executionConsoles[currentVariantIndex] || [];
|
||||||
|
|
||||||
// When coding is complete, focus on the update instruction textarea
|
// When coding is complete, focus on the update instruction textarea
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -47,6 +56,8 @@ function Sidebar({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
|
<Variants />
|
||||||
|
|
||||||
{/* Show code preview only when coding */}
|
{/* Show code preview only when coding */}
|
||||||
{appState === AppState.CODING && (
|
{appState === AppState.CODING && (
|
||||||
<div className="flex flex-col">
|
<div className="flex flex-col">
|
||||||
|
|||||||
64
frontend/src/components/variants/Variants.tsx
Normal file
64
frontend/src/components/variants/Variants.tsx
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import { useProjectStore } from "../../store/project-store";
|
||||||
|
|
||||||
|
function Variants() {
|
||||||
|
const {
|
||||||
|
// Inputs
|
||||||
|
referenceImages,
|
||||||
|
|
||||||
|
// Outputs
|
||||||
|
variants,
|
||||||
|
currentVariantIndex,
|
||||||
|
setCurrentVariantIndex,
|
||||||
|
setGeneratedCode,
|
||||||
|
appHistory,
|
||||||
|
setAppHistory,
|
||||||
|
} = useProjectStore();
|
||||||
|
|
||||||
|
function switchVariant(index: number) {
|
||||||
|
const variant = variants[index];
|
||||||
|
setCurrentVariantIndex(index);
|
||||||
|
setGeneratedCode(variant);
|
||||||
|
if (appHistory.length === 1) {
|
||||||
|
setAppHistory([
|
||||||
|
{
|
||||||
|
type: "ai_create",
|
||||||
|
parentIndex: null,
|
||||||
|
code: variant,
|
||||||
|
inputs: { image_url: referenceImages[0] },
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
} else {
|
||||||
|
setAppHistory((prev) => {
|
||||||
|
const newHistory = [...prev];
|
||||||
|
newHistory[newHistory.length - 1].code = variant;
|
||||||
|
return newHistory;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (variants.length === 0) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="mt-4 mb-4">
|
||||||
|
<div className="grid grid-cols-2 gap-2">
|
||||||
|
{variants.map((_, index) => (
|
||||||
|
<div
|
||||||
|
key={index}
|
||||||
|
className={`p-2 border rounded-md cursor-pointer ${
|
||||||
|
index === currentVariantIndex
|
||||||
|
? "bg-blue-100 dark:bg-blue-900"
|
||||||
|
: "bg-gray-50 dark:bg-gray-800 hover:bg-gray-100 dark:hover:bg-gray-700"
|
||||||
|
}`}
|
||||||
|
onClick={() => switchVariant(index)}
|
||||||
|
>
|
||||||
|
<h3 className="font-medium mb-1">Option {index + 1}</h3>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default Variants;
|
||||||
@ -11,12 +11,18 @@ const ERROR_MESSAGE =
|
|||||||
|
|
||||||
const CANCEL_MESSAGE = "Code generation cancelled";
|
const CANCEL_MESSAGE = "Code generation cancelled";
|
||||||
|
|
||||||
|
type WebSocketResponse = {
|
||||||
|
type: "chunk" | "status" | "setCode" | "error";
|
||||||
|
value: string;
|
||||||
|
variantIndex: number;
|
||||||
|
};
|
||||||
|
|
||||||
export function generateCode(
|
export function generateCode(
|
||||||
wsRef: React.MutableRefObject<WebSocket | null>,
|
wsRef: React.MutableRefObject<WebSocket | null>,
|
||||||
params: FullGenerationSettings,
|
params: FullGenerationSettings,
|
||||||
onChange: (chunk: string) => void,
|
onChange: (chunk: string, variantIndex: number) => void,
|
||||||
onSetCode: (code: string) => void,
|
onSetCode: (code: string, variantIndex: number) => void,
|
||||||
onStatusUpdate: (status: string) => void,
|
onStatusUpdate: (status: string, variantIndex: number) => void,
|
||||||
onCancel: () => void,
|
onCancel: () => void,
|
||||||
onComplete: () => void
|
onComplete: () => void
|
||||||
) {
|
) {
|
||||||
@ -31,13 +37,13 @@ export function generateCode(
|
|||||||
});
|
});
|
||||||
|
|
||||||
ws.addEventListener("message", async (event: MessageEvent) => {
|
ws.addEventListener("message", async (event: MessageEvent) => {
|
||||||
const response = JSON.parse(event.data);
|
const response = JSON.parse(event.data) as WebSocketResponse;
|
||||||
if (response.type === "chunk") {
|
if (response.type === "chunk") {
|
||||||
onChange(response.value);
|
onChange(response.value, response.variantIndex);
|
||||||
} else if (response.type === "status") {
|
} else if (response.type === "status") {
|
||||||
onStatusUpdate(response.value);
|
onStatusUpdate(response.value, response.variantIndex);
|
||||||
} else if (response.type === "setCode") {
|
} else if (response.type === "setCode") {
|
||||||
onSetCode(response.value);
|
onSetCode(response.value, response.variantIndex);
|
||||||
} else if (response.type === "error") {
|
} else if (response.type === "error") {
|
||||||
console.error("Error generating code", response.value);
|
console.error("Error generating code", response.value);
|
||||||
toast.error(response.value);
|
toast.error(response.value);
|
||||||
|
|||||||
@ -16,10 +16,17 @@ interface ProjectStore {
|
|||||||
setGeneratedCode: (
|
setGeneratedCode: (
|
||||||
updater: string | ((currentCode: string) => string)
|
updater: string | ((currentCode: string) => string)
|
||||||
) => void;
|
) => void;
|
||||||
executionConsole: string[];
|
|
||||||
setExecutionConsole: (
|
variants: string[];
|
||||||
updater: string[] | ((currentConsole: string[]) => string[])
|
currentVariantIndex: number;
|
||||||
) => void;
|
setCurrentVariantIndex: (index: number) => void;
|
||||||
|
setVariant: (code: string, index: number) => void;
|
||||||
|
appendToVariant: (newTokens: string, index: number) => void;
|
||||||
|
resetVariants: () => void;
|
||||||
|
|
||||||
|
executionConsoles: { [key: number]: string[] };
|
||||||
|
appendExecutionConsole: (variantIndex: number, line: string) => void;
|
||||||
|
resetExecutionConsoles: () => void;
|
||||||
|
|
||||||
// Tracks the currently shown version from app history
|
// Tracks the currently shown version from app history
|
||||||
// TODO: might want to move to appStore
|
// TODO: might want to move to appStore
|
||||||
@ -48,14 +55,41 @@ export const useProjectStore = create<ProjectStore>((set) => ({
|
|||||||
generatedCode:
|
generatedCode:
|
||||||
typeof updater === "function" ? updater(state.generatedCode) : updater,
|
typeof updater === "function" ? updater(state.generatedCode) : updater,
|
||||||
})),
|
})),
|
||||||
executionConsole: [],
|
|
||||||
setExecutionConsole: (updater) =>
|
variants: [],
|
||||||
|
currentVariantIndex: 0,
|
||||||
|
|
||||||
|
setCurrentVariantIndex: (index) => set({ currentVariantIndex: index }),
|
||||||
|
setVariant: (code: string, index: number) =>
|
||||||
|
set((state) => {
|
||||||
|
const newVariants = [...state.variants];
|
||||||
|
while (newVariants.length <= index) {
|
||||||
|
newVariants.push("");
|
||||||
|
}
|
||||||
|
newVariants[index] = code;
|
||||||
|
return { variants: newVariants };
|
||||||
|
}),
|
||||||
|
appendToVariant: (newTokens: string, index: number) =>
|
||||||
|
set((state) => {
|
||||||
|
const newVariants = [...state.variants];
|
||||||
|
newVariants[index] += newTokens;
|
||||||
|
return { variants: newVariants };
|
||||||
|
}),
|
||||||
|
resetVariants: () => set({ variants: [], currentVariantIndex: 0 }),
|
||||||
|
|
||||||
|
executionConsoles: {},
|
||||||
|
|
||||||
|
appendExecutionConsole: (variantIndex: number, line: string) =>
|
||||||
set((state) => ({
|
set((state) => ({
|
||||||
executionConsole:
|
executionConsoles: {
|
||||||
typeof updater === "function"
|
...state.executionConsoles,
|
||||||
? updater(state.executionConsole)
|
[variantIndex]: [
|
||||||
: updater,
|
...(state.executionConsoles[variantIndex] || []),
|
||||||
|
line,
|
||||||
|
],
|
||||||
|
},
|
||||||
})),
|
})),
|
||||||
|
resetExecutionConsoles: () => set({ executionConsoles: {} }),
|
||||||
|
|
||||||
currentVersion: null,
|
currentVersion: null,
|
||||||
setCurrentVersion: (version) => set({ currentVersion: version }),
|
setCurrentVersion: (version) => set({ currentVersion: version }),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user