merge
This commit is contained in:
commit
c107f4eda5
@ -33,10 +33,6 @@ We also just added experimental support for taking a video/screen recording of a
|
||||
|
||||
[Follow me on Twitter for updates](https://twitter.com/_abi_).
|
||||
|
||||
## Sponsors
|
||||
|
||||
<a href="https://konghq.com/products/kong-konnect?utm_medium=referral&utm_source=github&utm_campaign=platform&utm_content=screenshot-to-code" target="_blank" title="Kong - powering the API world"><img src="https://picoapps.xyz/s2c-sponsors/Kong-GitHub-240x100.png"></a>
|
||||
|
||||
## 🚀 Hosted Version
|
||||
|
||||
[Try it live on the hosted version (paid)](https://screenshottocode.com).
|
||||
|
||||
@ -7,15 +7,15 @@ repos:
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: poetry-pytest
|
||||
name: Run pytest with Poetry
|
||||
entry: poetry run --directory backend pytest
|
||||
language: system
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
files: ^backend/
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: poetry-pytest
|
||||
# name: Run pytest with Poetry
|
||||
# entry: poetry run --directory backend pytest
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
# always_run: true
|
||||
# files: ^backend/
|
||||
# - id: poetry-pyright
|
||||
# name: Run pyright with Poetry
|
||||
# entry: poetry run --directory backend pyright
|
||||
|
||||
@ -3,8 +3,12 @@
|
||||
# TODO: Should only be set to true when value is 'True', not any abitrary truthy value
|
||||
import os
|
||||
|
||||
NUM_VARIANTS = 2
|
||||
|
||||
# LLM-related
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
|
||||
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
|
||||
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", None)
|
||||
|
||||
# Image generation (optional)
|
||||
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
|
||||
|
||||
0
backend/fs_logging/__init__.py
Normal file
0
backend/fs_logging/__init__.py
Normal file
23
backend/fs_logging/core.py
Normal file
23
backend/fs_logging/core.py
Normal file
@ -0,0 +1,23 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
|
||||
def write_logs(prompt_messages: list[ChatCompletionMessageParam], completion: str):
|
||||
# Get the logs path from environment, default to the current working directory
|
||||
logs_path = os.environ.get("LOGS_PATH", os.getcwd())
|
||||
|
||||
# Create run_logs directory if it doesn't exist within the specified logs path
|
||||
logs_directory = os.path.join(logs_path, "run_logs")
|
||||
if not os.path.exists(logs_directory):
|
||||
os.makedirs(logs_directory)
|
||||
|
||||
print("Writing to logs directory:", logs_directory)
|
||||
|
||||
# Generate a unique filename using the current timestamp within the logs directory
|
||||
filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json")
|
||||
|
||||
# Write the messages dict into a new file for each run
|
||||
with open(filename, "w") as f:
|
||||
f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
|
||||
@ -28,7 +28,7 @@ async def process_tasks(
|
||||
|
||||
processed_results: List[Union[str, None]] = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
if isinstance(result, BaseException):
|
||||
print(f"An exception occurred: {result}")
|
||||
try:
|
||||
raise result
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, List, cast
|
||||
from anthropic import AsyncAnthropic
|
||||
@ -112,8 +113,12 @@ async def stream_claude_response(
|
||||
temperature = 0.0
|
||||
|
||||
# Translate OpenAI messages to Claude messages
|
||||
system_prompt = cast(str, messages[0].get("content"))
|
||||
claude_messages = [dict(message) for message in messages[1:]]
|
||||
|
||||
# Deep copy messages to avoid modifying the original list
|
||||
cloned_messages = copy.deepcopy(messages)
|
||||
|
||||
system_prompt = cast(str, cloned_messages[0].get("content"))
|
||||
claude_messages = [dict(message) for message in cloned_messages[1:]]
|
||||
for message in claude_messages:
|
||||
if not isinstance(message["content"], list):
|
||||
continue
|
||||
|
||||
@ -8,7 +8,7 @@ STREAM_CHUNK_SIZE = 20
|
||||
|
||||
|
||||
async def mock_completion(
|
||||
process_chunk: Callable[[str], Awaitable[None]], input_mode: InputMode
|
||||
process_chunk: Callable[[str, int], Awaitable[None]], input_mode: InputMode
|
||||
) -> str:
|
||||
code_to_return = (
|
||||
TALLY_FORM_VIDEO_PROMPT_MOCK
|
||||
@ -17,7 +17,7 @@ async def mock_completion(
|
||||
)
|
||||
|
||||
for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE):
|
||||
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE])
|
||||
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE], 0)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
if input_mode == "video":
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
from typing import List, NoReturn, Union
|
||||
|
||||
from typing import Union
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
|
||||
from llm import Llm
|
||||
|
||||
from custom_types import InputMode
|
||||
from image_generation.core import create_alt_url_mapping
|
||||
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
|
||||
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
|
||||
from prompts.text_prompts import SYSTEM_PROMPTS as TEXT_SYSTEM_PROMPTS
|
||||
from prompts.types import Stack
|
||||
from video.utils import assemble_claude_prompt_video
|
||||
|
||||
|
||||
USER_PROMPT = """
|
||||
@ -18,9 +19,65 @@ Generate code for a SVG that looks exactly like this.
|
||||
"""
|
||||
|
||||
|
||||
async def create_prompt(
|
||||
params: dict[str, str], stack: Stack, input_mode: InputMode
|
||||
) -> tuple[list[ChatCompletionMessageParam], dict[str, str]]:
|
||||
|
||||
image_cache: dict[str, str] = {}
|
||||
|
||||
# If this generation started off with imported code, we need to assemble the prompt differently
|
||||
if params.get("isImportedFromCode"):
|
||||
original_imported_code = params["history"][0]
|
||||
prompt_messages = assemble_imported_code_prompt(original_imported_code, stack)
|
||||
for index, text in enumerate(params["history"][1:]):
|
||||
if index % 2 == 0:
|
||||
message: ChatCompletionMessageParam = {
|
||||
"role": "user",
|
||||
"content": text,
|
||||
}
|
||||
else:
|
||||
message: ChatCompletionMessageParam = {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
}
|
||||
prompt_messages.append(message)
|
||||
else:
|
||||
# Assemble the prompt for non-imported code
|
||||
if params.get("resultImage"):
|
||||
prompt_messages = assemble_prompt(
|
||||
params["image"], stack, params["resultImage"]
|
||||
)
|
||||
else:
|
||||
prompt_messages = assemble_prompt(params["image"], stack)
|
||||
|
||||
if params["generationType"] == "update":
|
||||
# Transform the history tree into message format
|
||||
# TODO: Move this to frontend
|
||||
for index, text in enumerate(params["history"]):
|
||||
if index % 2 == 0:
|
||||
message: ChatCompletionMessageParam = {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
}
|
||||
else:
|
||||
message: ChatCompletionMessageParam = {
|
||||
"role": "user",
|
||||
"content": text,
|
||||
}
|
||||
prompt_messages.append(message)
|
||||
|
||||
image_cache = create_alt_url_mapping(params["history"][-2])
|
||||
|
||||
if input_mode == "video":
|
||||
video_data_url = params["image"]
|
||||
prompt_messages = await assemble_claude_prompt_video(video_data_url)
|
||||
|
||||
return prompt_messages, image_cache
|
||||
|
||||
|
||||
def assemble_imported_code_prompt(
|
||||
code: str, stack: Stack, model: Llm
|
||||
) -> List[ChatCompletionMessageParam]:
|
||||
code: str, stack: Stack
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
|
||||
|
||||
user_content = (
|
||||
@ -29,24 +86,12 @@ def assemble_imported_code_prompt(
|
||||
else "Here is the code of the SVG: " + code
|
||||
)
|
||||
|
||||
if model == Llm.CLAUDE_3_5_SONNET_2024_06_20:
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_content + "\n " + user_content,
|
||||
}
|
||||
]
|
||||
else:
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_content,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_content,
|
||||
},
|
||||
]
|
||||
# TODO: Use result_image_data_url
|
||||
|
||||
|
||||
@ -54,11 +99,11 @@ def assemble_prompt(
|
||||
image_data_url: str,
|
||||
stack: Stack,
|
||||
result_image_data_url: Union[str, None] = None,
|
||||
) -> List[ChatCompletionMessageParam]:
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
system_content = SYSTEM_PROMPTS[stack]
|
||||
user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT
|
||||
|
||||
user_content: List[ChatCompletionContentPartParam] = [
|
||||
user_content: list[ChatCompletionContentPartParam] = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data_url, "detail": "high"},
|
||||
@ -93,7 +138,7 @@ def assemble_prompt(
|
||||
def assemble_text_prompt(
|
||||
text_prompt: str,
|
||||
stack: Stack,
|
||||
) -> List[ChatCompletionMessageParam]:
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
|
||||
system_content = TEXT_SYSTEM_PROMPTS[stack]
|
||||
|
||||
|
||||
@ -391,63 +391,81 @@ def test_prompts():
|
||||
|
||||
|
||||
def test_imported_code_prompts():
|
||||
tailwind_prompt = assemble_imported_code_prompt(
|
||||
"code", "html_tailwind", Llm.GPT_4O_2024_05_13
|
||||
)
|
||||
code = "Sample code"
|
||||
|
||||
tailwind_prompt = assemble_imported_code_prompt(code, "html_tailwind")
|
||||
expected_tailwind_prompt = [
|
||||
{"role": "system", "content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": "Here is the code of the app: code"},
|
||||
{
|
||||
"role": "system",
|
||||
"content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
|
||||
+ "\n Here is the code of the app: "
|
||||
+ code,
|
||||
}
|
||||
]
|
||||
assert tailwind_prompt == expected_tailwind_prompt
|
||||
|
||||
html_css_prompt = assemble_imported_code_prompt(
|
||||
"code", "html_css", Llm.GPT_4O_2024_05_13
|
||||
)
|
||||
html_css_prompt = assemble_imported_code_prompt(code, "html_css")
|
||||
expected_html_css_prompt = [
|
||||
{"role": "system", "content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": "Here is the code of the app: code"},
|
||||
{
|
||||
"role": "system",
|
||||
"content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT
|
||||
+ "\n Here is the code of the app: "
|
||||
+ code,
|
||||
}
|
||||
]
|
||||
assert html_css_prompt == expected_html_css_prompt
|
||||
|
||||
react_tailwind_prompt = assemble_imported_code_prompt(
|
||||
"code", "react_tailwind", Llm.GPT_4O_2024_05_13
|
||||
)
|
||||
react_tailwind_prompt = assemble_imported_code_prompt(code, "react_tailwind")
|
||||
expected_react_tailwind_prompt = [
|
||||
{"role": "system", "content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": "Here is the code of the app: code"},
|
||||
{
|
||||
"role": "system",
|
||||
"content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT
|
||||
+ "\n Here is the code of the app: "
|
||||
+ code,
|
||||
}
|
||||
]
|
||||
assert react_tailwind_prompt == expected_react_tailwind_prompt
|
||||
|
||||
bootstrap_prompt = assemble_imported_code_prompt(
|
||||
"code", "bootstrap", Llm.GPT_4O_2024_05_13
|
||||
)
|
||||
bootstrap_prompt = assemble_imported_code_prompt(code, "bootstrap")
|
||||
expected_bootstrap_prompt = [
|
||||
{"role": "system", "content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": "Here is the code of the app: code"},
|
||||
{
|
||||
"role": "system",
|
||||
"content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT
|
||||
+ "\n Here is the code of the app: "
|
||||
+ code,
|
||||
}
|
||||
]
|
||||
assert bootstrap_prompt == expected_bootstrap_prompt
|
||||
|
||||
ionic_tailwind = assemble_imported_code_prompt(
|
||||
"code", "ionic_tailwind", Llm.GPT_4O_2024_05_13
|
||||
)
|
||||
ionic_tailwind = assemble_imported_code_prompt(code, "ionic_tailwind")
|
||||
expected_ionic_tailwind = [
|
||||
{"role": "system", "content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": "Here is the code of the app: code"},
|
||||
{
|
||||
"role": "system",
|
||||
"content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
|
||||
+ "\n Here is the code of the app: "
|
||||
+ code,
|
||||
}
|
||||
]
|
||||
assert ionic_tailwind == expected_ionic_tailwind
|
||||
|
||||
vue_tailwind = assemble_imported_code_prompt(
|
||||
"code", "vue_tailwind", Llm.GPT_4O_2024_05_13
|
||||
)
|
||||
vue_tailwind = assemble_imported_code_prompt(code, "vue_tailwind")
|
||||
expected_vue_tailwind = [
|
||||
{"role": "system", "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT},
|
||||
{"role": "user", "content": "Here is the code of the app: code"},
|
||||
{
|
||||
"role": "system",
|
||||
"content": IMPORTED_CODE_VUE_TAILWIND_PROMPT
|
||||
+ "\n Here is the code of the app: "
|
||||
+ code,
|
||||
}
|
||||
]
|
||||
assert vue_tailwind == expected_vue_tailwind
|
||||
|
||||
svg = assemble_imported_code_prompt("code", "svg", Llm.GPT_4O_2024_05_13)
|
||||
svg = assemble_imported_code_prompt(code, "svg")
|
||||
expected_svg = [
|
||||
{"role": "system", "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": "Here is the code of the SVG: code"},
|
||||
{
|
||||
"role": "system",
|
||||
"content": IMPORTED_CODE_SVG_SYSTEM_PROMPT
|
||||
+ "\n Here is the code of the SVG: "
|
||||
+ code,
|
||||
}
|
||||
]
|
||||
assert svg == expected_svg
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import traceback
|
||||
from fastapi import APIRouter, WebSocket
|
||||
import openai
|
||||
import sentry_sdk
|
||||
from codegen.utils import extract_html_content
|
||||
from config import (
|
||||
ANTHROPIC_API_KEY,
|
||||
IS_PROD,
|
||||
NUM_VARIANTS,
|
||||
OPENAI_API_KEY,
|
||||
OPENAI_BASE_URL,
|
||||
REPLICATE_API_KEY,
|
||||
SHOULD_MOCK_AI_RESPONSE,
|
||||
)
|
||||
@ -18,77 +21,101 @@ from llm import (
|
||||
stream_claude_response_native,
|
||||
stream_openai_response,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from mock_llm import mock_completion
|
||||
from typing import Dict, List, Union, cast, get_args
|
||||
from image_generation.core import create_alt_url_mapping, generate_images
|
||||
from prompts import assemble_imported_code_prompt, assemble_prompt, assemble_text_prompt
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import Dict, List, cast, get_args
|
||||
from image_generation.core import generate_images
|
||||
from routes.logging_utils import PaymentMethod, send_to_saas_backend
|
||||
from routes.saas_utils import does_user_have_subscription_credits
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Literal, cast, get_args
|
||||
from image_generation.core import generate_images
|
||||
from prompts import create_prompt
|
||||
from prompts.claude_prompts import VIDEO_PROMPT
|
||||
from prompts.types import Stack
|
||||
from utils import pprint_prompt
|
||||
|
||||
# from utils import pprint_prompt
|
||||
from video.utils import extract_tag_content, assemble_claude_prompt_video
|
||||
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str):
|
||||
# Get the logs path from environment, default to the current working directory
|
||||
logs_path = os.environ.get("LOGS_PATH", os.getcwd())
|
||||
|
||||
# Create run_logs directory if it doesn't exist within the specified logs path
|
||||
logs_directory = os.path.join(logs_path, "run_logs")
|
||||
if not os.path.exists(logs_directory):
|
||||
os.makedirs(logs_directory)
|
||||
|
||||
print("Writing to logs directory:", logs_directory)
|
||||
|
||||
# Generate a unique filename using the current timestamp within the logs directory
|
||||
filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json")
|
||||
|
||||
# Write the messages dict into a new file for each run
|
||||
with open(filename, "w") as f:
|
||||
f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
|
||||
# Auto-upgrade usage of older models
|
||||
def auto_upgrade_model(code_generation_model: Llm) -> Llm:
|
||||
if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}:
|
||||
print(
|
||||
f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13"
|
||||
)
|
||||
return Llm.GPT_4O_2024_05_13
|
||||
elif code_generation_model == Llm.CLAUDE_3_SONNET:
|
||||
print(
|
||||
f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20"
|
||||
)
|
||||
return Llm.CLAUDE_3_5_SONNET_2024_06_20
|
||||
return code_generation_model
|
||||
|
||||
|
||||
@router.websocket("/generate-code")
|
||||
async def stream_code(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
||||
print("Incoming websocket connection...")
|
||||
|
||||
async def throw_error(
|
||||
message: str,
|
||||
# Generate images, if needed
|
||||
async def perform_image_generation(
|
||||
completion: str,
|
||||
should_generate_images: bool,
|
||||
openai_api_key: str | None,
|
||||
openai_base_url: str | None,
|
||||
image_cache: dict[str, str],
|
||||
):
|
||||
await websocket.send_json({"type": "error", "value": message})
|
||||
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
|
||||
replicate_api_key = REPLICATE_API_KEY
|
||||
if not should_generate_images:
|
||||
return completion
|
||||
|
||||
# TODO: Are the values always strings?
|
||||
params: Dict[str, str] = await websocket.receive_json()
|
||||
if replicate_api_key:
|
||||
image_generation_model = "sdxl-lightning"
|
||||
api_key = replicate_api_key
|
||||
else:
|
||||
if not openai_api_key:
|
||||
print(
|
||||
"No OpenAI API key and Replicate key found. Skipping image generation."
|
||||
)
|
||||
return completion
|
||||
image_generation_model = "dalle3"
|
||||
api_key = openai_api_key
|
||||
|
||||
# Read the code config settings from the request. Fall back to default if not provided.
|
||||
generated_code_config = ""
|
||||
if "generatedCodeConfig" in params and params["generatedCodeConfig"]:
|
||||
generated_code_config = params["generatedCodeConfig"]
|
||||
if not generated_code_config in get_args(Stack):
|
||||
print("Generating images with model: ", image_generation_model)
|
||||
|
||||
return await generate_images(
|
||||
completion,
|
||||
api_key=api_key,
|
||||
base_url=openai_base_url,
|
||||
image_cache=image_cache,
|
||||
model=image_generation_model,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedParams:
|
||||
stack: Stack
|
||||
input_mode: InputMode
|
||||
code_generation_model: Llm
|
||||
should_generate_images: bool
|
||||
openai_api_key: str | None
|
||||
anthropic_api_key: str | None
|
||||
openai_base_url: str | None
|
||||
payment_method: PaymentMethod
|
||||
|
||||
|
||||
async def extract_params(
|
||||
params: Dict[str, str], throw_error: Callable[[str], Coroutine[Any, Any, None]]
|
||||
) -> ExtractedParams:
|
||||
# Read the code config settings (stack) from the request.
|
||||
generated_code_config = params.get("generatedCodeConfig", "")
|
||||
if generated_code_config not in get_args(Stack):
|
||||
await throw_error(f"Invalid generated code config: {generated_code_config}")
|
||||
return
|
||||
# Cast the variable to the Stack type
|
||||
valid_stack = cast(Stack, generated_code_config)
|
||||
raise ValueError(f"Invalid generated code config: {generated_code_config}")
|
||||
validated_stack = cast(Stack, generated_code_config)
|
||||
|
||||
# Validate the input mode
|
||||
input_mode = params.get("inputMode", "image")
|
||||
if not input_mode in get_args(InputMode):
|
||||
input_mode = params.get("inputMode")
|
||||
if input_mode not in get_args(InputMode):
|
||||
await throw_error(f"Invalid input mode: {input_mode}")
|
||||
raise Exception(f"Invalid input mode: {input_mode}")
|
||||
# Cast the variable to the right type
|
||||
raise ValueError(f"Invalid input mode: {input_mode}")
|
||||
validated_input_mode = cast(InputMode, input_mode)
|
||||
|
||||
# Read the model from the request. Fall back to default if not provided.
|
||||
@ -97,42 +124,19 @@ async def stream_code(websocket: WebSocket):
|
||||
)
|
||||
try:
|
||||
code_generation_model = convert_frontend_str_to_llm(code_generation_model_str)
|
||||
except:
|
||||
except ValueError:
|
||||
await throw_error(f"Invalid model: {code_generation_model_str}")
|
||||
raise Exception(f"Invalid model: {code_generation_model_str}")
|
||||
|
||||
# Auto-upgrade usage of older models
|
||||
if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}:
|
||||
print(
|
||||
f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13"
|
||||
)
|
||||
code_generation_model = Llm.GPT_4O_2024_05_13
|
||||
elif code_generation_model == Llm.CLAUDE_3_SONNET:
|
||||
print(
|
||||
f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20"
|
||||
)
|
||||
code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20
|
||||
|
||||
exact_llm_version = None
|
||||
|
||||
print(
|
||||
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
|
||||
)
|
||||
|
||||
# Track how this generation is being paid for
|
||||
payment_method: PaymentMethod = PaymentMethod.UNKNOWN
|
||||
# Track the OpenAI API key to use
|
||||
openai_api_key = None
|
||||
raise ValueError(f"Invalid model: {code_generation_model_str}")
|
||||
|
||||
auth_token = params.get("authToken")
|
||||
if not auth_token:
|
||||
await throw_error("You need to be logged in to use screenshot to code")
|
||||
raise Exception("No auth token")
|
||||
|
||||
# Get the OpenAI key by waterfalling through the different payment methods
|
||||
# 1. Subscription
|
||||
# 2. User's API key from client-side settings dialog
|
||||
# 3. User's API key from environment variable
|
||||
openai_api_key = None
|
||||
|
||||
# Track how this generation is being paid for
|
||||
payment_method: PaymentMethod = PaymentMethod.UNKNOWN
|
||||
|
||||
# If the user is a subscriber, use the platform API key
|
||||
# TODO: Rename does_user_have_subscription_credits
|
||||
@ -157,151 +161,152 @@ async def stream_code(websocket: WebSocket):
|
||||
await throw_error("Unknown error occurred. Contact support.")
|
||||
raise Exception("Unknown error occurred when checking subscription credits")
|
||||
|
||||
# For non-subscribers, use the user's API key from client-side settings dialog
|
||||
if not openai_api_key:
|
||||
openai_api_key = params.get("openAiApiKey", None)
|
||||
payment_method = PaymentMethod.OPENAI_API_KEY
|
||||
print("Using OpenAI API key from client-side settings dialog")
|
||||
|
||||
# If we still don't have an API key, use the user's API key from environment variable
|
||||
if not openai_api_key:
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
openai_api_key = get_from_settings_dialog_or_env(
|
||||
params, "openAiApiKey", OPENAI_API_KEY
|
||||
)
|
||||
payment_method = PaymentMethod.OPENAI_API_KEY
|
||||
if openai_api_key:
|
||||
# TODO
|
||||
print("Using OpenAI API key from environment variable")
|
||||
|
||||
if not openai_api_key and (
|
||||
code_generation_model == Llm.GPT_4_VISION
|
||||
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
|
||||
or code_generation_model == Llm.GPT_4O_2024_05_13
|
||||
):
|
||||
print("OpenAI API key not found")
|
||||
await throw_error(
|
||||
"Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support."
|
||||
)
|
||||
raise Exception("No OpenAI API key found")
|
||||
# if not openai_api_key and (
|
||||
# code_generation_model == Llm.GPT_4_VISION
|
||||
# or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
|
||||
# or code_generation_model == Llm.GPT_4O_2024_05_13
|
||||
# ):
|
||||
# print("OpenAI API key not found")
|
||||
# await throw_error(
|
||||
# "Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support."
|
||||
# )
|
||||
# raise Exception("No OpenAI API key found")
|
||||
|
||||
# Get the Anthropic API key from the request. Fall back to environment variable if not provided.
|
||||
# TODO: Do not allow usage of key
|
||||
# If neither is provided, we throw an error later only if Claude is used.
|
||||
anthropic_api_key = None
|
||||
if "anthropicApiKey" in params and params["anthropicApiKey"]:
|
||||
anthropic_api_key = params["anthropicApiKey"]
|
||||
print("Using Anthropic API key from client-side settings dialog")
|
||||
else:
|
||||
anthropic_api_key = ANTHROPIC_API_KEY
|
||||
if anthropic_api_key:
|
||||
print("Using Anthropic API key from environment variable")
|
||||
anthropic_api_key = get_from_settings_dialog_or_env(
|
||||
params, "anthropicApiKey", ANTHROPIC_API_KEY
|
||||
)
|
||||
|
||||
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
|
||||
openai_base_url: Union[str, None] = None
|
||||
# Base URL for OpenAI API
|
||||
openai_base_url: str | None = None
|
||||
# Disable user-specified OpenAI Base URL in prod
|
||||
if not os.environ.get("IS_PROD"):
|
||||
if "openAiBaseURL" in params and params["openAiBaseURL"]:
|
||||
openai_base_url = params["openAiBaseURL"]
|
||||
print("Using OpenAI Base URL from client-side settings dialog")
|
||||
else:
|
||||
openai_base_url = os.environ.get("OPENAI_BASE_URL")
|
||||
if openai_base_url:
|
||||
print("Using OpenAI Base URL from environment variable")
|
||||
|
||||
if not IS_PROD:
|
||||
openai_base_url = get_from_settings_dialog_or_env(
|
||||
params, "openAiBaseURL", OPENAI_BASE_URL
|
||||
)
|
||||
if not openai_base_url:
|
||||
print("Using official OpenAI URL")
|
||||
|
||||
# Get the image generation flag from the request. Fall back to True if not provided.
|
||||
should_generate_images = (
|
||||
params["isImageGenerationEnabled"]
|
||||
if "isImageGenerationEnabled" in params
|
||||
else True
|
||||
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
|
||||
|
||||
return ExtractedParams(
|
||||
stack=validated_stack,
|
||||
input_mode=validated_input_mode,
|
||||
code_generation_model=code_generation_model,
|
||||
should_generate_images=should_generate_images,
|
||||
openai_api_key=openai_api_key,
|
||||
anthropic_api_key=anthropic_api_key,
|
||||
openai_base_url=openai_base_url,
|
||||
payment_method=payment_method,
|
||||
)
|
||||
|
||||
print("generating code...")
|
||||
await websocket.send_json({"type": "status", "value": "Generating code..."})
|
||||
|
||||
async def process_chunk(content: str):
|
||||
await websocket.send_json({"type": "chunk", "value": content})
|
||||
def get_from_settings_dialog_or_env(
|
||||
params: dict[str, str], key: str, env_var: str | None
|
||||
) -> str | None:
|
||||
value = params.get(key)
|
||||
if value:
|
||||
print(f"Using {key} from client-side settings dialog")
|
||||
return value
|
||||
|
||||
if env_var:
|
||||
print(f"Using {key} from environment variable")
|
||||
return env_var
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.websocket("/generate-code")
|
||||
async def stream_code(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
print("Incoming websocket connection...")
|
||||
|
||||
## Communication protocol setup
|
||||
async def throw_error(
|
||||
message: str,
|
||||
):
|
||||
print(message)
|
||||
await websocket.send_json({"type": "error", "value": message})
|
||||
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
|
||||
|
||||
async def send_message(
|
||||
type: Literal["chunk", "status", "setCode", "error"],
|
||||
value: str,
|
||||
variantIndex: int,
|
||||
):
|
||||
# Print for debugging on the backend
|
||||
if type == "error":
|
||||
print(f"Error (variant {variantIndex}): {value}")
|
||||
elif type == "status":
|
||||
print(f"Status (variant {variantIndex}): {value}")
|
||||
|
||||
await websocket.send_json(
|
||||
{"type": type, "value": value, "variantIndex": variantIndex}
|
||||
)
|
||||
|
||||
## Parameter extract and validation
|
||||
|
||||
# TODO: Are the values always strings?
|
||||
params: dict[str, str] = await websocket.receive_json()
|
||||
print("Received params")
|
||||
|
||||
extracted_params = await extract_params(params, throw_error)
|
||||
stack = extracted_params.stack
|
||||
input_mode = extracted_params.input_mode
|
||||
code_generation_model = extracted_params.code_generation_model
|
||||
openai_api_key = extracted_params.openai_api_key
|
||||
openai_base_url = extracted_params.openai_base_url
|
||||
anthropic_api_key = extracted_params.anthropic_api_key
|
||||
should_generate_images = extracted_params.should_generate_images
|
||||
payment_method = extracted_params.payment_method
|
||||
|
||||
# Auto-upgrade usage of older models
|
||||
code_generation_model = auto_upgrade_model(code_generation_model)
|
||||
|
||||
print(
|
||||
f"Generating {stack} code in {input_mode} mode using {code_generation_model}..."
|
||||
)
|
||||
|
||||
for i in range(NUM_VARIANTS):
|
||||
await send_message("status", "Generating code...", i)
|
||||
|
||||
### Prompt creation
|
||||
|
||||
# Image cache for updates so that we don't have to regenerate images
|
||||
image_cache: Dict[str, str] = {}
|
||||
|
||||
# If this generation started off with imported code, we need to assemble the prompt differently
|
||||
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
|
||||
original_imported_code = params["history"][0]
|
||||
prompt_messages = assemble_imported_code_prompt(
|
||||
original_imported_code, valid_stack, code_generation_model
|
||||
)
|
||||
for index, text in enumerate(params["history"][1:]):
|
||||
# TODO: Remove after "Select and edit" is fully implemented
|
||||
if "referring to this element specifically" in text:
|
||||
sentry_sdk.capture_exception(Exception("Point and edit used"))
|
||||
|
||||
if index % 2 == 0:
|
||||
message: ChatCompletionMessageParam = {
|
||||
"role": "user",
|
||||
"content": text,
|
||||
}
|
||||
else:
|
||||
message: ChatCompletionMessageParam = {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
}
|
||||
prompt_messages.append(message)
|
||||
else:
|
||||
# Assemble the prompt
|
||||
try:
|
||||
if validated_input_mode == "image":
|
||||
if params.get("resultImage") and params["resultImage"]:
|
||||
prompt_messages = assemble_prompt(
|
||||
params["image"], valid_stack, params["resultImage"]
|
||||
)
|
||||
else:
|
||||
prompt_messages = assemble_prompt(params["image"], valid_stack)
|
||||
elif validated_input_mode == "text":
|
||||
sentry_sdk.capture_exception(Exception("Text generation used"))
|
||||
prompt_messages = assemble_text_prompt(params["image"], valid_stack)
|
||||
else:
|
||||
await throw_error("Invalid input mode")
|
||||
return
|
||||
prompt_messages, image_cache = await create_prompt(params, stack, input_mode)
|
||||
except:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"value": "Error assembling prompt. Contact support at support@picoapps.xyz",
|
||||
}
|
||||
await throw_error(
|
||||
"Error assembling prompt. Contact support at support@picoapps.xyz"
|
||||
)
|
||||
await websocket.close()
|
||||
return
|
||||
|
||||
# Transform the history tree into message format for updates
|
||||
if params["generationType"] == "update":
|
||||
# TODO: Move this to frontend
|
||||
for index, text in enumerate(params["history"]):
|
||||
if index % 2 == 0:
|
||||
message: ChatCompletionMessageParam = {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
}
|
||||
else:
|
||||
message: ChatCompletionMessageParam = {
|
||||
"role": "user",
|
||||
"content": text,
|
||||
}
|
||||
prompt_messages.append(message)
|
||||
|
||||
image_cache = create_alt_url_mapping(params["history"][-2])
|
||||
|
||||
if validated_input_mode == "video":
|
||||
video_data_url = params["image"]
|
||||
prompt_messages = await assemble_claude_prompt_video(video_data_url)
|
||||
raise
|
||||
|
||||
# pprint_prompt(prompt_messages) # type: ignore
|
||||
|
||||
### Code generation
|
||||
|
||||
async def process_chunk(content: str, variantIndex: int):
|
||||
await send_message("chunk", content, variantIndex)
|
||||
|
||||
if SHOULD_MOCK_AI_RESPONSE:
|
||||
completion = await mock_completion(
|
||||
process_chunk, input_mode=validated_input_mode
|
||||
)
|
||||
completions = [await mock_completion(process_chunk, input_mode=input_mode)]
|
||||
else:
|
||||
try:
|
||||
if validated_input_mode == "video":
|
||||
if input_mode == "video":
|
||||
if IS_PROD:
|
||||
raise Exception("Video mode is not supported in prod")
|
||||
|
||||
@ -311,48 +316,66 @@ async def stream_code(websocket: WebSocket):
|
||||
)
|
||||
raise Exception("No Anthropic key")
|
||||
|
||||
completion = await stream_claude_response_native(
|
||||
completions = [
|
||||
await stream_claude_response_native(
|
||||
system_prompt=VIDEO_PROMPT,
|
||||
messages=prompt_messages, # type: ignore
|
||||
api_key=anthropic_api_key,
|
||||
callback=lambda x: process_chunk(x),
|
||||
callback=lambda x: process_chunk(x, 0),
|
||||
model=Llm.CLAUDE_3_OPUS,
|
||||
include_thinking=True,
|
||||
)
|
||||
exact_llm_version = Llm.CLAUDE_3_OPUS
|
||||
elif (
|
||||
code_generation_model == Llm.CLAUDE_3_SONNET
|
||||
or code_generation_model == Llm.CLAUDE_3_5_SONNET_2024_06_20
|
||||
):
|
||||
if not anthropic_api_key:
|
||||
await throw_error(
|
||||
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
|
||||
)
|
||||
raise Exception("No Anthropic key")
|
||||
|
||||
# Do not allow non-subscribers to use Claude
|
||||
if payment_method != PaymentMethod.SUBSCRIPTION:
|
||||
await throw_error(
|
||||
"Please subscribe to a paid plan to use the Claude models"
|
||||
)
|
||||
raise Exception("Not subscribed to a paid plan for Claude")
|
||||
|
||||
completion = await stream_claude_response(
|
||||
prompt_messages, # type: ignore
|
||||
api_key=anthropic_api_key,
|
||||
callback=lambda x: process_chunk(x),
|
||||
model=code_generation_model,
|
||||
)
|
||||
exact_llm_version = code_generation_model
|
||||
]
|
||||
else:
|
||||
completion = await stream_openai_response(
|
||||
prompt_messages, # type: ignore
|
||||
|
||||
# Depending on the presence and absence of various keys,
|
||||
# we decide which models to run
|
||||
variant_models = []
|
||||
if openai_api_key and anthropic_api_key:
|
||||
variant_models = ["openai", "anthropic"]
|
||||
elif openai_api_key:
|
||||
variant_models = ["openai", "openai"]
|
||||
elif anthropic_api_key:
|
||||
variant_models = ["anthropic", "anthropic"]
|
||||
else:
|
||||
await throw_error(
|
||||
"No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog. If you add it to .env, make sure to restart the backend server."
|
||||
)
|
||||
raise Exception("No OpenAI or Anthropic key")
|
||||
|
||||
tasks: List[Coroutine[Any, Any, str]] = []
|
||||
for index, model in enumerate(variant_models):
|
||||
if model == "openai":
|
||||
if openai_api_key is None:
|
||||
await throw_error("OpenAI API key is missing.")
|
||||
raise Exception("OpenAI API key is missing.")
|
||||
|
||||
tasks.append(
|
||||
stream_openai_response(
|
||||
prompt_messages,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
callback=lambda x: process_chunk(x),
|
||||
model=code_generation_model,
|
||||
callback=lambda x, i=index: process_chunk(x, i),
|
||||
model=Llm.GPT_4O_2024_05_13,
|
||||
)
|
||||
exact_llm_version = code_generation_model
|
||||
)
|
||||
elif model == "anthropic":
|
||||
if anthropic_api_key is None:
|
||||
await throw_error("Anthropic API key is missing.")
|
||||
raise Exception("Anthropic API key is missing.")
|
||||
|
||||
tasks.append(
|
||||
stream_claude_response(
|
||||
prompt_messages,
|
||||
api_key=anthropic_api_key,
|
||||
callback=lambda x, i=index: process_chunk(x, i),
|
||||
model=Llm.CLAUDE_3_5_SONNET_2024_06_20,
|
||||
)
|
||||
)
|
||||
|
||||
completions = await asyncio.gather(*tasks)
|
||||
print("Models used for generation: ", variant_models)
|
||||
|
||||
except openai.AuthenticationError as e:
|
||||
print("[GENERATE_CODE] Authentication failed", e)
|
||||
error_message = (
|
||||
@ -388,13 +411,10 @@ async def stream_code(websocket: WebSocket):
|
||||
)
|
||||
return await throw_error(error_message)
|
||||
|
||||
if validated_input_mode == "video":
|
||||
completion = extract_tag_content("html", completion)
|
||||
|
||||
print("Exact used model for generation: ", exact_llm_version)
|
||||
## Post-processing
|
||||
|
||||
# Strip the completion of everything except the HTML content
|
||||
completion = extract_html_content(completion)
|
||||
completions = [extract_html_content(completion) for completion in completions]
|
||||
|
||||
# Write the messages dict into a log so that we can debug later
|
||||
# write_logs(prompt_messages, completion) # type: ignore
|
||||
@ -402,54 +422,44 @@ async def stream_code(websocket: WebSocket):
|
||||
if IS_PROD:
|
||||
# Catch any errors from sending to SaaS backend and continue
|
||||
try:
|
||||
assert exact_llm_version is not None, "exact_llm_version is not set"
|
||||
# TODO*
|
||||
# assert exact_llm_version is not None, "exact_llm_version is not set"
|
||||
await send_to_saas_backend(
|
||||
prompt_messages,
|
||||
completion,
|
||||
# TODO*: Store both completions
|
||||
completions[0],
|
||||
payment_method=payment_method,
|
||||
llm_version=exact_llm_version,
|
||||
stack=valid_stack,
|
||||
# TODO*
|
||||
llm_version=Llm.CLAUDE_3_5_SONNET_2024_06_20,
|
||||
stack=stack,
|
||||
is_imported_from_code=bool(params.get("isImportedFromCode", False)),
|
||||
includes_result_image=bool(params.get("resultImage", False)),
|
||||
input_mode=validated_input_mode,
|
||||
input_mode=input_mode,
|
||||
auth_token=params["authToken"],
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error sending to SaaS backend", e)
|
||||
|
||||
try:
|
||||
if should_generate_images:
|
||||
await websocket.send_json(
|
||||
{"type": "status", "value": "Generating images..."}
|
||||
)
|
||||
image_generation_model = "sdxl-lightning" if REPLICATE_API_KEY else "dalle3"
|
||||
print("Generating images with model: ", image_generation_model)
|
||||
## Image Generation
|
||||
|
||||
updated_html = await generate_images(
|
||||
for index, _ in enumerate(completions):
|
||||
await send_message("status", "Generating images...", index)
|
||||
|
||||
image_generation_tasks = [
|
||||
perform_image_generation(
|
||||
completion,
|
||||
api_key=(
|
||||
REPLICATE_API_KEY
|
||||
if image_generation_model == "sdxl-lightning"
|
||||
else openai_api_key
|
||||
),
|
||||
base_url=openai_base_url,
|
||||
image_cache=image_cache,
|
||||
model=image_generation_model,
|
||||
)
|
||||
else:
|
||||
updated_html = completion
|
||||
await websocket.send_json({"type": "setCode", "value": updated_html})
|
||||
await websocket.send_json(
|
||||
{"type": "status", "value": "Code generation complete."}
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print("Image generation failed", e)
|
||||
# Send set code even if image generation fails since that triggers
|
||||
# the frontend to update history
|
||||
await websocket.send_json({"type": "setCode", "value": completion})
|
||||
await websocket.send_json(
|
||||
{"type": "status", "value": "Image generation failed but code is complete."}
|
||||
should_generate_images,
|
||||
openai_api_key,
|
||||
openai_base_url,
|
||||
image_cache,
|
||||
)
|
||||
for completion in completions
|
||||
]
|
||||
|
||||
updated_completions = await asyncio.gather(*image_generation_tasks)
|
||||
|
||||
for index, updated_html in enumerate(updated_completions):
|
||||
await send_message("setCode", updated_html, index)
|
||||
await send_message("status", "Code generation complete.", index)
|
||||
|
||||
await websocket.close()
|
||||
|
||||
@ -43,6 +43,7 @@
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"html2canvas": "^1.4.1",
|
||||
"posthog-js": "^1.128.1",
|
||||
"nanoid": "^5.0.7",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-dropzone": "^14.2.3",
|
||||
|
||||
@ -9,8 +9,7 @@ import { usePersistedState } from "./hooks/usePersistedState";
|
||||
import TermsOfServiceDialog from "./components/TermsOfServiceDialog";
|
||||
import { USER_CLOSE_WEB_SOCKET_CODE } from "./constants";
|
||||
import { addEvent } from "./lib/analytics";
|
||||
import { History } from "./components/history/history_types";
|
||||
import { extractHistoryTree } from "./components/history/utils";
|
||||
import { extractHistory } from "./components/history/utils";
|
||||
import toast from "react-hot-toast";
|
||||
import { useAuth } from "@clerk/clerk-react";
|
||||
import { useStore } from "./store/store";
|
||||
@ -27,6 +26,8 @@ import { GenerationSettings } from "./components/settings/GenerationSettings";
|
||||
import StartPane from "./components/start-pane/StartPane";
|
||||
import { takeScreenshot } from "./lib/takeScreenshot";
|
||||
import Sidebar from "./components/sidebar/Sidebar";
|
||||
import { Commit } from "./components/commits/types";
|
||||
import { createCommit } from "./components/commits/utils";
|
||||
|
||||
interface Props {
|
||||
navbarComponent?: JSX.Element;
|
||||
@ -49,13 +50,19 @@ function App({ navbarComponent }: Props) {
|
||||
referenceImages,
|
||||
setReferenceImages,
|
||||
|
||||
head,
|
||||
commits,
|
||||
addCommit,
|
||||
removeCommit,
|
||||
setHead,
|
||||
appendCommitCode,
|
||||
setCommitCode,
|
||||
resetCommits,
|
||||
resetHead,
|
||||
|
||||
// Outputs
|
||||
setGeneratedCode,
|
||||
setExecutionConsole,
|
||||
currentVersion,
|
||||
setCurrentVersion,
|
||||
appHistory,
|
||||
setAppHistory,
|
||||
appendExecutionConsole,
|
||||
resetExecutionConsoles,
|
||||
} = useProjectStore();
|
||||
|
||||
const {
|
||||
@ -119,32 +126,31 @@ function App({ navbarComponent }: Props) {
|
||||
// Functions
|
||||
const reset = () => {
|
||||
setAppState(AppState.INITIAL);
|
||||
setGeneratedCode("");
|
||||
setReferenceImages([]);
|
||||
setInitialPrompt("");
|
||||
setExecutionConsole([]);
|
||||
setUpdateInstruction("");
|
||||
setIsImportedFromCode(false);
|
||||
setAppHistory([]);
|
||||
setCurrentVersion(null);
|
||||
setShouldIncludeResultImage(false);
|
||||
setUpdateInstruction("");
|
||||
disableInSelectAndEditMode();
|
||||
resetExecutionConsoles();
|
||||
|
||||
resetCommits();
|
||||
resetHead();
|
||||
|
||||
// Inputs
|
||||
setInputMode("image");
|
||||
setReferenceImages([]);
|
||||
setIsImportedFromCode(false);
|
||||
};
|
||||
|
||||
const regenerate = () => {
|
||||
if (currentVersion === null) {
|
||||
// This would be a error that I log to Sentry
|
||||
addEvent("RegenerateCurrentVersionNull");
|
||||
if (head === null) {
|
||||
toast.error(
|
||||
"No current version set. Please open a Github issue as this shouldn't happen."
|
||||
"No current version set. Please contact support via chat or Github."
|
||||
);
|
||||
return;
|
||||
throw new Error("Regenerate called with no head");
|
||||
}
|
||||
|
||||
// Retrieve the previous command
|
||||
const previousCommand = appHistory[currentVersion];
|
||||
if (previousCommand.type !== "ai_create") {
|
||||
addEvent("RegenerateNotFirstVersion");
|
||||
const currentCommit = commits[head];
|
||||
if (currentCommit.type !== "ai_create") {
|
||||
toast.error("Only the first version can be regenerated.");
|
||||
return;
|
||||
}
|
||||
@ -165,28 +171,32 @@ function App({ navbarComponent }: Props) {
|
||||
addEvent("Cancel");
|
||||
|
||||
wsRef.current?.close?.(USER_CLOSE_WEB_SOCKET_CODE);
|
||||
// make sure stop can correct the state even if the websocket is already closed
|
||||
cancelCodeGenerationAndReset();
|
||||
};
|
||||
|
||||
// Used for code generation failure as well
|
||||
const cancelCodeGenerationAndReset = () => {
|
||||
// When this is the first version, reset the entire app state
|
||||
if (currentVersion === null) {
|
||||
const cancelCodeGenerationAndReset = (commit: Commit) => {
|
||||
// When the current commit is the first version, reset the entire app state
|
||||
if (commit.type === "ai_create") {
|
||||
reset();
|
||||
} else {
|
||||
// Otherwise, revert to the last version
|
||||
setGeneratedCode(appHistory[currentVersion].code);
|
||||
// Otherwise, remove current commit from commits
|
||||
removeCommit(commit.hash);
|
||||
|
||||
// Revert to parent commit
|
||||
const parentCommitHash = commit.parentHash;
|
||||
if (parentCommitHash) {
|
||||
setHead(parentCommitHash);
|
||||
} else {
|
||||
throw new Error("Parent commit not found");
|
||||
}
|
||||
|
||||
setAppState(AppState.CODE_READY);
|
||||
}
|
||||
};
|
||||
|
||||
async function doGenerateCode(
|
||||
params: CodeGenerationParams,
|
||||
parentVersion: number | null
|
||||
) {
|
||||
async function doGenerateCode(params: CodeGenerationParams) {
|
||||
// Reset the execution console
|
||||
setExecutionConsole([]);
|
||||
resetExecutionConsoles();
|
||||
|
||||
// Set the app state
|
||||
setAppState(AppState.CODING);
|
||||
@ -199,69 +209,50 @@ function App({ navbarComponent }: Props) {
|
||||
authToken: authToken || undefined,
|
||||
};
|
||||
|
||||
const baseCommitObject = {
|
||||
variants: [{ code: "" }, { code: "" }],
|
||||
};
|
||||
|
||||
const commitInputObject =
|
||||
params.generationType === "create"
|
||||
? {
|
||||
...baseCommitObject,
|
||||
type: "ai_create" as const,
|
||||
parentHash: null,
|
||||
inputs: { image_url: referenceImages[0] },
|
||||
}
|
||||
: {
|
||||
...baseCommitObject,
|
||||
type: "ai_edit" as const,
|
||||
parentHash: head,
|
||||
inputs: {
|
||||
prompt: params.history
|
||||
? params.history[params.history.length - 1]
|
||||
: "",
|
||||
},
|
||||
};
|
||||
|
||||
// Create a new commit and set it as the head
|
||||
const commit = createCommit(commitInputObject);
|
||||
addCommit(commit);
|
||||
setHead(commit.hash);
|
||||
|
||||
generateCode(
|
||||
wsRef,
|
||||
updatedParams,
|
||||
// On change
|
||||
(token) => setGeneratedCode((prev) => prev + token),
|
||||
(token, variantIndex) => {
|
||||
appendCommitCode(commit.hash, variantIndex, token);
|
||||
},
|
||||
// On set code
|
||||
(code) => {
|
||||
setGeneratedCode(code);
|
||||
if (params.generationType === "create") {
|
||||
if (inputMode === "image" || inputMode === "video") {
|
||||
setAppHistory([
|
||||
{
|
||||
type: "ai_create",
|
||||
parentIndex: null,
|
||||
code,
|
||||
inputs: { image_url: referenceImages[0] },
|
||||
},
|
||||
]);
|
||||
} else {
|
||||
setAppHistory([
|
||||
{
|
||||
type: "ai_create",
|
||||
parentIndex: null,
|
||||
code,
|
||||
inputs: { text: params.image },
|
||||
},
|
||||
]);
|
||||
}
|
||||
setCurrentVersion(0);
|
||||
} else {
|
||||
setAppHistory((prev) => {
|
||||
// Validate parent version
|
||||
if (parentVersion === null) {
|
||||
toast.error(
|
||||
"No parent version set. Contact support or open a Github issue."
|
||||
);
|
||||
addEvent("ParentVersionNull");
|
||||
return prev;
|
||||
}
|
||||
|
||||
const newHistory: History = [
|
||||
...prev,
|
||||
{
|
||||
type: "ai_edit",
|
||||
parentIndex: parentVersion,
|
||||
code,
|
||||
inputs: {
|
||||
prompt: params.history
|
||||
? params.history[params.history.length - 1]
|
||||
: "", // History should never be empty when performing an edit
|
||||
},
|
||||
},
|
||||
];
|
||||
setCurrentVersion(newHistory.length - 1);
|
||||
return newHistory;
|
||||
});
|
||||
}
|
||||
(code, variantIndex) => {
|
||||
setCommitCode(commit.hash, variantIndex, code);
|
||||
},
|
||||
// On status update
|
||||
(line) => setExecutionConsole((prev) => [...prev, line]),
|
||||
(line, variantIndex) => appendExecutionConsole(variantIndex, line),
|
||||
// On cancel
|
||||
() => {
|
||||
cancelCodeGenerationAndReset();
|
||||
cancelCodeGenerationAndReset(commit);
|
||||
},
|
||||
// On complete
|
||||
() => {
|
||||
@ -285,14 +276,11 @@ function App({ navbarComponent }: Props) {
|
||||
// Kick off the code generation
|
||||
if (referenceImages.length > 0) {
|
||||
addEvent("Create");
|
||||
await doGenerateCode(
|
||||
{
|
||||
doGenerateCode({
|
||||
generationType: "create",
|
||||
image: referenceImages[0],
|
||||
inputMode,
|
||||
},
|
||||
currentVersion
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -302,14 +290,11 @@ function App({ navbarComponent }: Props) {
|
||||
|
||||
setInputMode("text");
|
||||
setInitialPrompt(text);
|
||||
doGenerateCode(
|
||||
{
|
||||
doGenerateCode({
|
||||
generationType: "create",
|
||||
inputMode: "text",
|
||||
image: text,
|
||||
},
|
||||
currentVersion
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
// Subsequent updates
|
||||
@ -322,23 +307,22 @@ function App({ navbarComponent }: Props) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (currentVersion === null) {
|
||||
if (head === null) {
|
||||
toast.error(
|
||||
"No current version set. Contact support or open a Github issue."
|
||||
);
|
||||
addEvent("CurrentVersionNull");
|
||||
return;
|
||||
throw new Error("Update called with no head");
|
||||
}
|
||||
|
||||
let historyTree;
|
||||
try {
|
||||
historyTree = extractHistoryTree(appHistory, currentVersion);
|
||||
historyTree = extractHistory(head, commits);
|
||||
} catch {
|
||||
addEvent("HistoryTreeFailed");
|
||||
toast.error(
|
||||
"Version history is invalid. This shouldn't happen. Please contact support or open a Github issue."
|
||||
);
|
||||
return;
|
||||
throw new Error("Invalid version history");
|
||||
}
|
||||
|
||||
let modifiedUpdateInstruction = updateInstruction;
|
||||
@ -352,34 +336,19 @@ function App({ navbarComponent }: Props) {
|
||||
}
|
||||
|
||||
const updatedHistory = [...historyTree, modifiedUpdateInstruction];
|
||||
const resultImage = shouldIncludeResultImage
|
||||
? await takeScreenshot()
|
||||
: undefined;
|
||||
|
||||
if (shouldIncludeResultImage) {
|
||||
const resultImage = await takeScreenshot();
|
||||
await doGenerateCode(
|
||||
{
|
||||
doGenerateCode({
|
||||
generationType: "update",
|
||||
inputMode,
|
||||
image: referenceImages[0],
|
||||
resultImage: resultImage,
|
||||
resultImage,
|
||||
history: updatedHistory,
|
||||
isImportedFromCode,
|
||||
},
|
||||
currentVersion
|
||||
);
|
||||
} else {
|
||||
await doGenerateCode(
|
||||
{
|
||||
generationType: "update",
|
||||
inputMode,
|
||||
image: inputMode === "text" ? initialPrompt : referenceImages[0],
|
||||
history: updatedHistory,
|
||||
isImportedFromCode,
|
||||
},
|
||||
currentVersion
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
setGeneratedCode("");
|
||||
setUpdateInstruction("");
|
||||
}
|
||||
|
||||
@ -402,17 +371,17 @@ function App({ navbarComponent }: Props) {
|
||||
setIsImportedFromCode(true);
|
||||
|
||||
// Set up this project
|
||||
setGeneratedCode(code);
|
||||
setStack(stack);
|
||||
setAppHistory([
|
||||
{
|
||||
|
||||
// Create a new commit and set it as the head
|
||||
const commit = createCommit({
|
||||
type: "code_create",
|
||||
parentIndex: null,
|
||||
code,
|
||||
inputs: { code },
|
||||
},
|
||||
]);
|
||||
setCurrentVersion(0);
|
||||
parentHash: null,
|
||||
variants: [{ code }],
|
||||
inputs: null,
|
||||
});
|
||||
addCommit(commit);
|
||||
setHead(commit.hash);
|
||||
|
||||
// Set the app state
|
||||
setAppState(AppState.CODE_READY);
|
||||
|
||||
37
frontend/src/components/commits/types.ts
Normal file
37
frontend/src/components/commits/types.ts
Normal file
@ -0,0 +1,37 @@
|
||||
export type CommitHash = string;
|
||||
|
||||
export type Variant = {
|
||||
code: string;
|
||||
};
|
||||
|
||||
export type BaseCommit = {
|
||||
hash: CommitHash;
|
||||
parentHash: CommitHash | null;
|
||||
dateCreated: Date;
|
||||
isCommitted: boolean;
|
||||
variants: Variant[];
|
||||
selectedVariantIndex: number;
|
||||
};
|
||||
|
||||
export type CommitType = "ai_create" | "ai_edit" | "code_create";
|
||||
|
||||
export type AiCreateCommit = BaseCommit & {
|
||||
type: "ai_create";
|
||||
inputs: {
|
||||
image_url: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type AiEditCommit = BaseCommit & {
|
||||
type: "ai_edit";
|
||||
inputs: {
|
||||
prompt: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type CodeCreateCommit = BaseCommit & {
|
||||
type: "code_create";
|
||||
inputs: null;
|
||||
};
|
||||
|
||||
export type Commit = AiCreateCommit | AiEditCommit | CodeCreateCommit;
|
||||
32
frontend/src/components/commits/utils.ts
Normal file
32
frontend/src/components/commits/utils.ts
Normal file
@ -0,0 +1,32 @@
|
||||
import { nanoid } from "nanoid";
|
||||
import {
|
||||
AiCreateCommit,
|
||||
AiEditCommit,
|
||||
CodeCreateCommit,
|
||||
Commit,
|
||||
} from "./types";
|
||||
|
||||
export function createCommit(
|
||||
commit:
|
||||
| Omit<
|
||||
AiCreateCommit,
|
||||
"hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted"
|
||||
>
|
||||
| Omit<
|
||||
AiEditCommit,
|
||||
"hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted"
|
||||
>
|
||||
| Omit<
|
||||
CodeCreateCommit,
|
||||
"hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted"
|
||||
>
|
||||
): Commit {
|
||||
const hash = nanoid();
|
||||
return {
|
||||
...commit,
|
||||
hash,
|
||||
isCommitted: false,
|
||||
dateCreated: new Date(),
|
||||
selectedVariantIndex: 0,
|
||||
};
|
||||
}
|
||||
@ -17,19 +17,16 @@ interface Props {
|
||||
}
|
||||
|
||||
export default function HistoryDisplay({ shouldDisableReverts }: Props) {
|
||||
const {
|
||||
appHistory: history,
|
||||
currentVersion,
|
||||
setCurrentVersion,
|
||||
setGeneratedCode,
|
||||
} = useProjectStore();
|
||||
const renderedHistory = renderHistory(history, currentVersion);
|
||||
const { commits, head, setHead } = useProjectStore();
|
||||
|
||||
const revertToVersion = (index: number) => {
|
||||
if (index < 0 || index >= history.length || !history[index]) return;
|
||||
setCurrentVersion(index);
|
||||
setGeneratedCode(history[index].code);
|
||||
};
|
||||
// Put all commits into an array and sort by created date (oldest first)
|
||||
const flatHistory = Object.values(commits).sort(
|
||||
(a, b) =>
|
||||
new Date(a.dateCreated).getTime() - new Date(b.dateCreated).getTime()
|
||||
);
|
||||
|
||||
// Annotate history items with a summary, parent version, etc.
|
||||
const renderedHistory = renderHistory(flatHistory);
|
||||
|
||||
return renderedHistory.length === 0 ? null : (
|
||||
<div className="flex flex-col h-screen">
|
||||
@ -43,8 +40,8 @@ export default function HistoryDisplay({ shouldDisableReverts }: Props) {
|
||||
"flex items-center justify-between space-x-2 w-full pr-2",
|
||||
"border-b cursor-pointer",
|
||||
{
|
||||
" hover:bg-black hover:text-white": !item.isActive,
|
||||
"bg-slate-500 text-white": item.isActive,
|
||||
" hover:bg-black hover:text-white": item.hash === head,
|
||||
"bg-slate-500 text-white": item.hash === head,
|
||||
}
|
||||
)}
|
||||
>
|
||||
@ -55,14 +52,14 @@ export default function HistoryDisplay({ shouldDisableReverts }: Props) {
|
||||
? toast.error(
|
||||
"Please wait for code generation to complete before viewing an older version."
|
||||
)
|
||||
: revertToVersion(index)
|
||||
: setHead(item.hash)
|
||||
}
|
||||
>
|
||||
<div className="flex gap-x-1 truncate">
|
||||
<h2 className="text-sm truncate">{item.summary}</h2>
|
||||
{item.parentVersion !== null && (
|
||||
<h2 className="text-sm">
|
||||
(parent: {item.parentVersion})
|
||||
(parent: v{item.parentVersion})
|
||||
</h2>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@ -1,45 +0,0 @@
|
||||
export type HistoryItemType = "ai_create" | "ai_edit" | "code_create";
|
||||
|
||||
type CommonHistoryItem = {
|
||||
parentIndex: null | number;
|
||||
code: string;
|
||||
};
|
||||
|
||||
export type HistoryItem =
|
||||
| ({
|
||||
type: "ai_create";
|
||||
inputs: AiCreateInputs | AiCreateInputsText;
|
||||
} & CommonHistoryItem)
|
||||
| ({
|
||||
type: "ai_edit";
|
||||
inputs: AiEditInputs;
|
||||
} & CommonHistoryItem)
|
||||
| ({
|
||||
type: "code_create";
|
||||
inputs: CodeCreateInputs;
|
||||
} & CommonHistoryItem);
|
||||
|
||||
export type AiCreateInputs = {
|
||||
image_url: string;
|
||||
};
|
||||
|
||||
export type AiCreateInputsText = {
|
||||
text: string;
|
||||
};
|
||||
|
||||
export type AiEditInputs = {
|
||||
prompt: string;
|
||||
};
|
||||
|
||||
export type CodeCreateInputs = {
|
||||
code: string;
|
||||
};
|
||||
|
||||
export type History = HistoryItem[];
|
||||
|
||||
export type RenderedHistoryItem = {
|
||||
type: string;
|
||||
summary: string;
|
||||
parentVersion: string | null;
|
||||
isActive: boolean;
|
||||
};
|
||||
@ -1,91 +1,125 @@
|
||||
import { extractHistoryTree, renderHistory } from "./utils";
|
||||
import type { History } from "./history_types";
|
||||
import { extractHistory, renderHistory } from "./utils";
|
||||
import { Commit, CommitHash } from "../commits/types";
|
||||
|
||||
const basicLinearHistory: History = [
|
||||
{
|
||||
const basicLinearHistory: Record<CommitHash, Commit> = {
|
||||
"0": {
|
||||
hash: "0",
|
||||
dateCreated: new Date(),
|
||||
isCommitted: false,
|
||||
type: "ai_create",
|
||||
parentIndex: null,
|
||||
code: "<html>1. create</html>",
|
||||
parentHash: null,
|
||||
variants: [{ code: "<html>1. create</html>" }],
|
||||
selectedVariantIndex: 0,
|
||||
inputs: {
|
||||
image_url: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
"1": {
|
||||
hash: "1",
|
||||
dateCreated: new Date(),
|
||||
isCommitted: false,
|
||||
type: "ai_edit",
|
||||
parentIndex: 0,
|
||||
code: "<html>2. edit with better icons</html>",
|
||||
parentHash: "0",
|
||||
variants: [{ code: "<html>2. edit with better icons</html>" }],
|
||||
selectedVariantIndex: 0,
|
||||
inputs: {
|
||||
prompt: "use better icons",
|
||||
},
|
||||
},
|
||||
{
|
||||
"2": {
|
||||
hash: "2",
|
||||
dateCreated: new Date(),
|
||||
isCommitted: false,
|
||||
type: "ai_edit",
|
||||
parentIndex: 1,
|
||||
code: "<html>3. edit with better icons and red text</html>",
|
||||
parentHash: "1",
|
||||
variants: [{ code: "<html>3. edit with better icons and red text</html>" }],
|
||||
selectedVariantIndex: 0,
|
||||
inputs: {
|
||||
prompt: "make text red",
|
||||
},
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
const basicLinearHistoryWithCode: History = [
|
||||
{
|
||||
const basicLinearHistoryWithCode: Record<CommitHash, Commit> = {
|
||||
"0": {
|
||||
hash: "0",
|
||||
dateCreated: new Date(),
|
||||
isCommitted: false,
|
||||
type: "code_create",
|
||||
parentIndex: null,
|
||||
code: "<html>1. create</html>",
|
||||
inputs: {
|
||||
code: "<html>1. create</html>",
|
||||
parentHash: null,
|
||||
variants: [{ code: "<html>1. create</html>" }],
|
||||
selectedVariantIndex: 0,
|
||||
inputs: null,
|
||||
},
|
||||
},
|
||||
...basicLinearHistory.slice(1),
|
||||
];
|
||||
...Object.fromEntries(Object.entries(basicLinearHistory).slice(1)),
|
||||
};
|
||||
|
||||
const basicBranchingHistory: History = [
|
||||
const basicBranchingHistory: Record<CommitHash, Commit> = {
|
||||
...basicLinearHistory,
|
||||
{
|
||||
"3": {
|
||||
hash: "3",
|
||||
dateCreated: new Date(),
|
||||
isCommitted: false,
|
||||
type: "ai_edit",
|
||||
parentIndex: 1,
|
||||
code: "<html>4. edit with better icons and green text</html>",
|
||||
parentHash: "1",
|
||||
variants: [
|
||||
{ code: "<html>4. edit with better icons and green text</html>" },
|
||||
],
|
||||
selectedVariantIndex: 0,
|
||||
inputs: {
|
||||
prompt: "make text green",
|
||||
},
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
const longerBranchingHistory: History = [
|
||||
const longerBranchingHistory: Record<CommitHash, Commit> = {
|
||||
...basicBranchingHistory,
|
||||
{
|
||||
"4": {
|
||||
hash: "4",
|
||||
dateCreated: new Date(),
|
||||
isCommitted: false,
|
||||
type: "ai_edit",
|
||||
parentIndex: 3,
|
||||
code: "<html>5. edit with better icons and green, bold text</html>",
|
||||
parentHash: "3",
|
||||
variants: [
|
||||
{ code: "<html>5. edit with better icons and green, bold text</html>" },
|
||||
],
|
||||
selectedVariantIndex: 0,
|
||||
inputs: {
|
||||
prompt: "make text bold",
|
||||
},
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
const basicBadHistory: History = [
|
||||
{
|
||||
const basicBadHistory: Record<CommitHash, Commit> = {
|
||||
"0": {
|
||||
hash: "0",
|
||||
dateCreated: new Date(),
|
||||
isCommitted: false,
|
||||
type: "ai_create",
|
||||
parentIndex: null,
|
||||
code: "<html>1. create</html>",
|
||||
parentHash: null,
|
||||
variants: [{ code: "<html>1. create</html>" }],
|
||||
selectedVariantIndex: 0,
|
||||
inputs: {
|
||||
image_url: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
"1": {
|
||||
hash: "1",
|
||||
dateCreated: new Date(),
|
||||
isCommitted: false,
|
||||
type: "ai_edit",
|
||||
parentIndex: 2, // <- Bad parent index
|
||||
code: "<html>2. edit with better icons</html>",
|
||||
parentHash: "2", // <- Bad parent hash
|
||||
variants: [{ code: "<html>2. edit with better icons</html>" }],
|
||||
selectedVariantIndex: 0,
|
||||
inputs: {
|
||||
prompt: "use better icons",
|
||||
},
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
describe("History Utils", () => {
|
||||
test("should correctly extract the history tree", () => {
|
||||
expect(extractHistoryTree(basicLinearHistory, 2)).toEqual([
|
||||
expect(extractHistory("2", basicLinearHistory)).toEqual([
|
||||
"<html>1. create</html>",
|
||||
"use better icons",
|
||||
"<html>2. edit with better icons</html>",
|
||||
@ -93,12 +127,12 @@ describe("History Utils", () => {
|
||||
"<html>3. edit with better icons and red text</html>",
|
||||
]);
|
||||
|
||||
expect(extractHistoryTree(basicLinearHistory, 0)).toEqual([
|
||||
expect(extractHistory("0", basicLinearHistory)).toEqual([
|
||||
"<html>1. create</html>",
|
||||
]);
|
||||
|
||||
// Test branching
|
||||
expect(extractHistoryTree(basicBranchingHistory, 3)).toEqual([
|
||||
expect(extractHistory("3", basicBranchingHistory)).toEqual([
|
||||
"<html>1. create</html>",
|
||||
"use better icons",
|
||||
"<html>2. edit with better icons</html>",
|
||||
@ -106,7 +140,7 @@ describe("History Utils", () => {
|
||||
"<html>4. edit with better icons and green text</html>",
|
||||
]);
|
||||
|
||||
expect(extractHistoryTree(longerBranchingHistory, 4)).toEqual([
|
||||
expect(extractHistory("4", longerBranchingHistory)).toEqual([
|
||||
"<html>1. create</html>",
|
||||
"use better icons",
|
||||
"<html>2. edit with better icons</html>",
|
||||
@ -116,7 +150,7 @@ describe("History Utils", () => {
|
||||
"<html>5. edit with better icons and green, bold text</html>",
|
||||
]);
|
||||
|
||||
expect(extractHistoryTree(longerBranchingHistory, 2)).toEqual([
|
||||
expect(extractHistory("2", longerBranchingHistory)).toEqual([
|
||||
"<html>1. create</html>",
|
||||
"use better icons",
|
||||
"<html>2. edit with better icons</html>",
|
||||
@ -126,105 +160,82 @@ describe("History Utils", () => {
|
||||
|
||||
// Errors
|
||||
|
||||
// Bad index
|
||||
expect(() => extractHistoryTree(basicLinearHistory, 100)).toThrow();
|
||||
expect(() => extractHistoryTree(basicLinearHistory, -2)).toThrow();
|
||||
// Bad hash
|
||||
expect(() => extractHistory("100", basicLinearHistory)).toThrow();
|
||||
|
||||
// Bad tree
|
||||
expect(() => extractHistoryTree(basicBadHistory, 1)).toThrow();
|
||||
expect(() => extractHistory("1", basicBadHistory)).toThrow();
|
||||
});
|
||||
|
||||
test("should correctly render the history tree", () => {
|
||||
expect(renderHistory(basicLinearHistory, 2)).toEqual([
|
||||
expect(renderHistory(Object.values(basicLinearHistory))).toEqual([
|
||||
{
|
||||
isActive: false,
|
||||
parentVersion: null,
|
||||
summary: "Create",
|
||||
...basicLinearHistory["0"],
|
||||
type: "Create",
|
||||
},
|
||||
{
|
||||
isActive: false,
|
||||
parentVersion: null,
|
||||
summary: "use better icons",
|
||||
type: "Edit",
|
||||
},
|
||||
{
|
||||
isActive: true,
|
||||
parentVersion: null,
|
||||
summary: "make text red",
|
||||
type: "Edit",
|
||||
},
|
||||
]);
|
||||
|
||||
// Current version is the first version
|
||||
expect(renderHistory(basicLinearHistory, 0)).toEqual([
|
||||
{
|
||||
isActive: true,
|
||||
parentVersion: null,
|
||||
summary: "Create",
|
||||
type: "Create",
|
||||
parentVersion: null,
|
||||
},
|
||||
{
|
||||
isActive: false,
|
||||
parentVersion: null,
|
||||
...basicLinearHistory["1"],
|
||||
type: "Edit",
|
||||
summary: "use better icons",
|
||||
type: "Edit",
|
||||
parentVersion: null,
|
||||
},
|
||||
{
|
||||
isActive: false,
|
||||
parentVersion: null,
|
||||
summary: "make text red",
|
||||
...basicLinearHistory["2"],
|
||||
type: "Edit",
|
||||
summary: "make text red",
|
||||
parentVersion: null,
|
||||
},
|
||||
]);
|
||||
|
||||
// Render a history with code
|
||||
expect(renderHistory(basicLinearHistoryWithCode, 0)).toEqual([
|
||||
expect(renderHistory(Object.values(basicLinearHistoryWithCode))).toEqual([
|
||||
{
|
||||
isActive: true,
|
||||
parentVersion: null,
|
||||
summary: "Imported from code",
|
||||
...basicLinearHistoryWithCode["0"],
|
||||
type: "Imported from code",
|
||||
summary: "Imported from code",
|
||||
parentVersion: null,
|
||||
},
|
||||
{
|
||||
isActive: false,
|
||||
parentVersion: null,
|
||||
...basicLinearHistoryWithCode["1"],
|
||||
type: "Edit",
|
||||
summary: "use better icons",
|
||||
type: "Edit",
|
||||
parentVersion: null,
|
||||
},
|
||||
{
|
||||
isActive: false,
|
||||
parentVersion: null,
|
||||
summary: "make text red",
|
||||
...basicLinearHistoryWithCode["2"],
|
||||
type: "Edit",
|
||||
summary: "make text red",
|
||||
parentVersion: null,
|
||||
},
|
||||
]);
|
||||
|
||||
// Render a non-linear history
|
||||
expect(renderHistory(basicBranchingHistory, 3)).toEqual([
|
||||
expect(renderHistory(Object.values(basicBranchingHistory))).toEqual([
|
||||
{
|
||||
isActive: false,
|
||||
parentVersion: null,
|
||||
summary: "Create",
|
||||
...basicBranchingHistory["0"],
|
||||
type: "Create",
|
||||
summary: "Create",
|
||||
parentVersion: null,
|
||||
},
|
||||
{
|
||||
isActive: false,
|
||||
parentVersion: null,
|
||||
...basicBranchingHistory["1"],
|
||||
type: "Edit",
|
||||
summary: "use better icons",
|
||||
type: "Edit",
|
||||
},
|
||||
{
|
||||
isActive: false,
|
||||
parentVersion: null,
|
||||
summary: "make text red",
|
||||
type: "Edit",
|
||||
},
|
||||
{
|
||||
isActive: true,
|
||||
parentVersion: "v2",
|
||||
summary: "make text green",
|
||||
...basicBranchingHistory["2"],
|
||||
type: "Edit",
|
||||
summary: "make text red",
|
||||
parentVersion: null,
|
||||
},
|
||||
{
|
||||
...basicBranchingHistory["3"],
|
||||
type: "Edit",
|
||||
summary: "make text green",
|
||||
parentVersion: 2,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
@ -1,33 +1,25 @@
|
||||
import {
|
||||
History,
|
||||
HistoryItem,
|
||||
HistoryItemType,
|
||||
RenderedHistoryItem,
|
||||
} from "./history_types";
|
||||
import { Commit, CommitHash, CommitType } from "../commits/types";
|
||||
|
||||
export function extractHistoryTree(
|
||||
history: History,
|
||||
version: number
|
||||
export function extractHistory(
|
||||
hash: CommitHash,
|
||||
commits: Record<CommitHash, Commit>
|
||||
): string[] {
|
||||
const flatHistory: string[] = [];
|
||||
|
||||
let currentIndex: number | null = version;
|
||||
while (currentIndex !== null) {
|
||||
const item: HistoryItem = history[currentIndex];
|
||||
let currentCommitHash: CommitHash | null = hash;
|
||||
while (currentCommitHash !== null) {
|
||||
const commit: Commit | null = commits[currentCommitHash];
|
||||
|
||||
if (item) {
|
||||
if (item.type === "ai_create") {
|
||||
// Don't include the image for ai_create
|
||||
flatHistory.unshift(item.code);
|
||||
} else if (item.type === "ai_edit") {
|
||||
flatHistory.unshift(item.code);
|
||||
flatHistory.unshift(item.inputs.prompt);
|
||||
} else if (item.type === "code_create") {
|
||||
flatHistory.unshift(item.code);
|
||||
if (commit) {
|
||||
flatHistory.unshift(commit.variants[commit.selectedVariantIndex].code);
|
||||
|
||||
// For edits, add the prompt to the history
|
||||
if (commit.type === "ai_edit") {
|
||||
flatHistory.unshift(commit.inputs.prompt);
|
||||
}
|
||||
|
||||
// Move to the parent of the current item
|
||||
currentIndex = item.parentIndex;
|
||||
currentCommitHash = commit.parentHash;
|
||||
} else {
|
||||
throw new Error("Malformed history: missing parent index");
|
||||
}
|
||||
@ -36,7 +28,7 @@ export function extractHistoryTree(
|
||||
return flatHistory;
|
||||
}
|
||||
|
||||
function displayHistoryItemType(itemType: HistoryItemType) {
|
||||
function displayHistoryItemType(itemType: CommitType) {
|
||||
switch (itemType) {
|
||||
case "ai_create":
|
||||
return "Create";
|
||||
@ -51,44 +43,48 @@ function displayHistoryItemType(itemType: HistoryItemType) {
|
||||
}
|
||||
}
|
||||
|
||||
function summarizeHistoryItem(item: HistoryItem) {
|
||||
const itemType = item.type;
|
||||
switch (itemType) {
|
||||
const setParentVersion = (commit: Commit, history: Commit[]) => {
|
||||
// If the commit has no parent, return null
|
||||
if (!commit.parentHash) return null;
|
||||
|
||||
const parentIndex = history.findIndex(
|
||||
(item) => item.hash === commit.parentHash
|
||||
);
|
||||
const currentIndex = history.findIndex((item) => item.hash === commit.hash);
|
||||
|
||||
// Only set parent version if the parent is not the previous commit
|
||||
// and parent exists
|
||||
return parentIndex !== -1 && parentIndex != currentIndex - 1
|
||||
? parentIndex + 1
|
||||
: null;
|
||||
};
|
||||
|
||||
export function summarizeHistoryItem(commit: Commit) {
|
||||
const commitType = commit.type;
|
||||
switch (commitType) {
|
||||
case "ai_create":
|
||||
return "Create";
|
||||
case "ai_edit":
|
||||
return item.inputs.prompt;
|
||||
return commit.inputs.prompt;
|
||||
case "code_create":
|
||||
return "Imported from code";
|
||||
default: {
|
||||
const exhaustiveCheck: never = itemType;
|
||||
const exhaustiveCheck: never = commitType;
|
||||
throw new Error(`Unhandled case: ${exhaustiveCheck}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const renderHistory = (
|
||||
history: History,
|
||||
currentVersion: number | null
|
||||
) => {
|
||||
const renderedHistory: RenderedHistoryItem[] = [];
|
||||
export const renderHistory = (history: Commit[]) => {
|
||||
const renderedHistory = [];
|
||||
|
||||
for (let i = 0; i < history.length; i++) {
|
||||
const item = history[i];
|
||||
// Only show the parent version if it's not the previous version
|
||||
// (i.e. it's the branching point) and if it's not the first version
|
||||
const parentVersion =
|
||||
item.parentIndex !== null && item.parentIndex !== i - 1
|
||||
? `v${(item.parentIndex || 0) + 1}`
|
||||
: null;
|
||||
const type = displayHistoryItemType(item.type);
|
||||
const isActive = i === currentVersion;
|
||||
const summary = summarizeHistoryItem(item);
|
||||
const commit = history[i];
|
||||
renderedHistory.push({
|
||||
isActive,
|
||||
summary: summary,
|
||||
parentVersion,
|
||||
type,
|
||||
...commit,
|
||||
type: displayHistoryItemType(commit.type),
|
||||
summary: summarizeHistoryItem(commit),
|
||||
parentVersion: setParentVersion(commit, history),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -23,12 +23,17 @@ interface Props {
|
||||
|
||||
function PreviewPane({ doUpdate, reset, settings }: Props) {
|
||||
const { appState } = useAppStore();
|
||||
const { inputMode, generatedCode, setGeneratedCode } = useProjectStore();
|
||||
const { inputMode, head, commits } = useProjectStore();
|
||||
|
||||
const currentCommit = head && commits[head] ? commits[head] : "";
|
||||
const currentCode = currentCommit
|
||||
? currentCommit.variants[currentCommit.selectedVariantIndex].code
|
||||
: "";
|
||||
|
||||
const previewCode =
|
||||
inputMode === "video" && appState === AppState.CODING
|
||||
? extractHtml(generatedCode)
|
||||
: generatedCode;
|
||||
? extractHtml(currentCode)
|
||||
: currentCode;
|
||||
|
||||
return (
|
||||
<div className="ml-4">
|
||||
@ -45,7 +50,7 @@ function PreviewPane({ doUpdate, reset, settings }: Props) {
|
||||
Reset
|
||||
</Button>
|
||||
<Button
|
||||
onClick={() => downloadCode(generatedCode)}
|
||||
onClick={() => downloadCode(previewCode)}
|
||||
variant="secondary"
|
||||
className="flex items-center gap-x-2 mr-4 dark:text-white dark:bg-gray-700 download-btn"
|
||||
>
|
||||
@ -84,11 +89,7 @@ function PreviewPane({ doUpdate, reset, settings }: Props) {
|
||||
/>
|
||||
</TabsContent>
|
||||
<TabsContent value="code">
|
||||
<CodeTab
|
||||
code={previewCode}
|
||||
setCode={setGeneratedCode}
|
||||
settings={settings}
|
||||
/>
|
||||
<CodeTab code={previewCode} setCode={() => {}} settings={settings} />
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</div>
|
||||
|
||||
@ -12,6 +12,7 @@ import { Button } from "../ui/button";
|
||||
import { Textarea } from "../ui/textarea";
|
||||
import { useEffect, useRef } from "react";
|
||||
import HistoryDisplay from "../history/HistoryDisplay";
|
||||
import Variants from "../variants/Variants";
|
||||
|
||||
interface SidebarProps {
|
||||
showSelectAndEditFeature: boolean;
|
||||
@ -35,9 +36,18 @@ function Sidebar({
|
||||
shouldIncludeResultImage,
|
||||
setShouldIncludeResultImage,
|
||||
} = useAppStore();
|
||||
const { inputMode, generatedCode, referenceImages, executionConsole } =
|
||||
|
||||
const { inputMode, referenceImages, executionConsoles, head, commits } =
|
||||
useProjectStore();
|
||||
|
||||
const viewedCode =
|
||||
head && commits[head]
|
||||
? commits[head].variants[commits[head].selectedVariantIndex].code
|
||||
: "";
|
||||
|
||||
const executionConsole =
|
||||
(head && executionConsoles[commits[head].selectedVariantIndex]) || [];
|
||||
|
||||
// When coding is complete, focus on the update instruction textarea
|
||||
useEffect(() => {
|
||||
if (appState === AppState.CODE_READY && textareaRef.current) {
|
||||
@ -47,6 +57,8 @@ function Sidebar({
|
||||
|
||||
return (
|
||||
<>
|
||||
<Variants />
|
||||
|
||||
{/* Show code preview only when coding */}
|
||||
{appState === AppState.CODING && (
|
||||
<div className="flex flex-col">
|
||||
@ -66,7 +78,7 @@ function Sidebar({
|
||||
{executionConsole.slice(-1)[0]}
|
||||
</div>
|
||||
|
||||
<CodePreview code={generatedCode} />
|
||||
<CodePreview code={viewedCode} />
|
||||
|
||||
<div className="flex w-full">
|
||||
<Button
|
||||
@ -158,15 +170,22 @@ function Sidebar({
|
||||
)}
|
||||
<div className="bg-gray-400 px-4 py-2 rounded text-sm hidden">
|
||||
<h2 className="text-lg mb-4 border-b border-gray-800">Console</h2>
|
||||
{executionConsole.map((line, index) => (
|
||||
{Object.entries(executionConsoles).map(([index, lines]) => (
|
||||
<div key={index}>
|
||||
{lines.map((line, lineIndex) => (
|
||||
<div
|
||||
key={index}
|
||||
key={`${index}-${lineIndex}`}
|
||||
className="border-b border-gray-400 mb-2 text-gray-600 font-mono"
|
||||
>
|
||||
<span className="font-bold mr-2">{`${index}:${
|
||||
lineIndex + 1
|
||||
}`}</span>
|
||||
{line}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<HistoryDisplay shouldDisableReverts={appState === AppState.CODING} />
|
||||
|
||||
42
frontend/src/components/variants/Variants.tsx
Normal file
42
frontend/src/components/variants/Variants.tsx
Normal file
@ -0,0 +1,42 @@
|
||||
import { useProjectStore } from "../../store/project-store";
|
||||
|
||||
function Variants() {
|
||||
const { inputMode, head, commits, updateSelectedVariantIndex } =
|
||||
useProjectStore();
|
||||
|
||||
// If there is no head, don't show the variants
|
||||
if (head === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const commit = commits[head];
|
||||
const variants = commit.variants;
|
||||
const selectedVariantIndex = commit.selectedVariantIndex;
|
||||
|
||||
// If there is only one variant or the commit is already committed, don't show the variants
|
||||
if (variants.length <= 1 || commit.isCommitted || inputMode === "video") {
|
||||
return <div className="mt-2"></div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mt-4 mb-4">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
{variants.map((_, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className={`p-2 border rounded-md cursor-pointer ${
|
||||
index === selectedVariantIndex
|
||||
? "bg-blue-100 dark:bg-blue-900"
|
||||
: "bg-gray-50 dark:bg-gray-800 hover:bg-gray-100 dark:hover:bg-gray-700"
|
||||
}`}
|
||||
onClick={() => updateSelectedVariantIndex(head, index)}
|
||||
>
|
||||
<h3 className="font-medium mb-1">Option {index + 1}</h3>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default Variants;
|
||||
@ -11,12 +11,18 @@ const ERROR_MESSAGE =
|
||||
|
||||
const CANCEL_MESSAGE = "Code generation cancelled";
|
||||
|
||||
type WebSocketResponse = {
|
||||
type: "chunk" | "status" | "setCode" | "error";
|
||||
value: string;
|
||||
variantIndex: number;
|
||||
};
|
||||
|
||||
export function generateCode(
|
||||
wsRef: React.MutableRefObject<WebSocket | null>,
|
||||
params: FullGenerationSettings,
|
||||
onChange: (chunk: string) => void,
|
||||
onSetCode: (code: string) => void,
|
||||
onStatusUpdate: (status: string) => void,
|
||||
onChange: (chunk: string, variantIndex: number) => void,
|
||||
onSetCode: (code: string, variantIndex: number) => void,
|
||||
onStatusUpdate: (status: string, variantIndex: number) => void,
|
||||
onCancel: () => void,
|
||||
onComplete: () => void
|
||||
) {
|
||||
@ -31,13 +37,13 @@ export function generateCode(
|
||||
});
|
||||
|
||||
ws.addEventListener("message", async (event: MessageEvent) => {
|
||||
const response = JSON.parse(event.data);
|
||||
const response = JSON.parse(event.data) as WebSocketResponse;
|
||||
if (response.type === "chunk") {
|
||||
onChange(response.value);
|
||||
onChange(response.value, response.variantIndex);
|
||||
} else if (response.type === "status") {
|
||||
onStatusUpdate(response.value);
|
||||
onStatusUpdate(response.value, response.variantIndex);
|
||||
} else if (response.type === "setCode") {
|
||||
onSetCode(response.value);
|
||||
onSetCode(response.value, response.variantIndex);
|
||||
} else if (response.type === "error") {
|
||||
console.error("Error generating code", response.value);
|
||||
toast.error(response.value);
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { create } from "zustand";
|
||||
import { History } from "../components/history/history_types";
|
||||
import { Commit, CommitHash } from "../components/commits/types";
|
||||
|
||||
// Store for app-wide state
|
||||
interface ProjectStore {
|
||||
@ -11,25 +11,28 @@ interface ProjectStore {
|
||||
referenceImages: string[];
|
||||
setReferenceImages: (images: string[]) => void;
|
||||
|
||||
// Outputs and other state
|
||||
generatedCode: string;
|
||||
setGeneratedCode: (
|
||||
updater: string | ((currentCode: string) => string)
|
||||
) => void;
|
||||
executionConsole: string[];
|
||||
setExecutionConsole: (
|
||||
updater: string[] | ((currentConsole: string[]) => string[])
|
||||
) => void;
|
||||
// Outputs
|
||||
commits: Record<string, Commit>;
|
||||
head: CommitHash | null;
|
||||
|
||||
// Tracks the currently shown version from app history
|
||||
// TODO: might want to move to appStore
|
||||
currentVersion: number | null;
|
||||
setCurrentVersion: (version: number | null) => void;
|
||||
addCommit: (commit: Commit) => void;
|
||||
removeCommit: (hash: CommitHash) => void;
|
||||
resetCommits: () => void;
|
||||
|
||||
appHistory: History;
|
||||
setAppHistory: (
|
||||
updater: History | ((currentHistory: History) => History)
|
||||
appendCommitCode: (
|
||||
hash: CommitHash,
|
||||
numVariant: number,
|
||||
code: string
|
||||
) => void;
|
||||
setCommitCode: (hash: CommitHash, numVariant: number, code: string) => void;
|
||||
updateSelectedVariantIndex: (hash: CommitHash, index: number) => void;
|
||||
|
||||
setHead: (hash: CommitHash) => void;
|
||||
resetHead: () => void;
|
||||
|
||||
executionConsoles: { [key: number]: string[] };
|
||||
appendExecutionConsole: (variantIndex: number, line: string) => void;
|
||||
resetExecutionConsoles: () => void;
|
||||
}
|
||||
|
||||
export const useProjectStore = create<ProjectStore>((set) => ({
|
||||
@ -41,28 +44,106 @@ export const useProjectStore = create<ProjectStore>((set) => ({
|
||||
referenceImages: [],
|
||||
setReferenceImages: (images) => set({ referenceImages: images }),
|
||||
|
||||
// Outputs and other state
|
||||
generatedCode: "",
|
||||
setGeneratedCode: (updater) =>
|
||||
set((state) => ({
|
||||
generatedCode:
|
||||
typeof updater === "function" ? updater(state.generatedCode) : updater,
|
||||
})),
|
||||
executionConsole: [],
|
||||
setExecutionConsole: (updater) =>
|
||||
set((state) => ({
|
||||
executionConsole:
|
||||
typeof updater === "function"
|
||||
? updater(state.executionConsole)
|
||||
: updater,
|
||||
})),
|
||||
// Outputs
|
||||
commits: {},
|
||||
head: null,
|
||||
|
||||
currentVersion: null,
|
||||
setCurrentVersion: (version) => set({ currentVersion: version }),
|
||||
appHistory: [],
|
||||
setAppHistory: (updater) =>
|
||||
addCommit: (commit: Commit) => {
|
||||
// When adding a new commit, make sure all existing commits are marked as committed
|
||||
set((state) => ({
|
||||
appHistory:
|
||||
typeof updater === "function" ? updater(state.appHistory) : updater,
|
||||
})),
|
||||
commits: {
|
||||
...Object.fromEntries(
|
||||
Object.entries(state.commits).map(([hash, existingCommit]) => [
|
||||
hash,
|
||||
{ ...existingCommit, isCommitted: true },
|
||||
])
|
||||
),
|
||||
[commit.hash]: commit,
|
||||
},
|
||||
}));
|
||||
},
|
||||
removeCommit: (hash: CommitHash) => {
|
||||
set((state) => {
|
||||
const newCommits = { ...state.commits };
|
||||
delete newCommits[hash];
|
||||
return { commits: newCommits };
|
||||
});
|
||||
},
|
||||
resetCommits: () => set({ commits: {} }),
|
||||
|
||||
appendCommitCode: (hash: CommitHash, numVariant: number, code: string) =>
|
||||
set((state) => {
|
||||
const commit = state.commits[hash];
|
||||
// Don't update if the commit is already committed
|
||||
if (commit.isCommitted) {
|
||||
throw new Error("Attempted to append code to a committed commit");
|
||||
}
|
||||
return {
|
||||
commits: {
|
||||
...state.commits,
|
||||
[hash]: {
|
||||
...commit,
|
||||
variants: commit.variants.map((variant, index) =>
|
||||
index === numVariant
|
||||
? { ...variant, code: variant.code + code }
|
||||
: variant
|
||||
),
|
||||
},
|
||||
},
|
||||
};
|
||||
}),
|
||||
setCommitCode: (hash: CommitHash, numVariant: number, code: string) =>
|
||||
set((state) => {
|
||||
const commit = state.commits[hash];
|
||||
// Don't update if the commit is already committed
|
||||
if (commit.isCommitted) {
|
||||
throw new Error("Attempted to set code of a committed commit");
|
||||
}
|
||||
return {
|
||||
commits: {
|
||||
...state.commits,
|
||||
[hash]: {
|
||||
...commit,
|
||||
variants: commit.variants.map((variant, index) =>
|
||||
index === numVariant ? { ...variant, code } : variant
|
||||
),
|
||||
},
|
||||
},
|
||||
};
|
||||
}),
|
||||
updateSelectedVariantIndex: (hash: CommitHash, index: number) =>
|
||||
set((state) => {
|
||||
const commit = state.commits[hash];
|
||||
// Don't update if the commit is already committed
|
||||
if (commit.isCommitted) {
|
||||
throw new Error(
|
||||
"Attempted to update selected variant index of a committed commit"
|
||||
);
|
||||
}
|
||||
return {
|
||||
commits: {
|
||||
...state.commits,
|
||||
[hash]: {
|
||||
...commit,
|
||||
selectedVariantIndex: index,
|
||||
},
|
||||
},
|
||||
};
|
||||
}),
|
||||
|
||||
setHead: (hash: CommitHash) => set({ head: hash }),
|
||||
resetHead: () => set({ head: null }),
|
||||
|
||||
executionConsoles: {},
|
||||
appendExecutionConsole: (variantIndex: number, line: string) =>
|
||||
set((state) => ({
|
||||
executionConsoles: {
|
||||
...state.executionConsoles,
|
||||
[variantIndex]: [
|
||||
...(state.executionConsoles[variantIndex] || []),
|
||||
line,
|
||||
],
|
||||
},
|
||||
})),
|
||||
resetExecutionConsoles: () => set({ executionConsoles: {} }),
|
||||
}));
|
||||
|
||||
@ -4582,6 +4582,11 @@ nanoid@^3.3.6, nanoid@^3.3.7:
|
||||
resolved "https://registry.npmjs.org/nanoid/-/nanoid-3.3.7.tgz"
|
||||
integrity sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g==
|
||||
|
||||
nanoid@^5.0.7:
|
||||
version "5.0.7"
|
||||
resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-5.0.7.tgz#6452e8c5a816861fd9d2b898399f7e5fd6944cc6"
|
||||
integrity sha512-oLxFY2gd2IqnjcYyOXD8XGCftpGtZP2AbHbOkthDkvRywH5ayNtPVy9YlOPcHckXzbLTCHpkb7FB+yuxKV13pQ==
|
||||
|
||||
natural-compare@^1.4.0:
|
||||
version "1.4.0"
|
||||
resolved "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user