merge
This commit is contained in:
commit
c107f4eda5
@ -33,10 +33,6 @@ We also just added experimental support for taking a video/screen recording of a
|
|||||||
|
|
||||||
[Follow me on Twitter for updates](https://twitter.com/_abi_).
|
[Follow me on Twitter for updates](https://twitter.com/_abi_).
|
||||||
|
|
||||||
## Sponsors
|
|
||||||
|
|
||||||
<a href="https://konghq.com/products/kong-konnect?utm_medium=referral&utm_source=github&utm_campaign=platform&utm_content=screenshot-to-code" target="_blank" title="Kong - powering the API world"><img src="https://picoapps.xyz/s2c-sponsors/Kong-GitHub-240x100.png"></a>
|
|
||||||
|
|
||||||
## 🚀 Hosted Version
|
## 🚀 Hosted Version
|
||||||
|
|
||||||
[Try it live on the hosted version (paid)](https://screenshottocode.com).
|
[Try it live on the hosted version (paid)](https://screenshottocode.com).
|
||||||
|
|||||||
@ -7,15 +7,15 @@ repos:
|
|||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
- repo: local
|
# - repo: local
|
||||||
hooks:
|
# hooks:
|
||||||
- id: poetry-pytest
|
# - id: poetry-pytest
|
||||||
name: Run pytest with Poetry
|
# name: Run pytest with Poetry
|
||||||
entry: poetry run --directory backend pytest
|
# entry: poetry run --directory backend pytest
|
||||||
language: system
|
# language: system
|
||||||
pass_filenames: false
|
# pass_filenames: false
|
||||||
always_run: true
|
# always_run: true
|
||||||
files: ^backend/
|
# files: ^backend/
|
||||||
# - id: poetry-pyright
|
# - id: poetry-pyright
|
||||||
# name: Run pyright with Poetry
|
# name: Run pyright with Poetry
|
||||||
# entry: poetry run --directory backend pyright
|
# entry: poetry run --directory backend pyright
|
||||||
|
|||||||
@ -3,8 +3,12 @@
|
|||||||
# TODO: Should only be set to true when value is 'True', not any abitrary truthy value
|
# TODO: Should only be set to true when value is 'True', not any abitrary truthy value
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
NUM_VARIANTS = 2
|
||||||
|
|
||||||
|
# LLM-related
|
||||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
|
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
|
||||||
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
|
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
|
||||||
|
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", None)
|
||||||
|
|
||||||
# Image generation (optional)
|
# Image generation (optional)
|
||||||
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
|
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
|
||||||
|
|||||||
0
backend/fs_logging/__init__.py
Normal file
0
backend/fs_logging/__init__.py
Normal file
23
backend/fs_logging/core.py
Normal file
23
backend/fs_logging/core.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
|
||||||
|
def write_logs(prompt_messages: list[ChatCompletionMessageParam], completion: str):
|
||||||
|
# Get the logs path from environment, default to the current working directory
|
||||||
|
logs_path = os.environ.get("LOGS_PATH", os.getcwd())
|
||||||
|
|
||||||
|
# Create run_logs directory if it doesn't exist within the specified logs path
|
||||||
|
logs_directory = os.path.join(logs_path, "run_logs")
|
||||||
|
if not os.path.exists(logs_directory):
|
||||||
|
os.makedirs(logs_directory)
|
||||||
|
|
||||||
|
print("Writing to logs directory:", logs_directory)
|
||||||
|
|
||||||
|
# Generate a unique filename using the current timestamp within the logs directory
|
||||||
|
filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json")
|
||||||
|
|
||||||
|
# Write the messages dict into a new file for each run
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
|
||||||
@ -28,7 +28,7 @@ async def process_tasks(
|
|||||||
|
|
||||||
processed_results: List[Union[str, None]] = []
|
processed_results: List[Union[str, None]] = []
|
||||||
for result in results:
|
for result in results:
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, BaseException):
|
||||||
print(f"An exception occurred: {result}")
|
print(f"An exception occurred: {result}")
|
||||||
try:
|
try:
|
||||||
raise result
|
raise result
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Awaitable, Callable, List, cast
|
from typing import Any, Awaitable, Callable, List, cast
|
||||||
from anthropic import AsyncAnthropic
|
from anthropic import AsyncAnthropic
|
||||||
@ -112,8 +113,12 @@ async def stream_claude_response(
|
|||||||
temperature = 0.0
|
temperature = 0.0
|
||||||
|
|
||||||
# Translate OpenAI messages to Claude messages
|
# Translate OpenAI messages to Claude messages
|
||||||
system_prompt = cast(str, messages[0].get("content"))
|
|
||||||
claude_messages = [dict(message) for message in messages[1:]]
|
# Deep copy messages to avoid modifying the original list
|
||||||
|
cloned_messages = copy.deepcopy(messages)
|
||||||
|
|
||||||
|
system_prompt = cast(str, cloned_messages[0].get("content"))
|
||||||
|
claude_messages = [dict(message) for message in cloned_messages[1:]]
|
||||||
for message in claude_messages:
|
for message in claude_messages:
|
||||||
if not isinstance(message["content"], list):
|
if not isinstance(message["content"], list):
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -8,7 +8,7 @@ STREAM_CHUNK_SIZE = 20
|
|||||||
|
|
||||||
|
|
||||||
async def mock_completion(
|
async def mock_completion(
|
||||||
process_chunk: Callable[[str], Awaitable[None]], input_mode: InputMode
|
process_chunk: Callable[[str, int], Awaitable[None]], input_mode: InputMode
|
||||||
) -> str:
|
) -> str:
|
||||||
code_to_return = (
|
code_to_return = (
|
||||||
TALLY_FORM_VIDEO_PROMPT_MOCK
|
TALLY_FORM_VIDEO_PROMPT_MOCK
|
||||||
@ -17,7 +17,7 @@ async def mock_completion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE):
|
for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE):
|
||||||
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE])
|
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE], 0)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
if input_mode == "video":
|
if input_mode == "video":
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
from typing import List, NoReturn, Union
|
from typing import Union
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
|
||||||
from llm import Llm
|
|
||||||
|
|
||||||
|
from custom_types import InputMode
|
||||||
|
from image_generation.core import create_alt_url_mapping
|
||||||
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
|
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
|
||||||
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
|
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
|
||||||
from prompts.text_prompts import SYSTEM_PROMPTS as TEXT_SYSTEM_PROMPTS
|
from prompts.text_prompts import SYSTEM_PROMPTS as TEXT_SYSTEM_PROMPTS
|
||||||
from prompts.types import Stack
|
from prompts.types import Stack
|
||||||
|
from video.utils import assemble_claude_prompt_video
|
||||||
|
|
||||||
|
|
||||||
USER_PROMPT = """
|
USER_PROMPT = """
|
||||||
@ -18,9 +19,65 @@ Generate code for a SVG that looks exactly like this.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def create_prompt(
|
||||||
|
params: dict[str, str], stack: Stack, input_mode: InputMode
|
||||||
|
) -> tuple[list[ChatCompletionMessageParam], dict[str, str]]:
|
||||||
|
|
||||||
|
image_cache: dict[str, str] = {}
|
||||||
|
|
||||||
|
# If this generation started off with imported code, we need to assemble the prompt differently
|
||||||
|
if params.get("isImportedFromCode"):
|
||||||
|
original_imported_code = params["history"][0]
|
||||||
|
prompt_messages = assemble_imported_code_prompt(original_imported_code, stack)
|
||||||
|
for index, text in enumerate(params["history"][1:]):
|
||||||
|
if index % 2 == 0:
|
||||||
|
message: ChatCompletionMessageParam = {
|
||||||
|
"role": "user",
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
message: ChatCompletionMessageParam = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
prompt_messages.append(message)
|
||||||
|
else:
|
||||||
|
# Assemble the prompt for non-imported code
|
||||||
|
if params.get("resultImage"):
|
||||||
|
prompt_messages = assemble_prompt(
|
||||||
|
params["image"], stack, params["resultImage"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt_messages = assemble_prompt(params["image"], stack)
|
||||||
|
|
||||||
|
if params["generationType"] == "update":
|
||||||
|
# Transform the history tree into message format
|
||||||
|
# TODO: Move this to frontend
|
||||||
|
for index, text in enumerate(params["history"]):
|
||||||
|
if index % 2 == 0:
|
||||||
|
message: ChatCompletionMessageParam = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
message: ChatCompletionMessageParam = {
|
||||||
|
"role": "user",
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
prompt_messages.append(message)
|
||||||
|
|
||||||
|
image_cache = create_alt_url_mapping(params["history"][-2])
|
||||||
|
|
||||||
|
if input_mode == "video":
|
||||||
|
video_data_url = params["image"]
|
||||||
|
prompt_messages = await assemble_claude_prompt_video(video_data_url)
|
||||||
|
|
||||||
|
return prompt_messages, image_cache
|
||||||
|
|
||||||
|
|
||||||
def assemble_imported_code_prompt(
|
def assemble_imported_code_prompt(
|
||||||
code: str, stack: Stack, model: Llm
|
code: str, stack: Stack
|
||||||
) -> List[ChatCompletionMessageParam]:
|
) -> list[ChatCompletionMessageParam]:
|
||||||
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
|
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
|
||||||
|
|
||||||
user_content = (
|
user_content = (
|
||||||
@ -29,24 +86,12 @@ def assemble_imported_code_prompt(
|
|||||||
else "Here is the code of the SVG: " + code
|
else "Here is the code of the SVG: " + code
|
||||||
)
|
)
|
||||||
|
|
||||||
if model == Llm.CLAUDE_3_5_SONNET_2024_06_20:
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": system_content + "\n " + user_content,
|
"content": system_content + "\n " + user_content,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": system_content,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": user_content,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
# TODO: Use result_image_data_url
|
# TODO: Use result_image_data_url
|
||||||
|
|
||||||
|
|
||||||
@ -54,11 +99,11 @@ def assemble_prompt(
|
|||||||
image_data_url: str,
|
image_data_url: str,
|
||||||
stack: Stack,
|
stack: Stack,
|
||||||
result_image_data_url: Union[str, None] = None,
|
result_image_data_url: Union[str, None] = None,
|
||||||
) -> List[ChatCompletionMessageParam]:
|
) -> list[ChatCompletionMessageParam]:
|
||||||
system_content = SYSTEM_PROMPTS[stack]
|
system_content = SYSTEM_PROMPTS[stack]
|
||||||
user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT
|
user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT
|
||||||
|
|
||||||
user_content: List[ChatCompletionContentPartParam] = [
|
user_content: list[ChatCompletionContentPartParam] = [
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {"url": image_data_url, "detail": "high"},
|
"image_url": {"url": image_data_url, "detail": "high"},
|
||||||
@ -93,7 +138,7 @@ def assemble_prompt(
|
|||||||
def assemble_text_prompt(
|
def assemble_text_prompt(
|
||||||
text_prompt: str,
|
text_prompt: str,
|
||||||
stack: Stack,
|
stack: Stack,
|
||||||
) -> List[ChatCompletionMessageParam]:
|
) -> list[ChatCompletionMessageParam]:
|
||||||
|
|
||||||
system_content = TEXT_SYSTEM_PROMPTS[stack]
|
system_content = TEXT_SYSTEM_PROMPTS[stack]
|
||||||
|
|
||||||
|
|||||||
@ -391,63 +391,81 @@ def test_prompts():
|
|||||||
|
|
||||||
|
|
||||||
def test_imported_code_prompts():
|
def test_imported_code_prompts():
|
||||||
tailwind_prompt = assemble_imported_code_prompt(
|
code = "Sample code"
|
||||||
"code", "html_tailwind", Llm.GPT_4O_2024_05_13
|
|
||||||
)
|
tailwind_prompt = assemble_imported_code_prompt(code, "html_tailwind")
|
||||||
expected_tailwind_prompt = [
|
expected_tailwind_prompt = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the app: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
|
||||||
|
+ "\n Here is the code of the app: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert tailwind_prompt == expected_tailwind_prompt
|
assert tailwind_prompt == expected_tailwind_prompt
|
||||||
|
|
||||||
html_css_prompt = assemble_imported_code_prompt(
|
html_css_prompt = assemble_imported_code_prompt(code, "html_css")
|
||||||
"code", "html_css", Llm.GPT_4O_2024_05_13
|
|
||||||
)
|
|
||||||
expected_html_css_prompt = [
|
expected_html_css_prompt = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the app: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT
|
||||||
|
+ "\n Here is the code of the app: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert html_css_prompt == expected_html_css_prompt
|
assert html_css_prompt == expected_html_css_prompt
|
||||||
|
|
||||||
react_tailwind_prompt = assemble_imported_code_prompt(
|
react_tailwind_prompt = assemble_imported_code_prompt(code, "react_tailwind")
|
||||||
"code", "react_tailwind", Llm.GPT_4O_2024_05_13
|
|
||||||
)
|
|
||||||
expected_react_tailwind_prompt = [
|
expected_react_tailwind_prompt = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the app: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT
|
||||||
|
+ "\n Here is the code of the app: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert react_tailwind_prompt == expected_react_tailwind_prompt
|
assert react_tailwind_prompt == expected_react_tailwind_prompt
|
||||||
|
|
||||||
bootstrap_prompt = assemble_imported_code_prompt(
|
bootstrap_prompt = assemble_imported_code_prompt(code, "bootstrap")
|
||||||
"code", "bootstrap", Llm.GPT_4O_2024_05_13
|
|
||||||
)
|
|
||||||
expected_bootstrap_prompt = [
|
expected_bootstrap_prompt = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the app: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT
|
||||||
|
+ "\n Here is the code of the app: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert bootstrap_prompt == expected_bootstrap_prompt
|
assert bootstrap_prompt == expected_bootstrap_prompt
|
||||||
|
|
||||||
ionic_tailwind = assemble_imported_code_prompt(
|
ionic_tailwind = assemble_imported_code_prompt(code, "ionic_tailwind")
|
||||||
"code", "ionic_tailwind", Llm.GPT_4O_2024_05_13
|
|
||||||
)
|
|
||||||
expected_ionic_tailwind = [
|
expected_ionic_tailwind = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the app: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
|
||||||
|
+ "\n Here is the code of the app: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert ionic_tailwind == expected_ionic_tailwind
|
assert ionic_tailwind == expected_ionic_tailwind
|
||||||
|
|
||||||
vue_tailwind = assemble_imported_code_prompt(
|
vue_tailwind = assemble_imported_code_prompt(code, "vue_tailwind")
|
||||||
"code", "vue_tailwind", Llm.GPT_4O_2024_05_13
|
|
||||||
)
|
|
||||||
expected_vue_tailwind = [
|
expected_vue_tailwind = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the app: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_VUE_TAILWIND_PROMPT
|
||||||
|
+ "\n Here is the code of the app: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert vue_tailwind == expected_vue_tailwind
|
assert vue_tailwind == expected_vue_tailwind
|
||||||
|
|
||||||
svg = assemble_imported_code_prompt("code", "svg", Llm.GPT_4O_2024_05_13)
|
svg = assemble_imported_code_prompt(code, "svg")
|
||||||
expected_svg = [
|
expected_svg = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the SVG: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_SVG_SYSTEM_PROMPT
|
||||||
|
+ "\n Here is the code of the SVG: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert svg == expected_svg
|
assert svg == expected_svg
|
||||||
|
|||||||
@ -1,12 +1,15 @@
|
|||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
import os
|
import os
|
||||||
import traceback
|
|
||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, WebSocket
|
||||||
import openai
|
import openai
|
||||||
import sentry_sdk
|
|
||||||
from codegen.utils import extract_html_content
|
from codegen.utils import extract_html_content
|
||||||
from config import (
|
from config import (
|
||||||
ANTHROPIC_API_KEY,
|
ANTHROPIC_API_KEY,
|
||||||
IS_PROD,
|
IS_PROD,
|
||||||
|
NUM_VARIANTS,
|
||||||
|
OPENAI_API_KEY,
|
||||||
|
OPENAI_BASE_URL,
|
||||||
REPLICATE_API_KEY,
|
REPLICATE_API_KEY,
|
||||||
SHOULD_MOCK_AI_RESPONSE,
|
SHOULD_MOCK_AI_RESPONSE,
|
||||||
)
|
)
|
||||||
@ -18,77 +21,101 @@ from llm import (
|
|||||||
stream_claude_response_native,
|
stream_claude_response_native,
|
||||||
stream_openai_response,
|
stream_openai_response,
|
||||||
)
|
)
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
|
||||||
from mock_llm import mock_completion
|
from mock_llm import mock_completion
|
||||||
from typing import Dict, List, Union, cast, get_args
|
from typing import Dict, List, cast, get_args
|
||||||
from image_generation.core import create_alt_url_mapping, generate_images
|
from image_generation.core import generate_images
|
||||||
from prompts import assemble_imported_code_prompt, assemble_prompt, assemble_text_prompt
|
|
||||||
from datetime import datetime
|
|
||||||
import json
|
|
||||||
from routes.logging_utils import PaymentMethod, send_to_saas_backend
|
from routes.logging_utils import PaymentMethod, send_to_saas_backend
|
||||||
from routes.saas_utils import does_user_have_subscription_credits
|
from routes.saas_utils import does_user_have_subscription_credits
|
||||||
|
from typing import Any, Callable, Coroutine, Dict, List, Literal, cast, get_args
|
||||||
|
from image_generation.core import generate_images
|
||||||
|
from prompts import create_prompt
|
||||||
from prompts.claude_prompts import VIDEO_PROMPT
|
from prompts.claude_prompts import VIDEO_PROMPT
|
||||||
from prompts.types import Stack
|
from prompts.types import Stack
|
||||||
from utils import pprint_prompt
|
|
||||||
|
|
||||||
# from utils import pprint_prompt
|
# from utils import pprint_prompt
|
||||||
from video.utils import extract_tag_content, assemble_claude_prompt_video
|
|
||||||
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str):
|
# Auto-upgrade usage of older models
|
||||||
# Get the logs path from environment, default to the current working directory
|
def auto_upgrade_model(code_generation_model: Llm) -> Llm:
|
||||||
logs_path = os.environ.get("LOGS_PATH", os.getcwd())
|
if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}:
|
||||||
|
print(
|
||||||
# Create run_logs directory if it doesn't exist within the specified logs path
|
f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13"
|
||||||
logs_directory = os.path.join(logs_path, "run_logs")
|
)
|
||||||
if not os.path.exists(logs_directory):
|
return Llm.GPT_4O_2024_05_13
|
||||||
os.makedirs(logs_directory)
|
elif code_generation_model == Llm.CLAUDE_3_SONNET:
|
||||||
|
print(
|
||||||
print("Writing to logs directory:", logs_directory)
|
f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20"
|
||||||
|
)
|
||||||
# Generate a unique filename using the current timestamp within the logs directory
|
return Llm.CLAUDE_3_5_SONNET_2024_06_20
|
||||||
filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json")
|
return code_generation_model
|
||||||
|
|
||||||
# Write the messages dict into a new file for each run
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
|
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/generate-code")
|
# Generate images, if needed
|
||||||
async def stream_code(websocket: WebSocket):
|
async def perform_image_generation(
|
||||||
await websocket.accept()
|
completion: str,
|
||||||
|
should_generate_images: bool,
|
||||||
print("Incoming websocket connection...")
|
openai_api_key: str | None,
|
||||||
|
openai_base_url: str | None,
|
||||||
async def throw_error(
|
image_cache: dict[str, str],
|
||||||
message: str,
|
|
||||||
):
|
):
|
||||||
await websocket.send_json({"type": "error", "value": message})
|
replicate_api_key = REPLICATE_API_KEY
|
||||||
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
|
if not should_generate_images:
|
||||||
|
return completion
|
||||||
|
|
||||||
# TODO: Are the values always strings?
|
if replicate_api_key:
|
||||||
params: Dict[str, str] = await websocket.receive_json()
|
image_generation_model = "sdxl-lightning"
|
||||||
|
api_key = replicate_api_key
|
||||||
|
else:
|
||||||
|
if not openai_api_key:
|
||||||
|
print(
|
||||||
|
"No OpenAI API key and Replicate key found. Skipping image generation."
|
||||||
|
)
|
||||||
|
return completion
|
||||||
|
image_generation_model = "dalle3"
|
||||||
|
api_key = openai_api_key
|
||||||
|
|
||||||
# Read the code config settings from the request. Fall back to default if not provided.
|
print("Generating images with model: ", image_generation_model)
|
||||||
generated_code_config = ""
|
|
||||||
if "generatedCodeConfig" in params and params["generatedCodeConfig"]:
|
return await generate_images(
|
||||||
generated_code_config = params["generatedCodeConfig"]
|
completion,
|
||||||
if not generated_code_config in get_args(Stack):
|
api_key=api_key,
|
||||||
|
base_url=openai_base_url,
|
||||||
|
image_cache=image_cache,
|
||||||
|
model=image_generation_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtractedParams:
|
||||||
|
stack: Stack
|
||||||
|
input_mode: InputMode
|
||||||
|
code_generation_model: Llm
|
||||||
|
should_generate_images: bool
|
||||||
|
openai_api_key: str | None
|
||||||
|
anthropic_api_key: str | None
|
||||||
|
openai_base_url: str | None
|
||||||
|
payment_method: PaymentMethod
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_params(
|
||||||
|
params: Dict[str, str], throw_error: Callable[[str], Coroutine[Any, Any, None]]
|
||||||
|
) -> ExtractedParams:
|
||||||
|
# Read the code config settings (stack) from the request.
|
||||||
|
generated_code_config = params.get("generatedCodeConfig", "")
|
||||||
|
if generated_code_config not in get_args(Stack):
|
||||||
await throw_error(f"Invalid generated code config: {generated_code_config}")
|
await throw_error(f"Invalid generated code config: {generated_code_config}")
|
||||||
return
|
raise ValueError(f"Invalid generated code config: {generated_code_config}")
|
||||||
# Cast the variable to the Stack type
|
validated_stack = cast(Stack, generated_code_config)
|
||||||
valid_stack = cast(Stack, generated_code_config)
|
|
||||||
|
|
||||||
# Validate the input mode
|
# Validate the input mode
|
||||||
input_mode = params.get("inputMode", "image")
|
input_mode = params.get("inputMode")
|
||||||
if not input_mode in get_args(InputMode):
|
if input_mode not in get_args(InputMode):
|
||||||
await throw_error(f"Invalid input mode: {input_mode}")
|
await throw_error(f"Invalid input mode: {input_mode}")
|
||||||
raise Exception(f"Invalid input mode: {input_mode}")
|
raise ValueError(f"Invalid input mode: {input_mode}")
|
||||||
# Cast the variable to the right type
|
|
||||||
validated_input_mode = cast(InputMode, input_mode)
|
validated_input_mode = cast(InputMode, input_mode)
|
||||||
|
|
||||||
# Read the model from the request. Fall back to default if not provided.
|
# Read the model from the request. Fall back to default if not provided.
|
||||||
@ -97,42 +124,19 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
code_generation_model = convert_frontend_str_to_llm(code_generation_model_str)
|
code_generation_model = convert_frontend_str_to_llm(code_generation_model_str)
|
||||||
except:
|
except ValueError:
|
||||||
await throw_error(f"Invalid model: {code_generation_model_str}")
|
await throw_error(f"Invalid model: {code_generation_model_str}")
|
||||||
raise Exception(f"Invalid model: {code_generation_model_str}")
|
raise ValueError(f"Invalid model: {code_generation_model_str}")
|
||||||
|
|
||||||
# Auto-upgrade usage of older models
|
|
||||||
if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}:
|
|
||||||
print(
|
|
||||||
f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13"
|
|
||||||
)
|
|
||||||
code_generation_model = Llm.GPT_4O_2024_05_13
|
|
||||||
elif code_generation_model == Llm.CLAUDE_3_SONNET:
|
|
||||||
print(
|
|
||||||
f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20"
|
|
||||||
)
|
|
||||||
code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20
|
|
||||||
|
|
||||||
exact_llm_version = None
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Track how this generation is being paid for
|
|
||||||
payment_method: PaymentMethod = PaymentMethod.UNKNOWN
|
|
||||||
# Track the OpenAI API key to use
|
|
||||||
openai_api_key = None
|
|
||||||
|
|
||||||
auth_token = params.get("authToken")
|
auth_token = params.get("authToken")
|
||||||
if not auth_token:
|
if not auth_token:
|
||||||
await throw_error("You need to be logged in to use screenshot to code")
|
await throw_error("You need to be logged in to use screenshot to code")
|
||||||
raise Exception("No auth token")
|
raise Exception("No auth token")
|
||||||
|
|
||||||
# Get the OpenAI key by waterfalling through the different payment methods
|
openai_api_key = None
|
||||||
# 1. Subscription
|
|
||||||
# 2. User's API key from client-side settings dialog
|
# Track how this generation is being paid for
|
||||||
# 3. User's API key from environment variable
|
payment_method: PaymentMethod = PaymentMethod.UNKNOWN
|
||||||
|
|
||||||
# If the user is a subscriber, use the platform API key
|
# If the user is a subscriber, use the platform API key
|
||||||
# TODO: Rename does_user_have_subscription_credits
|
# TODO: Rename does_user_have_subscription_credits
|
||||||
@ -157,151 +161,152 @@ async def stream_code(websocket: WebSocket):
|
|||||||
await throw_error("Unknown error occurred. Contact support.")
|
await throw_error("Unknown error occurred. Contact support.")
|
||||||
raise Exception("Unknown error occurred when checking subscription credits")
|
raise Exception("Unknown error occurred when checking subscription credits")
|
||||||
|
|
||||||
# For non-subscribers, use the user's API key from client-side settings dialog
|
|
||||||
if not openai_api_key:
|
|
||||||
openai_api_key = params.get("openAiApiKey", None)
|
|
||||||
payment_method = PaymentMethod.OPENAI_API_KEY
|
|
||||||
print("Using OpenAI API key from client-side settings dialog")
|
|
||||||
|
|
||||||
# If we still don't have an API key, use the user's API key from environment variable
|
# If we still don't have an API key, use the user's API key from environment variable
|
||||||
if not openai_api_key:
|
if not openai_api_key:
|
||||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
openai_api_key = get_from_settings_dialog_or_env(
|
||||||
|
params, "openAiApiKey", OPENAI_API_KEY
|
||||||
|
)
|
||||||
payment_method = PaymentMethod.OPENAI_API_KEY
|
payment_method = PaymentMethod.OPENAI_API_KEY
|
||||||
if openai_api_key:
|
if openai_api_key:
|
||||||
|
# TODO
|
||||||
print("Using OpenAI API key from environment variable")
|
print("Using OpenAI API key from environment variable")
|
||||||
|
|
||||||
if not openai_api_key and (
|
# if not openai_api_key and (
|
||||||
code_generation_model == Llm.GPT_4_VISION
|
# code_generation_model == Llm.GPT_4_VISION
|
||||||
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
|
# or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
|
||||||
or code_generation_model == Llm.GPT_4O_2024_05_13
|
# or code_generation_model == Llm.GPT_4O_2024_05_13
|
||||||
):
|
# ):
|
||||||
print("OpenAI API key not found")
|
# print("OpenAI API key not found")
|
||||||
await throw_error(
|
# await throw_error(
|
||||||
"Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support."
|
# "Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support."
|
||||||
)
|
# )
|
||||||
raise Exception("No OpenAI API key found")
|
# raise Exception("No OpenAI API key found")
|
||||||
|
|
||||||
# Get the Anthropic API key from the request. Fall back to environment variable if not provided.
|
# TODO: Do not allow usage of key
|
||||||
# If neither is provided, we throw an error later only if Claude is used.
|
# If neither is provided, we throw an error later only if Claude is used.
|
||||||
anthropic_api_key = None
|
anthropic_api_key = get_from_settings_dialog_or_env(
|
||||||
if "anthropicApiKey" in params and params["anthropicApiKey"]:
|
params, "anthropicApiKey", ANTHROPIC_API_KEY
|
||||||
anthropic_api_key = params["anthropicApiKey"]
|
)
|
||||||
print("Using Anthropic API key from client-side settings dialog")
|
|
||||||
else:
|
|
||||||
anthropic_api_key = ANTHROPIC_API_KEY
|
|
||||||
if anthropic_api_key:
|
|
||||||
print("Using Anthropic API key from environment variable")
|
|
||||||
|
|
||||||
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
|
# Base URL for OpenAI API
|
||||||
openai_base_url: Union[str, None] = None
|
openai_base_url: str | None = None
|
||||||
# Disable user-specified OpenAI Base URL in prod
|
# Disable user-specified OpenAI Base URL in prod
|
||||||
if not os.environ.get("IS_PROD"):
|
if not IS_PROD:
|
||||||
if "openAiBaseURL" in params and params["openAiBaseURL"]:
|
openai_base_url = get_from_settings_dialog_or_env(
|
||||||
openai_base_url = params["openAiBaseURL"]
|
params, "openAiBaseURL", OPENAI_BASE_URL
|
||||||
print("Using OpenAI Base URL from client-side settings dialog")
|
)
|
||||||
else:
|
|
||||||
openai_base_url = os.environ.get("OPENAI_BASE_URL")
|
|
||||||
if openai_base_url:
|
|
||||||
print("Using OpenAI Base URL from environment variable")
|
|
||||||
|
|
||||||
if not openai_base_url:
|
if not openai_base_url:
|
||||||
print("Using official OpenAI URL")
|
print("Using official OpenAI URL")
|
||||||
|
|
||||||
# Get the image generation flag from the request. Fall back to True if not provided.
|
# Get the image generation flag from the request. Fall back to True if not provided.
|
||||||
should_generate_images = (
|
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
|
||||||
params["isImageGenerationEnabled"]
|
|
||||||
if "isImageGenerationEnabled" in params
|
return ExtractedParams(
|
||||||
else True
|
stack=validated_stack,
|
||||||
|
input_mode=validated_input_mode,
|
||||||
|
code_generation_model=code_generation_model,
|
||||||
|
should_generate_images=should_generate_images,
|
||||||
|
openai_api_key=openai_api_key,
|
||||||
|
anthropic_api_key=anthropic_api_key,
|
||||||
|
openai_base_url=openai_base_url,
|
||||||
|
payment_method=payment_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("generating code...")
|
|
||||||
await websocket.send_json({"type": "status", "value": "Generating code..."})
|
|
||||||
|
|
||||||
async def process_chunk(content: str):
|
def get_from_settings_dialog_or_env(
|
||||||
await websocket.send_json({"type": "chunk", "value": content})
|
params: dict[str, str], key: str, env_var: str | None
|
||||||
|
) -> str | None:
|
||||||
|
value = params.get(key)
|
||||||
|
if value:
|
||||||
|
print(f"Using {key} from client-side settings dialog")
|
||||||
|
return value
|
||||||
|
|
||||||
|
if env_var:
|
||||||
|
print(f"Using {key} from environment variable")
|
||||||
|
return env_var
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/generate-code")
|
||||||
|
async def stream_code(websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
print("Incoming websocket connection...")
|
||||||
|
|
||||||
|
## Communication protocol setup
|
||||||
|
async def throw_error(
|
||||||
|
message: str,
|
||||||
|
):
|
||||||
|
print(message)
|
||||||
|
await websocket.send_json({"type": "error", "value": message})
|
||||||
|
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
type: Literal["chunk", "status", "setCode", "error"],
|
||||||
|
value: str,
|
||||||
|
variantIndex: int,
|
||||||
|
):
|
||||||
|
# Print for debugging on the backend
|
||||||
|
if type == "error":
|
||||||
|
print(f"Error (variant {variantIndex}): {value}")
|
||||||
|
elif type == "status":
|
||||||
|
print(f"Status (variant {variantIndex}): {value}")
|
||||||
|
|
||||||
|
await websocket.send_json(
|
||||||
|
{"type": type, "value": value, "variantIndex": variantIndex}
|
||||||
|
)
|
||||||
|
|
||||||
|
## Parameter extract and validation
|
||||||
|
|
||||||
|
# TODO: Are the values always strings?
|
||||||
|
params: dict[str, str] = await websocket.receive_json()
|
||||||
|
print("Received params")
|
||||||
|
|
||||||
|
extracted_params = await extract_params(params, throw_error)
|
||||||
|
stack = extracted_params.stack
|
||||||
|
input_mode = extracted_params.input_mode
|
||||||
|
code_generation_model = extracted_params.code_generation_model
|
||||||
|
openai_api_key = extracted_params.openai_api_key
|
||||||
|
openai_base_url = extracted_params.openai_base_url
|
||||||
|
anthropic_api_key = extracted_params.anthropic_api_key
|
||||||
|
should_generate_images = extracted_params.should_generate_images
|
||||||
|
payment_method = extracted_params.payment_method
|
||||||
|
|
||||||
|
# Auto-upgrade usage of older models
|
||||||
|
code_generation_model = auto_upgrade_model(code_generation_model)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Generating {stack} code in {input_mode} mode using {code_generation_model}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(NUM_VARIANTS):
|
||||||
|
await send_message("status", "Generating code...", i)
|
||||||
|
|
||||||
|
### Prompt creation
|
||||||
|
|
||||||
# Image cache for updates so that we don't have to regenerate images
|
# Image cache for updates so that we don't have to regenerate images
|
||||||
image_cache: Dict[str, str] = {}
|
image_cache: Dict[str, str] = {}
|
||||||
|
|
||||||
# If this generation started off with imported code, we need to assemble the prompt differently
|
|
||||||
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
|
|
||||||
original_imported_code = params["history"][0]
|
|
||||||
prompt_messages = assemble_imported_code_prompt(
|
|
||||||
original_imported_code, valid_stack, code_generation_model
|
|
||||||
)
|
|
||||||
for index, text in enumerate(params["history"][1:]):
|
|
||||||
# TODO: Remove after "Select and edit" is fully implemented
|
|
||||||
if "referring to this element specifically" in text:
|
|
||||||
sentry_sdk.capture_exception(Exception("Point and edit used"))
|
|
||||||
|
|
||||||
if index % 2 == 0:
|
|
||||||
message: ChatCompletionMessageParam = {
|
|
||||||
"role": "user",
|
|
||||||
"content": text,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
message: ChatCompletionMessageParam = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": text,
|
|
||||||
}
|
|
||||||
prompt_messages.append(message)
|
|
||||||
else:
|
|
||||||
# Assemble the prompt
|
|
||||||
try:
|
try:
|
||||||
if validated_input_mode == "image":
|
prompt_messages, image_cache = await create_prompt(params, stack, input_mode)
|
||||||
if params.get("resultImage") and params["resultImage"]:
|
|
||||||
prompt_messages = assemble_prompt(
|
|
||||||
params["image"], valid_stack, params["resultImage"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt_messages = assemble_prompt(params["image"], valid_stack)
|
|
||||||
elif validated_input_mode == "text":
|
|
||||||
sentry_sdk.capture_exception(Exception("Text generation used"))
|
|
||||||
prompt_messages = assemble_text_prompt(params["image"], valid_stack)
|
|
||||||
else:
|
|
||||||
await throw_error("Invalid input mode")
|
|
||||||
return
|
|
||||||
except:
|
except:
|
||||||
await websocket.send_json(
|
await throw_error(
|
||||||
{
|
"Error assembling prompt. Contact support at support@picoapps.xyz"
|
||||||
"type": "error",
|
|
||||||
"value": "Error assembling prompt. Contact support at support@picoapps.xyz",
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
await websocket.close()
|
raise
|
||||||
return
|
|
||||||
|
|
||||||
# Transform the history tree into message format for updates
|
|
||||||
if params["generationType"] == "update":
|
|
||||||
# TODO: Move this to frontend
|
|
||||||
for index, text in enumerate(params["history"]):
|
|
||||||
if index % 2 == 0:
|
|
||||||
message: ChatCompletionMessageParam = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": text,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
message: ChatCompletionMessageParam = {
|
|
||||||
"role": "user",
|
|
||||||
"content": text,
|
|
||||||
}
|
|
||||||
prompt_messages.append(message)
|
|
||||||
|
|
||||||
image_cache = create_alt_url_mapping(params["history"][-2])
|
|
||||||
|
|
||||||
if validated_input_mode == "video":
|
|
||||||
video_data_url = params["image"]
|
|
||||||
prompt_messages = await assemble_claude_prompt_video(video_data_url)
|
|
||||||
|
|
||||||
# pprint_prompt(prompt_messages) # type: ignore
|
# pprint_prompt(prompt_messages) # type: ignore
|
||||||
|
|
||||||
|
### Code generation
|
||||||
|
|
||||||
|
async def process_chunk(content: str, variantIndex: int):
|
||||||
|
await send_message("chunk", content, variantIndex)
|
||||||
|
|
||||||
if SHOULD_MOCK_AI_RESPONSE:
|
if SHOULD_MOCK_AI_RESPONSE:
|
||||||
completion = await mock_completion(
|
completions = [await mock_completion(process_chunk, input_mode=input_mode)]
|
||||||
process_chunk, input_mode=validated_input_mode
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
if validated_input_mode == "video":
|
if input_mode == "video":
|
||||||
if IS_PROD:
|
if IS_PROD:
|
||||||
raise Exception("Video mode is not supported in prod")
|
raise Exception("Video mode is not supported in prod")
|
||||||
|
|
||||||
@ -311,48 +316,66 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
raise Exception("No Anthropic key")
|
raise Exception("No Anthropic key")
|
||||||
|
|
||||||
completion = await stream_claude_response_native(
|
completions = [
|
||||||
|
await stream_claude_response_native(
|
||||||
system_prompt=VIDEO_PROMPT,
|
system_prompt=VIDEO_PROMPT,
|
||||||
messages=prompt_messages, # type: ignore
|
messages=prompt_messages, # type: ignore
|
||||||
api_key=anthropic_api_key,
|
api_key=anthropic_api_key,
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x: process_chunk(x, 0),
|
||||||
model=Llm.CLAUDE_3_OPUS,
|
model=Llm.CLAUDE_3_OPUS,
|
||||||
include_thinking=True,
|
include_thinking=True,
|
||||||
)
|
)
|
||||||
exact_llm_version = Llm.CLAUDE_3_OPUS
|
]
|
||||||
elif (
|
|
||||||
code_generation_model == Llm.CLAUDE_3_SONNET
|
|
||||||
or code_generation_model == Llm.CLAUDE_3_5_SONNET_2024_06_20
|
|
||||||
):
|
|
||||||
if not anthropic_api_key:
|
|
||||||
await throw_error(
|
|
||||||
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
|
|
||||||
)
|
|
||||||
raise Exception("No Anthropic key")
|
|
||||||
|
|
||||||
# Do not allow non-subscribers to use Claude
|
|
||||||
if payment_method != PaymentMethod.SUBSCRIPTION:
|
|
||||||
await throw_error(
|
|
||||||
"Please subscribe to a paid plan to use the Claude models"
|
|
||||||
)
|
|
||||||
raise Exception("Not subscribed to a paid plan for Claude")
|
|
||||||
|
|
||||||
completion = await stream_claude_response(
|
|
||||||
prompt_messages, # type: ignore
|
|
||||||
api_key=anthropic_api_key,
|
|
||||||
callback=lambda x: process_chunk(x),
|
|
||||||
model=code_generation_model,
|
|
||||||
)
|
|
||||||
exact_llm_version = code_generation_model
|
|
||||||
else:
|
else:
|
||||||
completion = await stream_openai_response(
|
|
||||||
prompt_messages, # type: ignore
|
# Depending on the presence and absence of various keys,
|
||||||
|
# we decide which models to run
|
||||||
|
variant_models = []
|
||||||
|
if openai_api_key and anthropic_api_key:
|
||||||
|
variant_models = ["openai", "anthropic"]
|
||||||
|
elif openai_api_key:
|
||||||
|
variant_models = ["openai", "openai"]
|
||||||
|
elif anthropic_api_key:
|
||||||
|
variant_models = ["anthropic", "anthropic"]
|
||||||
|
else:
|
||||||
|
await throw_error(
|
||||||
|
"No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog. If you add it to .env, make sure to restart the backend server."
|
||||||
|
)
|
||||||
|
raise Exception("No OpenAI or Anthropic key")
|
||||||
|
|
||||||
|
tasks: List[Coroutine[Any, Any, str]] = []
|
||||||
|
for index, model in enumerate(variant_models):
|
||||||
|
if model == "openai":
|
||||||
|
if openai_api_key is None:
|
||||||
|
await throw_error("OpenAI API key is missing.")
|
||||||
|
raise Exception("OpenAI API key is missing.")
|
||||||
|
|
||||||
|
tasks.append(
|
||||||
|
stream_openai_response(
|
||||||
|
prompt_messages,
|
||||||
api_key=openai_api_key,
|
api_key=openai_api_key,
|
||||||
base_url=openai_base_url,
|
base_url=openai_base_url,
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x, i=index: process_chunk(x, i),
|
||||||
model=code_generation_model,
|
model=Llm.GPT_4O_2024_05_13,
|
||||||
)
|
)
|
||||||
exact_llm_version = code_generation_model
|
)
|
||||||
|
elif model == "anthropic":
|
||||||
|
if anthropic_api_key is None:
|
||||||
|
await throw_error("Anthropic API key is missing.")
|
||||||
|
raise Exception("Anthropic API key is missing.")
|
||||||
|
|
||||||
|
tasks.append(
|
||||||
|
stream_claude_response(
|
||||||
|
prompt_messages,
|
||||||
|
api_key=anthropic_api_key,
|
||||||
|
callback=lambda x, i=index: process_chunk(x, i),
|
||||||
|
model=Llm.CLAUDE_3_5_SONNET_2024_06_20,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
completions = await asyncio.gather(*tasks)
|
||||||
|
print("Models used for generation: ", variant_models)
|
||||||
|
|
||||||
except openai.AuthenticationError as e:
|
except openai.AuthenticationError as e:
|
||||||
print("[GENERATE_CODE] Authentication failed", e)
|
print("[GENERATE_CODE] Authentication failed", e)
|
||||||
error_message = (
|
error_message = (
|
||||||
@ -388,13 +411,10 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
return await throw_error(error_message)
|
return await throw_error(error_message)
|
||||||
|
|
||||||
if validated_input_mode == "video":
|
## Post-processing
|
||||||
completion = extract_tag_content("html", completion)
|
|
||||||
|
|
||||||
print("Exact used model for generation: ", exact_llm_version)
|
|
||||||
|
|
||||||
# Strip the completion of everything except the HTML content
|
# Strip the completion of everything except the HTML content
|
||||||
completion = extract_html_content(completion)
|
completions = [extract_html_content(completion) for completion in completions]
|
||||||
|
|
||||||
# Write the messages dict into a log so that we can debug later
|
# Write the messages dict into a log so that we can debug later
|
||||||
# write_logs(prompt_messages, completion) # type: ignore
|
# write_logs(prompt_messages, completion) # type: ignore
|
||||||
@ -402,54 +422,44 @@ async def stream_code(websocket: WebSocket):
|
|||||||
if IS_PROD:
|
if IS_PROD:
|
||||||
# Catch any errors from sending to SaaS backend and continue
|
# Catch any errors from sending to SaaS backend and continue
|
||||||
try:
|
try:
|
||||||
assert exact_llm_version is not None, "exact_llm_version is not set"
|
# TODO*
|
||||||
|
# assert exact_llm_version is not None, "exact_llm_version is not set"
|
||||||
await send_to_saas_backend(
|
await send_to_saas_backend(
|
||||||
prompt_messages,
|
prompt_messages,
|
||||||
completion,
|
# TODO*: Store both completions
|
||||||
|
completions[0],
|
||||||
payment_method=payment_method,
|
payment_method=payment_method,
|
||||||
llm_version=exact_llm_version,
|
# TODO*
|
||||||
stack=valid_stack,
|
llm_version=Llm.CLAUDE_3_5_SONNET_2024_06_20,
|
||||||
|
stack=stack,
|
||||||
is_imported_from_code=bool(params.get("isImportedFromCode", False)),
|
is_imported_from_code=bool(params.get("isImportedFromCode", False)),
|
||||||
includes_result_image=bool(params.get("resultImage", False)),
|
includes_result_image=bool(params.get("resultImage", False)),
|
||||||
input_mode=validated_input_mode,
|
input_mode=input_mode,
|
||||||
auth_token=params["authToken"],
|
auth_token=params["authToken"],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error sending to SaaS backend", e)
|
print("Error sending to SaaS backend", e)
|
||||||
|
|
||||||
try:
|
## Image Generation
|
||||||
if should_generate_images:
|
|
||||||
await websocket.send_json(
|
|
||||||
{"type": "status", "value": "Generating images..."}
|
|
||||||
)
|
|
||||||
image_generation_model = "sdxl-lightning" if REPLICATE_API_KEY else "dalle3"
|
|
||||||
print("Generating images with model: ", image_generation_model)
|
|
||||||
|
|
||||||
updated_html = await generate_images(
|
for index, _ in enumerate(completions):
|
||||||
|
await send_message("status", "Generating images...", index)
|
||||||
|
|
||||||
|
image_generation_tasks = [
|
||||||
|
perform_image_generation(
|
||||||
completion,
|
completion,
|
||||||
api_key=(
|
should_generate_images,
|
||||||
REPLICATE_API_KEY
|
openai_api_key,
|
||||||
if image_generation_model == "sdxl-lightning"
|
openai_base_url,
|
||||||
else openai_api_key
|
image_cache,
|
||||||
),
|
|
||||||
base_url=openai_base_url,
|
|
||||||
image_cache=image_cache,
|
|
||||||
model=image_generation_model,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
updated_html = completion
|
|
||||||
await websocket.send_json({"type": "setCode", "value": updated_html})
|
|
||||||
await websocket.send_json(
|
|
||||||
{"type": "status", "value": "Code generation complete."}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
print("Image generation failed", e)
|
|
||||||
# Send set code even if image generation fails since that triggers
|
|
||||||
# the frontend to update history
|
|
||||||
await websocket.send_json({"type": "setCode", "value": completion})
|
|
||||||
await websocket.send_json(
|
|
||||||
{"type": "status", "value": "Image generation failed but code is complete."}
|
|
||||||
)
|
)
|
||||||
|
for completion in completions
|
||||||
|
]
|
||||||
|
|
||||||
|
updated_completions = await asyncio.gather(*image_generation_tasks)
|
||||||
|
|
||||||
|
for index, updated_html in enumerate(updated_completions):
|
||||||
|
await send_message("setCode", updated_html, index)
|
||||||
|
await send_message("status", "Code generation complete.", index)
|
||||||
|
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|||||||
@ -43,6 +43,7 @@
|
|||||||
"copy-to-clipboard": "^3.3.3",
|
"copy-to-clipboard": "^3.3.3",
|
||||||
"html2canvas": "^1.4.1",
|
"html2canvas": "^1.4.1",
|
||||||
"posthog-js": "^1.128.1",
|
"posthog-js": "^1.128.1",
|
||||||
|
"nanoid": "^5.0.7",
|
||||||
"react": "^18.2.0",
|
"react": "^18.2.0",
|
||||||
"react-dom": "^18.2.0",
|
"react-dom": "^18.2.0",
|
||||||
"react-dropzone": "^14.2.3",
|
"react-dropzone": "^14.2.3",
|
||||||
|
|||||||
@ -9,8 +9,7 @@ import { usePersistedState } from "./hooks/usePersistedState";
|
|||||||
import TermsOfServiceDialog from "./components/TermsOfServiceDialog";
|
import TermsOfServiceDialog from "./components/TermsOfServiceDialog";
|
||||||
import { USER_CLOSE_WEB_SOCKET_CODE } from "./constants";
|
import { USER_CLOSE_WEB_SOCKET_CODE } from "./constants";
|
||||||
import { addEvent } from "./lib/analytics";
|
import { addEvent } from "./lib/analytics";
|
||||||
import { History } from "./components/history/history_types";
|
import { extractHistory } from "./components/history/utils";
|
||||||
import { extractHistoryTree } from "./components/history/utils";
|
|
||||||
import toast from "react-hot-toast";
|
import toast from "react-hot-toast";
|
||||||
import { useAuth } from "@clerk/clerk-react";
|
import { useAuth } from "@clerk/clerk-react";
|
||||||
import { useStore } from "./store/store";
|
import { useStore } from "./store/store";
|
||||||
@ -27,6 +26,8 @@ import { GenerationSettings } from "./components/settings/GenerationSettings";
|
|||||||
import StartPane from "./components/start-pane/StartPane";
|
import StartPane from "./components/start-pane/StartPane";
|
||||||
import { takeScreenshot } from "./lib/takeScreenshot";
|
import { takeScreenshot } from "./lib/takeScreenshot";
|
||||||
import Sidebar from "./components/sidebar/Sidebar";
|
import Sidebar from "./components/sidebar/Sidebar";
|
||||||
|
import { Commit } from "./components/commits/types";
|
||||||
|
import { createCommit } from "./components/commits/utils";
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
navbarComponent?: JSX.Element;
|
navbarComponent?: JSX.Element;
|
||||||
@ -49,13 +50,19 @@ function App({ navbarComponent }: Props) {
|
|||||||
referenceImages,
|
referenceImages,
|
||||||
setReferenceImages,
|
setReferenceImages,
|
||||||
|
|
||||||
|
head,
|
||||||
|
commits,
|
||||||
|
addCommit,
|
||||||
|
removeCommit,
|
||||||
|
setHead,
|
||||||
|
appendCommitCode,
|
||||||
|
setCommitCode,
|
||||||
|
resetCommits,
|
||||||
|
resetHead,
|
||||||
|
|
||||||
// Outputs
|
// Outputs
|
||||||
setGeneratedCode,
|
appendExecutionConsole,
|
||||||
setExecutionConsole,
|
resetExecutionConsoles,
|
||||||
currentVersion,
|
|
||||||
setCurrentVersion,
|
|
||||||
appHistory,
|
|
||||||
setAppHistory,
|
|
||||||
} = useProjectStore();
|
} = useProjectStore();
|
||||||
|
|
||||||
const {
|
const {
|
||||||
@ -119,32 +126,31 @@ function App({ navbarComponent }: Props) {
|
|||||||
// Functions
|
// Functions
|
||||||
const reset = () => {
|
const reset = () => {
|
||||||
setAppState(AppState.INITIAL);
|
setAppState(AppState.INITIAL);
|
||||||
setGeneratedCode("");
|
|
||||||
setReferenceImages([]);
|
|
||||||
setInitialPrompt("");
|
|
||||||
setExecutionConsole([]);
|
|
||||||
setUpdateInstruction("");
|
|
||||||
setIsImportedFromCode(false);
|
|
||||||
setAppHistory([]);
|
|
||||||
setCurrentVersion(null);
|
|
||||||
setShouldIncludeResultImage(false);
|
setShouldIncludeResultImage(false);
|
||||||
|
setUpdateInstruction("");
|
||||||
disableInSelectAndEditMode();
|
disableInSelectAndEditMode();
|
||||||
|
resetExecutionConsoles();
|
||||||
|
|
||||||
|
resetCommits();
|
||||||
|
resetHead();
|
||||||
|
|
||||||
|
// Inputs
|
||||||
|
setInputMode("image");
|
||||||
|
setReferenceImages([]);
|
||||||
|
setIsImportedFromCode(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
const regenerate = () => {
|
const regenerate = () => {
|
||||||
if (currentVersion === null) {
|
if (head === null) {
|
||||||
// This would be a error that I log to Sentry
|
|
||||||
addEvent("RegenerateCurrentVersionNull");
|
|
||||||
toast.error(
|
toast.error(
|
||||||
"No current version set. Please open a Github issue as this shouldn't happen."
|
"No current version set. Please contact support via chat or Github."
|
||||||
);
|
);
|
||||||
return;
|
throw new Error("Regenerate called with no head");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve the previous command
|
// Retrieve the previous command
|
||||||
const previousCommand = appHistory[currentVersion];
|
const currentCommit = commits[head];
|
||||||
if (previousCommand.type !== "ai_create") {
|
if (currentCommit.type !== "ai_create") {
|
||||||
addEvent("RegenerateNotFirstVersion");
|
|
||||||
toast.error("Only the first version can be regenerated.");
|
toast.error("Only the first version can be regenerated.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -165,28 +171,32 @@ function App({ navbarComponent }: Props) {
|
|||||||
addEvent("Cancel");
|
addEvent("Cancel");
|
||||||
|
|
||||||
wsRef.current?.close?.(USER_CLOSE_WEB_SOCKET_CODE);
|
wsRef.current?.close?.(USER_CLOSE_WEB_SOCKET_CODE);
|
||||||
// make sure stop can correct the state even if the websocket is already closed
|
|
||||||
cancelCodeGenerationAndReset();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Used for code generation failure as well
|
// Used for code generation failure as well
|
||||||
const cancelCodeGenerationAndReset = () => {
|
const cancelCodeGenerationAndReset = (commit: Commit) => {
|
||||||
// When this is the first version, reset the entire app state
|
// When the current commit is the first version, reset the entire app state
|
||||||
if (currentVersion === null) {
|
if (commit.type === "ai_create") {
|
||||||
reset();
|
reset();
|
||||||
} else {
|
} else {
|
||||||
// Otherwise, revert to the last version
|
// Otherwise, remove current commit from commits
|
||||||
setGeneratedCode(appHistory[currentVersion].code);
|
removeCommit(commit.hash);
|
||||||
|
|
||||||
|
// Revert to parent commit
|
||||||
|
const parentCommitHash = commit.parentHash;
|
||||||
|
if (parentCommitHash) {
|
||||||
|
setHead(parentCommitHash);
|
||||||
|
} else {
|
||||||
|
throw new Error("Parent commit not found");
|
||||||
|
}
|
||||||
|
|
||||||
setAppState(AppState.CODE_READY);
|
setAppState(AppState.CODE_READY);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
async function doGenerateCode(
|
async function doGenerateCode(params: CodeGenerationParams) {
|
||||||
params: CodeGenerationParams,
|
|
||||||
parentVersion: number | null
|
|
||||||
) {
|
|
||||||
// Reset the execution console
|
// Reset the execution console
|
||||||
setExecutionConsole([]);
|
resetExecutionConsoles();
|
||||||
|
|
||||||
// Set the app state
|
// Set the app state
|
||||||
setAppState(AppState.CODING);
|
setAppState(AppState.CODING);
|
||||||
@ -199,69 +209,50 @@ function App({ navbarComponent }: Props) {
|
|||||||
authToken: authToken || undefined,
|
authToken: authToken || undefined,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const baseCommitObject = {
|
||||||
|
variants: [{ code: "" }, { code: "" }],
|
||||||
|
};
|
||||||
|
|
||||||
|
const commitInputObject =
|
||||||
|
params.generationType === "create"
|
||||||
|
? {
|
||||||
|
...baseCommitObject,
|
||||||
|
type: "ai_create" as const,
|
||||||
|
parentHash: null,
|
||||||
|
inputs: { image_url: referenceImages[0] },
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
...baseCommitObject,
|
||||||
|
type: "ai_edit" as const,
|
||||||
|
parentHash: head,
|
||||||
|
inputs: {
|
||||||
|
prompt: params.history
|
||||||
|
? params.history[params.history.length - 1]
|
||||||
|
: "",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create a new commit and set it as the head
|
||||||
|
const commit = createCommit(commitInputObject);
|
||||||
|
addCommit(commit);
|
||||||
|
setHead(commit.hash);
|
||||||
|
|
||||||
generateCode(
|
generateCode(
|
||||||
wsRef,
|
wsRef,
|
||||||
updatedParams,
|
updatedParams,
|
||||||
// On change
|
// On change
|
||||||
(token) => setGeneratedCode((prev) => prev + token),
|
(token, variantIndex) => {
|
||||||
|
appendCommitCode(commit.hash, variantIndex, token);
|
||||||
|
},
|
||||||
// On set code
|
// On set code
|
||||||
(code) => {
|
(code, variantIndex) => {
|
||||||
setGeneratedCode(code);
|
setCommitCode(commit.hash, variantIndex, code);
|
||||||
if (params.generationType === "create") {
|
|
||||||
if (inputMode === "image" || inputMode === "video") {
|
|
||||||
setAppHistory([
|
|
||||||
{
|
|
||||||
type: "ai_create",
|
|
||||||
parentIndex: null,
|
|
||||||
code,
|
|
||||||
inputs: { image_url: referenceImages[0] },
|
|
||||||
},
|
|
||||||
]);
|
|
||||||
} else {
|
|
||||||
setAppHistory([
|
|
||||||
{
|
|
||||||
type: "ai_create",
|
|
||||||
parentIndex: null,
|
|
||||||
code,
|
|
||||||
inputs: { text: params.image },
|
|
||||||
},
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
setCurrentVersion(0);
|
|
||||||
} else {
|
|
||||||
setAppHistory((prev) => {
|
|
||||||
// Validate parent version
|
|
||||||
if (parentVersion === null) {
|
|
||||||
toast.error(
|
|
||||||
"No parent version set. Contact support or open a Github issue."
|
|
||||||
);
|
|
||||||
addEvent("ParentVersionNull");
|
|
||||||
return prev;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newHistory: History = [
|
|
||||||
...prev,
|
|
||||||
{
|
|
||||||
type: "ai_edit",
|
|
||||||
parentIndex: parentVersion,
|
|
||||||
code,
|
|
||||||
inputs: {
|
|
||||||
prompt: params.history
|
|
||||||
? params.history[params.history.length - 1]
|
|
||||||
: "", // History should never be empty when performing an edit
|
|
||||||
},
|
|
||||||
},
|
|
||||||
];
|
|
||||||
setCurrentVersion(newHistory.length - 1);
|
|
||||||
return newHistory;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
// On status update
|
// On status update
|
||||||
(line) => setExecutionConsole((prev) => [...prev, line]),
|
(line, variantIndex) => appendExecutionConsole(variantIndex, line),
|
||||||
// On cancel
|
// On cancel
|
||||||
() => {
|
() => {
|
||||||
cancelCodeGenerationAndReset();
|
cancelCodeGenerationAndReset(commit);
|
||||||
},
|
},
|
||||||
// On complete
|
// On complete
|
||||||
() => {
|
() => {
|
||||||
@ -285,14 +276,11 @@ function App({ navbarComponent }: Props) {
|
|||||||
// Kick off the code generation
|
// Kick off the code generation
|
||||||
if (referenceImages.length > 0) {
|
if (referenceImages.length > 0) {
|
||||||
addEvent("Create");
|
addEvent("Create");
|
||||||
await doGenerateCode(
|
doGenerateCode({
|
||||||
{
|
|
||||||
generationType: "create",
|
generationType: "create",
|
||||||
image: referenceImages[0],
|
image: referenceImages[0],
|
||||||
inputMode,
|
inputMode,
|
||||||
},
|
});
|
||||||
currentVersion
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -302,14 +290,11 @@ function App({ navbarComponent }: Props) {
|
|||||||
|
|
||||||
setInputMode("text");
|
setInputMode("text");
|
||||||
setInitialPrompt(text);
|
setInitialPrompt(text);
|
||||||
doGenerateCode(
|
doGenerateCode({
|
||||||
{
|
|
||||||
generationType: "create",
|
generationType: "create",
|
||||||
inputMode: "text",
|
inputMode: "text",
|
||||||
image: text,
|
image: text,
|
||||||
},
|
});
|
||||||
currentVersion
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subsequent updates
|
// Subsequent updates
|
||||||
@ -322,23 +307,22 @@ function App({ navbarComponent }: Props) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (currentVersion === null) {
|
if (head === null) {
|
||||||
toast.error(
|
toast.error(
|
||||||
"No current version set. Contact support or open a Github issue."
|
"No current version set. Contact support or open a Github issue."
|
||||||
);
|
);
|
||||||
addEvent("CurrentVersionNull");
|
throw new Error("Update called with no head");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let historyTree;
|
let historyTree;
|
||||||
try {
|
try {
|
||||||
historyTree = extractHistoryTree(appHistory, currentVersion);
|
historyTree = extractHistory(head, commits);
|
||||||
} catch {
|
} catch {
|
||||||
addEvent("HistoryTreeFailed");
|
addEvent("HistoryTreeFailed");
|
||||||
toast.error(
|
toast.error(
|
||||||
"Version history is invalid. This shouldn't happen. Please contact support or open a Github issue."
|
"Version history is invalid. This shouldn't happen. Please contact support or open a Github issue."
|
||||||
);
|
);
|
||||||
return;
|
throw new Error("Invalid version history");
|
||||||
}
|
}
|
||||||
|
|
||||||
let modifiedUpdateInstruction = updateInstruction;
|
let modifiedUpdateInstruction = updateInstruction;
|
||||||
@ -352,34 +336,19 @@ function App({ navbarComponent }: Props) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const updatedHistory = [...historyTree, modifiedUpdateInstruction];
|
const updatedHistory = [...historyTree, modifiedUpdateInstruction];
|
||||||
|
const resultImage = shouldIncludeResultImage
|
||||||
|
? await takeScreenshot()
|
||||||
|
: undefined;
|
||||||
|
|
||||||
if (shouldIncludeResultImage) {
|
doGenerateCode({
|
||||||
const resultImage = await takeScreenshot();
|
|
||||||
await doGenerateCode(
|
|
||||||
{
|
|
||||||
generationType: "update",
|
generationType: "update",
|
||||||
inputMode,
|
inputMode,
|
||||||
image: referenceImages[0],
|
image: referenceImages[0],
|
||||||
resultImage: resultImage,
|
resultImage,
|
||||||
history: updatedHistory,
|
history: updatedHistory,
|
||||||
isImportedFromCode,
|
isImportedFromCode,
|
||||||
},
|
});
|
||||||
currentVersion
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
await doGenerateCode(
|
|
||||||
{
|
|
||||||
generationType: "update",
|
|
||||||
inputMode,
|
|
||||||
image: inputMode === "text" ? initialPrompt : referenceImages[0],
|
|
||||||
history: updatedHistory,
|
|
||||||
isImportedFromCode,
|
|
||||||
},
|
|
||||||
currentVersion
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
setGeneratedCode("");
|
|
||||||
setUpdateInstruction("");
|
setUpdateInstruction("");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -402,17 +371,17 @@ function App({ navbarComponent }: Props) {
|
|||||||
setIsImportedFromCode(true);
|
setIsImportedFromCode(true);
|
||||||
|
|
||||||
// Set up this project
|
// Set up this project
|
||||||
setGeneratedCode(code);
|
|
||||||
setStack(stack);
|
setStack(stack);
|
||||||
setAppHistory([
|
|
||||||
{
|
// Create a new commit and set it as the head
|
||||||
|
const commit = createCommit({
|
||||||
type: "code_create",
|
type: "code_create",
|
||||||
parentIndex: null,
|
parentHash: null,
|
||||||
code,
|
variants: [{ code }],
|
||||||
inputs: { code },
|
inputs: null,
|
||||||
},
|
});
|
||||||
]);
|
addCommit(commit);
|
||||||
setCurrentVersion(0);
|
setHead(commit.hash);
|
||||||
|
|
||||||
// Set the app state
|
// Set the app state
|
||||||
setAppState(AppState.CODE_READY);
|
setAppState(AppState.CODE_READY);
|
||||||
|
|||||||
37
frontend/src/components/commits/types.ts
Normal file
37
frontend/src/components/commits/types.ts
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
export type CommitHash = string;
|
||||||
|
|
||||||
|
export type Variant = {
|
||||||
|
code: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type BaseCommit = {
|
||||||
|
hash: CommitHash;
|
||||||
|
parentHash: CommitHash | null;
|
||||||
|
dateCreated: Date;
|
||||||
|
isCommitted: boolean;
|
||||||
|
variants: Variant[];
|
||||||
|
selectedVariantIndex: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type CommitType = "ai_create" | "ai_edit" | "code_create";
|
||||||
|
|
||||||
|
export type AiCreateCommit = BaseCommit & {
|
||||||
|
type: "ai_create";
|
||||||
|
inputs: {
|
||||||
|
image_url: string;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export type AiEditCommit = BaseCommit & {
|
||||||
|
type: "ai_edit";
|
||||||
|
inputs: {
|
||||||
|
prompt: string;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export type CodeCreateCommit = BaseCommit & {
|
||||||
|
type: "code_create";
|
||||||
|
inputs: null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type Commit = AiCreateCommit | AiEditCommit | CodeCreateCommit;
|
||||||
32
frontend/src/components/commits/utils.ts
Normal file
32
frontend/src/components/commits/utils.ts
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import { nanoid } from "nanoid";
|
||||||
|
import {
|
||||||
|
AiCreateCommit,
|
||||||
|
AiEditCommit,
|
||||||
|
CodeCreateCommit,
|
||||||
|
Commit,
|
||||||
|
} from "./types";
|
||||||
|
|
||||||
|
export function createCommit(
|
||||||
|
commit:
|
||||||
|
| Omit<
|
||||||
|
AiCreateCommit,
|
||||||
|
"hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted"
|
||||||
|
>
|
||||||
|
| Omit<
|
||||||
|
AiEditCommit,
|
||||||
|
"hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted"
|
||||||
|
>
|
||||||
|
| Omit<
|
||||||
|
CodeCreateCommit,
|
||||||
|
"hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted"
|
||||||
|
>
|
||||||
|
): Commit {
|
||||||
|
const hash = nanoid();
|
||||||
|
return {
|
||||||
|
...commit,
|
||||||
|
hash,
|
||||||
|
isCommitted: false,
|
||||||
|
dateCreated: new Date(),
|
||||||
|
selectedVariantIndex: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
@ -17,19 +17,16 @@ interface Props {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function HistoryDisplay({ shouldDisableReverts }: Props) {
|
export default function HistoryDisplay({ shouldDisableReverts }: Props) {
|
||||||
const {
|
const { commits, head, setHead } = useProjectStore();
|
||||||
appHistory: history,
|
|
||||||
currentVersion,
|
|
||||||
setCurrentVersion,
|
|
||||||
setGeneratedCode,
|
|
||||||
} = useProjectStore();
|
|
||||||
const renderedHistory = renderHistory(history, currentVersion);
|
|
||||||
|
|
||||||
const revertToVersion = (index: number) => {
|
// Put all commits into an array and sort by created date (oldest first)
|
||||||
if (index < 0 || index >= history.length || !history[index]) return;
|
const flatHistory = Object.values(commits).sort(
|
||||||
setCurrentVersion(index);
|
(a, b) =>
|
||||||
setGeneratedCode(history[index].code);
|
new Date(a.dateCreated).getTime() - new Date(b.dateCreated).getTime()
|
||||||
};
|
);
|
||||||
|
|
||||||
|
// Annotate history items with a summary, parent version, etc.
|
||||||
|
const renderedHistory = renderHistory(flatHistory);
|
||||||
|
|
||||||
return renderedHistory.length === 0 ? null : (
|
return renderedHistory.length === 0 ? null : (
|
||||||
<div className="flex flex-col h-screen">
|
<div className="flex flex-col h-screen">
|
||||||
@ -43,8 +40,8 @@ export default function HistoryDisplay({ shouldDisableReverts }: Props) {
|
|||||||
"flex items-center justify-between space-x-2 w-full pr-2",
|
"flex items-center justify-between space-x-2 w-full pr-2",
|
||||||
"border-b cursor-pointer",
|
"border-b cursor-pointer",
|
||||||
{
|
{
|
||||||
" hover:bg-black hover:text-white": !item.isActive,
|
" hover:bg-black hover:text-white": item.hash === head,
|
||||||
"bg-slate-500 text-white": item.isActive,
|
"bg-slate-500 text-white": item.hash === head,
|
||||||
}
|
}
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
@ -55,14 +52,14 @@ export default function HistoryDisplay({ shouldDisableReverts }: Props) {
|
|||||||
? toast.error(
|
? toast.error(
|
||||||
"Please wait for code generation to complete before viewing an older version."
|
"Please wait for code generation to complete before viewing an older version."
|
||||||
)
|
)
|
||||||
: revertToVersion(index)
|
: setHead(item.hash)
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
<div className="flex gap-x-1 truncate">
|
<div className="flex gap-x-1 truncate">
|
||||||
<h2 className="text-sm truncate">{item.summary}</h2>
|
<h2 className="text-sm truncate">{item.summary}</h2>
|
||||||
{item.parentVersion !== null && (
|
{item.parentVersion !== null && (
|
||||||
<h2 className="text-sm">
|
<h2 className="text-sm">
|
||||||
(parent: {item.parentVersion})
|
(parent: v{item.parentVersion})
|
||||||
</h2>
|
</h2>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -1,45 +0,0 @@
|
|||||||
export type HistoryItemType = "ai_create" | "ai_edit" | "code_create";
|
|
||||||
|
|
||||||
type CommonHistoryItem = {
|
|
||||||
parentIndex: null | number;
|
|
||||||
code: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type HistoryItem =
|
|
||||||
| ({
|
|
||||||
type: "ai_create";
|
|
||||||
inputs: AiCreateInputs | AiCreateInputsText;
|
|
||||||
} & CommonHistoryItem)
|
|
||||||
| ({
|
|
||||||
type: "ai_edit";
|
|
||||||
inputs: AiEditInputs;
|
|
||||||
} & CommonHistoryItem)
|
|
||||||
| ({
|
|
||||||
type: "code_create";
|
|
||||||
inputs: CodeCreateInputs;
|
|
||||||
} & CommonHistoryItem);
|
|
||||||
|
|
||||||
export type AiCreateInputs = {
|
|
||||||
image_url: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type AiCreateInputsText = {
|
|
||||||
text: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type AiEditInputs = {
|
|
||||||
prompt: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type CodeCreateInputs = {
|
|
||||||
code: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type History = HistoryItem[];
|
|
||||||
|
|
||||||
export type RenderedHistoryItem = {
|
|
||||||
type: string;
|
|
||||||
summary: string;
|
|
||||||
parentVersion: string | null;
|
|
||||||
isActive: boolean;
|
|
||||||
};
|
|
||||||
@ -1,91 +1,125 @@
|
|||||||
import { extractHistoryTree, renderHistory } from "./utils";
|
import { extractHistory, renderHistory } from "./utils";
|
||||||
import type { History } from "./history_types";
|
import { Commit, CommitHash } from "../commits/types";
|
||||||
|
|
||||||
const basicLinearHistory: History = [
|
const basicLinearHistory: Record<CommitHash, Commit> = {
|
||||||
{
|
"0": {
|
||||||
|
hash: "0",
|
||||||
|
dateCreated: new Date(),
|
||||||
|
isCommitted: false,
|
||||||
type: "ai_create",
|
type: "ai_create",
|
||||||
parentIndex: null,
|
parentHash: null,
|
||||||
code: "<html>1. create</html>",
|
variants: [{ code: "<html>1. create</html>" }],
|
||||||
|
selectedVariantIndex: 0,
|
||||||
inputs: {
|
inputs: {
|
||||||
image_url: "",
|
image_url: "",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
"1": {
|
||||||
|
hash: "1",
|
||||||
|
dateCreated: new Date(),
|
||||||
|
isCommitted: false,
|
||||||
type: "ai_edit",
|
type: "ai_edit",
|
||||||
parentIndex: 0,
|
parentHash: "0",
|
||||||
code: "<html>2. edit with better icons</html>",
|
variants: [{ code: "<html>2. edit with better icons</html>" }],
|
||||||
|
selectedVariantIndex: 0,
|
||||||
inputs: {
|
inputs: {
|
||||||
prompt: "use better icons",
|
prompt: "use better icons",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
"2": {
|
||||||
|
hash: "2",
|
||||||
|
dateCreated: new Date(),
|
||||||
|
isCommitted: false,
|
||||||
type: "ai_edit",
|
type: "ai_edit",
|
||||||
parentIndex: 1,
|
parentHash: "1",
|
||||||
code: "<html>3. edit with better icons and red text</html>",
|
variants: [{ code: "<html>3. edit with better icons and red text</html>" }],
|
||||||
|
selectedVariantIndex: 0,
|
||||||
inputs: {
|
inputs: {
|
||||||
prompt: "make text red",
|
prompt: "make text red",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
];
|
};
|
||||||
|
|
||||||
const basicLinearHistoryWithCode: History = [
|
const basicLinearHistoryWithCode: Record<CommitHash, Commit> = {
|
||||||
{
|
"0": {
|
||||||
|
hash: "0",
|
||||||
|
dateCreated: new Date(),
|
||||||
|
isCommitted: false,
|
||||||
type: "code_create",
|
type: "code_create",
|
||||||
parentIndex: null,
|
parentHash: null,
|
||||||
code: "<html>1. create</html>",
|
variants: [{ code: "<html>1. create</html>" }],
|
||||||
inputs: {
|
selectedVariantIndex: 0,
|
||||||
code: "<html>1. create</html>",
|
inputs: null,
|
||||||
},
|
},
|
||||||
},
|
...Object.fromEntries(Object.entries(basicLinearHistory).slice(1)),
|
||||||
...basicLinearHistory.slice(1),
|
};
|
||||||
];
|
|
||||||
|
|
||||||
const basicBranchingHistory: History = [
|
const basicBranchingHistory: Record<CommitHash, Commit> = {
|
||||||
...basicLinearHistory,
|
...basicLinearHistory,
|
||||||
{
|
"3": {
|
||||||
|
hash: "3",
|
||||||
|
dateCreated: new Date(),
|
||||||
|
isCommitted: false,
|
||||||
type: "ai_edit",
|
type: "ai_edit",
|
||||||
parentIndex: 1,
|
parentHash: "1",
|
||||||
code: "<html>4. edit with better icons and green text</html>",
|
variants: [
|
||||||
|
{ code: "<html>4. edit with better icons and green text</html>" },
|
||||||
|
],
|
||||||
|
selectedVariantIndex: 0,
|
||||||
inputs: {
|
inputs: {
|
||||||
prompt: "make text green",
|
prompt: "make text green",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
];
|
};
|
||||||
|
|
||||||
const longerBranchingHistory: History = [
|
const longerBranchingHistory: Record<CommitHash, Commit> = {
|
||||||
...basicBranchingHistory,
|
...basicBranchingHistory,
|
||||||
{
|
"4": {
|
||||||
|
hash: "4",
|
||||||
|
dateCreated: new Date(),
|
||||||
|
isCommitted: false,
|
||||||
type: "ai_edit",
|
type: "ai_edit",
|
||||||
parentIndex: 3,
|
parentHash: "3",
|
||||||
code: "<html>5. edit with better icons and green, bold text</html>",
|
variants: [
|
||||||
|
{ code: "<html>5. edit with better icons and green, bold text</html>" },
|
||||||
|
],
|
||||||
|
selectedVariantIndex: 0,
|
||||||
inputs: {
|
inputs: {
|
||||||
prompt: "make text bold",
|
prompt: "make text bold",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
];
|
};
|
||||||
|
|
||||||
const basicBadHistory: History = [
|
const basicBadHistory: Record<CommitHash, Commit> = {
|
||||||
{
|
"0": {
|
||||||
|
hash: "0",
|
||||||
|
dateCreated: new Date(),
|
||||||
|
isCommitted: false,
|
||||||
type: "ai_create",
|
type: "ai_create",
|
||||||
parentIndex: null,
|
parentHash: null,
|
||||||
code: "<html>1. create</html>",
|
variants: [{ code: "<html>1. create</html>" }],
|
||||||
|
selectedVariantIndex: 0,
|
||||||
inputs: {
|
inputs: {
|
||||||
image_url: "",
|
image_url: "",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
"1": {
|
||||||
|
hash: "1",
|
||||||
|
dateCreated: new Date(),
|
||||||
|
isCommitted: false,
|
||||||
type: "ai_edit",
|
type: "ai_edit",
|
||||||
parentIndex: 2, // <- Bad parent index
|
parentHash: "2", // <- Bad parent hash
|
||||||
code: "<html>2. edit with better icons</html>",
|
variants: [{ code: "<html>2. edit with better icons</html>" }],
|
||||||
|
selectedVariantIndex: 0,
|
||||||
inputs: {
|
inputs: {
|
||||||
prompt: "use better icons",
|
prompt: "use better icons",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
];
|
};
|
||||||
|
|
||||||
describe("History Utils", () => {
|
describe("History Utils", () => {
|
||||||
test("should correctly extract the history tree", () => {
|
test("should correctly extract the history tree", () => {
|
||||||
expect(extractHistoryTree(basicLinearHistory, 2)).toEqual([
|
expect(extractHistory("2", basicLinearHistory)).toEqual([
|
||||||
"<html>1. create</html>",
|
"<html>1. create</html>",
|
||||||
"use better icons",
|
"use better icons",
|
||||||
"<html>2. edit with better icons</html>",
|
"<html>2. edit with better icons</html>",
|
||||||
@ -93,12 +127,12 @@ describe("History Utils", () => {
|
|||||||
"<html>3. edit with better icons and red text</html>",
|
"<html>3. edit with better icons and red text</html>",
|
||||||
]);
|
]);
|
||||||
|
|
||||||
expect(extractHistoryTree(basicLinearHistory, 0)).toEqual([
|
expect(extractHistory("0", basicLinearHistory)).toEqual([
|
||||||
"<html>1. create</html>",
|
"<html>1. create</html>",
|
||||||
]);
|
]);
|
||||||
|
|
||||||
// Test branching
|
// Test branching
|
||||||
expect(extractHistoryTree(basicBranchingHistory, 3)).toEqual([
|
expect(extractHistory("3", basicBranchingHistory)).toEqual([
|
||||||
"<html>1. create</html>",
|
"<html>1. create</html>",
|
||||||
"use better icons",
|
"use better icons",
|
||||||
"<html>2. edit with better icons</html>",
|
"<html>2. edit with better icons</html>",
|
||||||
@ -106,7 +140,7 @@ describe("History Utils", () => {
|
|||||||
"<html>4. edit with better icons and green text</html>",
|
"<html>4. edit with better icons and green text</html>",
|
||||||
]);
|
]);
|
||||||
|
|
||||||
expect(extractHistoryTree(longerBranchingHistory, 4)).toEqual([
|
expect(extractHistory("4", longerBranchingHistory)).toEqual([
|
||||||
"<html>1. create</html>",
|
"<html>1. create</html>",
|
||||||
"use better icons",
|
"use better icons",
|
||||||
"<html>2. edit with better icons</html>",
|
"<html>2. edit with better icons</html>",
|
||||||
@ -116,7 +150,7 @@ describe("History Utils", () => {
|
|||||||
"<html>5. edit with better icons and green, bold text</html>",
|
"<html>5. edit with better icons and green, bold text</html>",
|
||||||
]);
|
]);
|
||||||
|
|
||||||
expect(extractHistoryTree(longerBranchingHistory, 2)).toEqual([
|
expect(extractHistory("2", longerBranchingHistory)).toEqual([
|
||||||
"<html>1. create</html>",
|
"<html>1. create</html>",
|
||||||
"use better icons",
|
"use better icons",
|
||||||
"<html>2. edit with better icons</html>",
|
"<html>2. edit with better icons</html>",
|
||||||
@ -126,105 +160,82 @@ describe("History Utils", () => {
|
|||||||
|
|
||||||
// Errors
|
// Errors
|
||||||
|
|
||||||
// Bad index
|
// Bad hash
|
||||||
expect(() => extractHistoryTree(basicLinearHistory, 100)).toThrow();
|
expect(() => extractHistory("100", basicLinearHistory)).toThrow();
|
||||||
expect(() => extractHistoryTree(basicLinearHistory, -2)).toThrow();
|
|
||||||
|
|
||||||
// Bad tree
|
// Bad tree
|
||||||
expect(() => extractHistoryTree(basicBadHistory, 1)).toThrow();
|
expect(() => extractHistory("1", basicBadHistory)).toThrow();
|
||||||
});
|
});
|
||||||
|
|
||||||
test("should correctly render the history tree", () => {
|
test("should correctly render the history tree", () => {
|
||||||
expect(renderHistory(basicLinearHistory, 2)).toEqual([
|
expect(renderHistory(Object.values(basicLinearHistory))).toEqual([
|
||||||
{
|
{
|
||||||
isActive: false,
|
...basicLinearHistory["0"],
|
||||||
parentVersion: null,
|
|
||||||
summary: "Create",
|
|
||||||
type: "Create",
|
type: "Create",
|
||||||
},
|
|
||||||
{
|
|
||||||
isActive: false,
|
|
||||||
parentVersion: null,
|
|
||||||
summary: "use better icons",
|
|
||||||
type: "Edit",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
isActive: true,
|
|
||||||
parentVersion: null,
|
|
||||||
summary: "make text red",
|
|
||||||
type: "Edit",
|
|
||||||
},
|
|
||||||
]);
|
|
||||||
|
|
||||||
// Current version is the first version
|
|
||||||
expect(renderHistory(basicLinearHistory, 0)).toEqual([
|
|
||||||
{
|
|
||||||
isActive: true,
|
|
||||||
parentVersion: null,
|
|
||||||
summary: "Create",
|
summary: "Create",
|
||||||
type: "Create",
|
parentVersion: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
isActive: false,
|
...basicLinearHistory["1"],
|
||||||
parentVersion: null,
|
type: "Edit",
|
||||||
summary: "use better icons",
|
summary: "use better icons",
|
||||||
type: "Edit",
|
parentVersion: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
isActive: false,
|
...basicLinearHistory["2"],
|
||||||
parentVersion: null,
|
|
||||||
summary: "make text red",
|
|
||||||
type: "Edit",
|
type: "Edit",
|
||||||
|
summary: "make text red",
|
||||||
|
parentVersion: null,
|
||||||
},
|
},
|
||||||
]);
|
]);
|
||||||
|
|
||||||
// Render a history with code
|
// Render a history with code
|
||||||
expect(renderHistory(basicLinearHistoryWithCode, 0)).toEqual([
|
expect(renderHistory(Object.values(basicLinearHistoryWithCode))).toEqual([
|
||||||
{
|
{
|
||||||
isActive: true,
|
...basicLinearHistoryWithCode["0"],
|
||||||
parentVersion: null,
|
|
||||||
summary: "Imported from code",
|
|
||||||
type: "Imported from code",
|
type: "Imported from code",
|
||||||
|
summary: "Imported from code",
|
||||||
|
parentVersion: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
isActive: false,
|
...basicLinearHistoryWithCode["1"],
|
||||||
parentVersion: null,
|
type: "Edit",
|
||||||
summary: "use better icons",
|
summary: "use better icons",
|
||||||
type: "Edit",
|
parentVersion: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
isActive: false,
|
...basicLinearHistoryWithCode["2"],
|
||||||
parentVersion: null,
|
|
||||||
summary: "make text red",
|
|
||||||
type: "Edit",
|
type: "Edit",
|
||||||
|
summary: "make text red",
|
||||||
|
parentVersion: null,
|
||||||
},
|
},
|
||||||
]);
|
]);
|
||||||
|
|
||||||
// Render a non-linear history
|
// Render a non-linear history
|
||||||
expect(renderHistory(basicBranchingHistory, 3)).toEqual([
|
expect(renderHistory(Object.values(basicBranchingHistory))).toEqual([
|
||||||
{
|
{
|
||||||
isActive: false,
|
...basicBranchingHistory["0"],
|
||||||
parentVersion: null,
|
|
||||||
summary: "Create",
|
|
||||||
type: "Create",
|
type: "Create",
|
||||||
|
summary: "Create",
|
||||||
|
parentVersion: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
isActive: false,
|
...basicBranchingHistory["1"],
|
||||||
parentVersion: null,
|
type: "Edit",
|
||||||
summary: "use better icons",
|
summary: "use better icons",
|
||||||
type: "Edit",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
isActive: false,
|
|
||||||
parentVersion: null,
|
parentVersion: null,
|
||||||
summary: "make text red",
|
|
||||||
type: "Edit",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
isActive: true,
|
...basicBranchingHistory["2"],
|
||||||
parentVersion: "v2",
|
|
||||||
summary: "make text green",
|
|
||||||
type: "Edit",
|
type: "Edit",
|
||||||
|
summary: "make text red",
|
||||||
|
parentVersion: null,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
...basicBranchingHistory["3"],
|
||||||
|
type: "Edit",
|
||||||
|
summary: "make text green",
|
||||||
|
parentVersion: 2,
|
||||||
},
|
},
|
||||||
]);
|
]);
|
||||||
});
|
});
|
||||||
|
|||||||
@ -1,33 +1,25 @@
|
|||||||
import {
|
import { Commit, CommitHash, CommitType } from "../commits/types";
|
||||||
History,
|
|
||||||
HistoryItem,
|
|
||||||
HistoryItemType,
|
|
||||||
RenderedHistoryItem,
|
|
||||||
} from "./history_types";
|
|
||||||
|
|
||||||
export function extractHistoryTree(
|
export function extractHistory(
|
||||||
history: History,
|
hash: CommitHash,
|
||||||
version: number
|
commits: Record<CommitHash, Commit>
|
||||||
): string[] {
|
): string[] {
|
||||||
const flatHistory: string[] = [];
|
const flatHistory: string[] = [];
|
||||||
|
|
||||||
let currentIndex: number | null = version;
|
let currentCommitHash: CommitHash | null = hash;
|
||||||
while (currentIndex !== null) {
|
while (currentCommitHash !== null) {
|
||||||
const item: HistoryItem = history[currentIndex];
|
const commit: Commit | null = commits[currentCommitHash];
|
||||||
|
|
||||||
if (item) {
|
if (commit) {
|
||||||
if (item.type === "ai_create") {
|
flatHistory.unshift(commit.variants[commit.selectedVariantIndex].code);
|
||||||
// Don't include the image for ai_create
|
|
||||||
flatHistory.unshift(item.code);
|
// For edits, add the prompt to the history
|
||||||
} else if (item.type === "ai_edit") {
|
if (commit.type === "ai_edit") {
|
||||||
flatHistory.unshift(item.code);
|
flatHistory.unshift(commit.inputs.prompt);
|
||||||
flatHistory.unshift(item.inputs.prompt);
|
|
||||||
} else if (item.type === "code_create") {
|
|
||||||
flatHistory.unshift(item.code);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move to the parent of the current item
|
// Move to the parent of the current item
|
||||||
currentIndex = item.parentIndex;
|
currentCommitHash = commit.parentHash;
|
||||||
} else {
|
} else {
|
||||||
throw new Error("Malformed history: missing parent index");
|
throw new Error("Malformed history: missing parent index");
|
||||||
}
|
}
|
||||||
@ -36,7 +28,7 @@ export function extractHistoryTree(
|
|||||||
return flatHistory;
|
return flatHistory;
|
||||||
}
|
}
|
||||||
|
|
||||||
function displayHistoryItemType(itemType: HistoryItemType) {
|
function displayHistoryItemType(itemType: CommitType) {
|
||||||
switch (itemType) {
|
switch (itemType) {
|
||||||
case "ai_create":
|
case "ai_create":
|
||||||
return "Create";
|
return "Create";
|
||||||
@ -51,44 +43,48 @@ function displayHistoryItemType(itemType: HistoryItemType) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function summarizeHistoryItem(item: HistoryItem) {
|
const setParentVersion = (commit: Commit, history: Commit[]) => {
|
||||||
const itemType = item.type;
|
// If the commit has no parent, return null
|
||||||
switch (itemType) {
|
if (!commit.parentHash) return null;
|
||||||
|
|
||||||
|
const parentIndex = history.findIndex(
|
||||||
|
(item) => item.hash === commit.parentHash
|
||||||
|
);
|
||||||
|
const currentIndex = history.findIndex((item) => item.hash === commit.hash);
|
||||||
|
|
||||||
|
// Only set parent version if the parent is not the previous commit
|
||||||
|
// and parent exists
|
||||||
|
return parentIndex !== -1 && parentIndex != currentIndex - 1
|
||||||
|
? parentIndex + 1
|
||||||
|
: null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function summarizeHistoryItem(commit: Commit) {
|
||||||
|
const commitType = commit.type;
|
||||||
|
switch (commitType) {
|
||||||
case "ai_create":
|
case "ai_create":
|
||||||
return "Create";
|
return "Create";
|
||||||
case "ai_edit":
|
case "ai_edit":
|
||||||
return item.inputs.prompt;
|
return commit.inputs.prompt;
|
||||||
case "code_create":
|
case "code_create":
|
||||||
return "Imported from code";
|
return "Imported from code";
|
||||||
default: {
|
default: {
|
||||||
const exhaustiveCheck: never = itemType;
|
const exhaustiveCheck: never = commitType;
|
||||||
throw new Error(`Unhandled case: ${exhaustiveCheck}`);
|
throw new Error(`Unhandled case: ${exhaustiveCheck}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export const renderHistory = (
|
export const renderHistory = (history: Commit[]) => {
|
||||||
history: History,
|
const renderedHistory = [];
|
||||||
currentVersion: number | null
|
|
||||||
) => {
|
|
||||||
const renderedHistory: RenderedHistoryItem[] = [];
|
|
||||||
|
|
||||||
for (let i = 0; i < history.length; i++) {
|
for (let i = 0; i < history.length; i++) {
|
||||||
const item = history[i];
|
const commit = history[i];
|
||||||
// Only show the parent version if it's not the previous version
|
|
||||||
// (i.e. it's the branching point) and if it's not the first version
|
|
||||||
const parentVersion =
|
|
||||||
item.parentIndex !== null && item.parentIndex !== i - 1
|
|
||||||
? `v${(item.parentIndex || 0) + 1}`
|
|
||||||
: null;
|
|
||||||
const type = displayHistoryItemType(item.type);
|
|
||||||
const isActive = i === currentVersion;
|
|
||||||
const summary = summarizeHistoryItem(item);
|
|
||||||
renderedHistory.push({
|
renderedHistory.push({
|
||||||
isActive,
|
...commit,
|
||||||
summary: summary,
|
type: displayHistoryItemType(commit.type),
|
||||||
parentVersion,
|
summary: summarizeHistoryItem(commit),
|
||||||
type,
|
parentVersion: setParentVersion(commit, history),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -23,12 +23,17 @@ interface Props {
|
|||||||
|
|
||||||
function PreviewPane({ doUpdate, reset, settings }: Props) {
|
function PreviewPane({ doUpdate, reset, settings }: Props) {
|
||||||
const { appState } = useAppStore();
|
const { appState } = useAppStore();
|
||||||
const { inputMode, generatedCode, setGeneratedCode } = useProjectStore();
|
const { inputMode, head, commits } = useProjectStore();
|
||||||
|
|
||||||
|
const currentCommit = head && commits[head] ? commits[head] : "";
|
||||||
|
const currentCode = currentCommit
|
||||||
|
? currentCommit.variants[currentCommit.selectedVariantIndex].code
|
||||||
|
: "";
|
||||||
|
|
||||||
const previewCode =
|
const previewCode =
|
||||||
inputMode === "video" && appState === AppState.CODING
|
inputMode === "video" && appState === AppState.CODING
|
||||||
? extractHtml(generatedCode)
|
? extractHtml(currentCode)
|
||||||
: generatedCode;
|
: currentCode;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="ml-4">
|
<div className="ml-4">
|
||||||
@ -45,7 +50,7 @@ function PreviewPane({ doUpdate, reset, settings }: Props) {
|
|||||||
Reset
|
Reset
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
onClick={() => downloadCode(generatedCode)}
|
onClick={() => downloadCode(previewCode)}
|
||||||
variant="secondary"
|
variant="secondary"
|
||||||
className="flex items-center gap-x-2 mr-4 dark:text-white dark:bg-gray-700 download-btn"
|
className="flex items-center gap-x-2 mr-4 dark:text-white dark:bg-gray-700 download-btn"
|
||||||
>
|
>
|
||||||
@ -84,11 +89,7 @@ function PreviewPane({ doUpdate, reset, settings }: Props) {
|
|||||||
/>
|
/>
|
||||||
</TabsContent>
|
</TabsContent>
|
||||||
<TabsContent value="code">
|
<TabsContent value="code">
|
||||||
<CodeTab
|
<CodeTab code={previewCode} setCode={() => {}} settings={settings} />
|
||||||
code={previewCode}
|
|
||||||
setCode={setGeneratedCode}
|
|
||||||
settings={settings}
|
|
||||||
/>
|
|
||||||
</TabsContent>
|
</TabsContent>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import { Button } from "../ui/button";
|
|||||||
import { Textarea } from "../ui/textarea";
|
import { Textarea } from "../ui/textarea";
|
||||||
import { useEffect, useRef } from "react";
|
import { useEffect, useRef } from "react";
|
||||||
import HistoryDisplay from "../history/HistoryDisplay";
|
import HistoryDisplay from "../history/HistoryDisplay";
|
||||||
|
import Variants from "../variants/Variants";
|
||||||
|
|
||||||
interface SidebarProps {
|
interface SidebarProps {
|
||||||
showSelectAndEditFeature: boolean;
|
showSelectAndEditFeature: boolean;
|
||||||
@ -35,9 +36,18 @@ function Sidebar({
|
|||||||
shouldIncludeResultImage,
|
shouldIncludeResultImage,
|
||||||
setShouldIncludeResultImage,
|
setShouldIncludeResultImage,
|
||||||
} = useAppStore();
|
} = useAppStore();
|
||||||
const { inputMode, generatedCode, referenceImages, executionConsole } =
|
|
||||||
|
const { inputMode, referenceImages, executionConsoles, head, commits } =
|
||||||
useProjectStore();
|
useProjectStore();
|
||||||
|
|
||||||
|
const viewedCode =
|
||||||
|
head && commits[head]
|
||||||
|
? commits[head].variants[commits[head].selectedVariantIndex].code
|
||||||
|
: "";
|
||||||
|
|
||||||
|
const executionConsole =
|
||||||
|
(head && executionConsoles[commits[head].selectedVariantIndex]) || [];
|
||||||
|
|
||||||
// When coding is complete, focus on the update instruction textarea
|
// When coding is complete, focus on the update instruction textarea
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (appState === AppState.CODE_READY && textareaRef.current) {
|
if (appState === AppState.CODE_READY && textareaRef.current) {
|
||||||
@ -47,6 +57,8 @@ function Sidebar({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
|
<Variants />
|
||||||
|
|
||||||
{/* Show code preview only when coding */}
|
{/* Show code preview only when coding */}
|
||||||
{appState === AppState.CODING && (
|
{appState === AppState.CODING && (
|
||||||
<div className="flex flex-col">
|
<div className="flex flex-col">
|
||||||
@ -66,7 +78,7 @@ function Sidebar({
|
|||||||
{executionConsole.slice(-1)[0]}
|
{executionConsole.slice(-1)[0]}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<CodePreview code={generatedCode} />
|
<CodePreview code={viewedCode} />
|
||||||
|
|
||||||
<div className="flex w-full">
|
<div className="flex w-full">
|
||||||
<Button
|
<Button
|
||||||
@ -158,15 +170,22 @@ function Sidebar({
|
|||||||
)}
|
)}
|
||||||
<div className="bg-gray-400 px-4 py-2 rounded text-sm hidden">
|
<div className="bg-gray-400 px-4 py-2 rounded text-sm hidden">
|
||||||
<h2 className="text-lg mb-4 border-b border-gray-800">Console</h2>
|
<h2 className="text-lg mb-4 border-b border-gray-800">Console</h2>
|
||||||
{executionConsole.map((line, index) => (
|
{Object.entries(executionConsoles).map(([index, lines]) => (
|
||||||
|
<div key={index}>
|
||||||
|
{lines.map((line, lineIndex) => (
|
||||||
<div
|
<div
|
||||||
key={index}
|
key={`${index}-${lineIndex}`}
|
||||||
className="border-b border-gray-400 mb-2 text-gray-600 font-mono"
|
className="border-b border-gray-400 mb-2 text-gray-600 font-mono"
|
||||||
>
|
>
|
||||||
|
<span className="font-bold mr-2">{`${index}:${
|
||||||
|
lineIndex + 1
|
||||||
|
}`}</span>
|
||||||
{line}
|
{line}
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<HistoryDisplay shouldDisableReverts={appState === AppState.CODING} />
|
<HistoryDisplay shouldDisableReverts={appState === AppState.CODING} />
|
||||||
|
|||||||
42
frontend/src/components/variants/Variants.tsx
Normal file
42
frontend/src/components/variants/Variants.tsx
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import { useProjectStore } from "../../store/project-store";
|
||||||
|
|
||||||
|
function Variants() {
|
||||||
|
const { inputMode, head, commits, updateSelectedVariantIndex } =
|
||||||
|
useProjectStore();
|
||||||
|
|
||||||
|
// If there is no head, don't show the variants
|
||||||
|
if (head === null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const commit = commits[head];
|
||||||
|
const variants = commit.variants;
|
||||||
|
const selectedVariantIndex = commit.selectedVariantIndex;
|
||||||
|
|
||||||
|
// If there is only one variant or the commit is already committed, don't show the variants
|
||||||
|
if (variants.length <= 1 || commit.isCommitted || inputMode === "video") {
|
||||||
|
return <div className="mt-2"></div>;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="mt-4 mb-4">
|
||||||
|
<div className="grid grid-cols-2 gap-2">
|
||||||
|
{variants.map((_, index) => (
|
||||||
|
<div
|
||||||
|
key={index}
|
||||||
|
className={`p-2 border rounded-md cursor-pointer ${
|
||||||
|
index === selectedVariantIndex
|
||||||
|
? "bg-blue-100 dark:bg-blue-900"
|
||||||
|
: "bg-gray-50 dark:bg-gray-800 hover:bg-gray-100 dark:hover:bg-gray-700"
|
||||||
|
}`}
|
||||||
|
onClick={() => updateSelectedVariantIndex(head, index)}
|
||||||
|
>
|
||||||
|
<h3 className="font-medium mb-1">Option {index + 1}</h3>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default Variants;
|
||||||
@ -11,12 +11,18 @@ const ERROR_MESSAGE =
|
|||||||
|
|
||||||
const CANCEL_MESSAGE = "Code generation cancelled";
|
const CANCEL_MESSAGE = "Code generation cancelled";
|
||||||
|
|
||||||
|
type WebSocketResponse = {
|
||||||
|
type: "chunk" | "status" | "setCode" | "error";
|
||||||
|
value: string;
|
||||||
|
variantIndex: number;
|
||||||
|
};
|
||||||
|
|
||||||
export function generateCode(
|
export function generateCode(
|
||||||
wsRef: React.MutableRefObject<WebSocket | null>,
|
wsRef: React.MutableRefObject<WebSocket | null>,
|
||||||
params: FullGenerationSettings,
|
params: FullGenerationSettings,
|
||||||
onChange: (chunk: string) => void,
|
onChange: (chunk: string, variantIndex: number) => void,
|
||||||
onSetCode: (code: string) => void,
|
onSetCode: (code: string, variantIndex: number) => void,
|
||||||
onStatusUpdate: (status: string) => void,
|
onStatusUpdate: (status: string, variantIndex: number) => void,
|
||||||
onCancel: () => void,
|
onCancel: () => void,
|
||||||
onComplete: () => void
|
onComplete: () => void
|
||||||
) {
|
) {
|
||||||
@ -31,13 +37,13 @@ export function generateCode(
|
|||||||
});
|
});
|
||||||
|
|
||||||
ws.addEventListener("message", async (event: MessageEvent) => {
|
ws.addEventListener("message", async (event: MessageEvent) => {
|
||||||
const response = JSON.parse(event.data);
|
const response = JSON.parse(event.data) as WebSocketResponse;
|
||||||
if (response.type === "chunk") {
|
if (response.type === "chunk") {
|
||||||
onChange(response.value);
|
onChange(response.value, response.variantIndex);
|
||||||
} else if (response.type === "status") {
|
} else if (response.type === "status") {
|
||||||
onStatusUpdate(response.value);
|
onStatusUpdate(response.value, response.variantIndex);
|
||||||
} else if (response.type === "setCode") {
|
} else if (response.type === "setCode") {
|
||||||
onSetCode(response.value);
|
onSetCode(response.value, response.variantIndex);
|
||||||
} else if (response.type === "error") {
|
} else if (response.type === "error") {
|
||||||
console.error("Error generating code", response.value);
|
console.error("Error generating code", response.value);
|
||||||
toast.error(response.value);
|
toast.error(response.value);
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { create } from "zustand";
|
import { create } from "zustand";
|
||||||
import { History } from "../components/history/history_types";
|
import { Commit, CommitHash } from "../components/commits/types";
|
||||||
|
|
||||||
// Store for app-wide state
|
// Store for app-wide state
|
||||||
interface ProjectStore {
|
interface ProjectStore {
|
||||||
@ -11,25 +11,28 @@ interface ProjectStore {
|
|||||||
referenceImages: string[];
|
referenceImages: string[];
|
||||||
setReferenceImages: (images: string[]) => void;
|
setReferenceImages: (images: string[]) => void;
|
||||||
|
|
||||||
// Outputs and other state
|
// Outputs
|
||||||
generatedCode: string;
|
commits: Record<string, Commit>;
|
||||||
setGeneratedCode: (
|
head: CommitHash | null;
|
||||||
updater: string | ((currentCode: string) => string)
|
|
||||||
) => void;
|
|
||||||
executionConsole: string[];
|
|
||||||
setExecutionConsole: (
|
|
||||||
updater: string[] | ((currentConsole: string[]) => string[])
|
|
||||||
) => void;
|
|
||||||
|
|
||||||
// Tracks the currently shown version from app history
|
addCommit: (commit: Commit) => void;
|
||||||
// TODO: might want to move to appStore
|
removeCommit: (hash: CommitHash) => void;
|
||||||
currentVersion: number | null;
|
resetCommits: () => void;
|
||||||
setCurrentVersion: (version: number | null) => void;
|
|
||||||
|
|
||||||
appHistory: History;
|
appendCommitCode: (
|
||||||
setAppHistory: (
|
hash: CommitHash,
|
||||||
updater: History | ((currentHistory: History) => History)
|
numVariant: number,
|
||||||
|
code: string
|
||||||
) => void;
|
) => void;
|
||||||
|
setCommitCode: (hash: CommitHash, numVariant: number, code: string) => void;
|
||||||
|
updateSelectedVariantIndex: (hash: CommitHash, index: number) => void;
|
||||||
|
|
||||||
|
setHead: (hash: CommitHash) => void;
|
||||||
|
resetHead: () => void;
|
||||||
|
|
||||||
|
executionConsoles: { [key: number]: string[] };
|
||||||
|
appendExecutionConsole: (variantIndex: number, line: string) => void;
|
||||||
|
resetExecutionConsoles: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const useProjectStore = create<ProjectStore>((set) => ({
|
export const useProjectStore = create<ProjectStore>((set) => ({
|
||||||
@ -41,28 +44,106 @@ export const useProjectStore = create<ProjectStore>((set) => ({
|
|||||||
referenceImages: [],
|
referenceImages: [],
|
||||||
setReferenceImages: (images) => set({ referenceImages: images }),
|
setReferenceImages: (images) => set({ referenceImages: images }),
|
||||||
|
|
||||||
// Outputs and other state
|
// Outputs
|
||||||
generatedCode: "",
|
commits: {},
|
||||||
setGeneratedCode: (updater) =>
|
head: null,
|
||||||
set((state) => ({
|
|
||||||
generatedCode:
|
|
||||||
typeof updater === "function" ? updater(state.generatedCode) : updater,
|
|
||||||
})),
|
|
||||||
executionConsole: [],
|
|
||||||
setExecutionConsole: (updater) =>
|
|
||||||
set((state) => ({
|
|
||||||
executionConsole:
|
|
||||||
typeof updater === "function"
|
|
||||||
? updater(state.executionConsole)
|
|
||||||
: updater,
|
|
||||||
})),
|
|
||||||
|
|
||||||
currentVersion: null,
|
addCommit: (commit: Commit) => {
|
||||||
setCurrentVersion: (version) => set({ currentVersion: version }),
|
// When adding a new commit, make sure all existing commits are marked as committed
|
||||||
appHistory: [],
|
|
||||||
setAppHistory: (updater) =>
|
|
||||||
set((state) => ({
|
set((state) => ({
|
||||||
appHistory:
|
commits: {
|
||||||
typeof updater === "function" ? updater(state.appHistory) : updater,
|
...Object.fromEntries(
|
||||||
})),
|
Object.entries(state.commits).map(([hash, existingCommit]) => [
|
||||||
|
hash,
|
||||||
|
{ ...existingCommit, isCommitted: true },
|
||||||
|
])
|
||||||
|
),
|
||||||
|
[commit.hash]: commit,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
},
|
||||||
|
removeCommit: (hash: CommitHash) => {
|
||||||
|
set((state) => {
|
||||||
|
const newCommits = { ...state.commits };
|
||||||
|
delete newCommits[hash];
|
||||||
|
return { commits: newCommits };
|
||||||
|
});
|
||||||
|
},
|
||||||
|
resetCommits: () => set({ commits: {} }),
|
||||||
|
|
||||||
|
appendCommitCode: (hash: CommitHash, numVariant: number, code: string) =>
|
||||||
|
set((state) => {
|
||||||
|
const commit = state.commits[hash];
|
||||||
|
// Don't update if the commit is already committed
|
||||||
|
if (commit.isCommitted) {
|
||||||
|
throw new Error("Attempted to append code to a committed commit");
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
commits: {
|
||||||
|
...state.commits,
|
||||||
|
[hash]: {
|
||||||
|
...commit,
|
||||||
|
variants: commit.variants.map((variant, index) =>
|
||||||
|
index === numVariant
|
||||||
|
? { ...variant, code: variant.code + code }
|
||||||
|
: variant
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
setCommitCode: (hash: CommitHash, numVariant: number, code: string) =>
|
||||||
|
set((state) => {
|
||||||
|
const commit = state.commits[hash];
|
||||||
|
// Don't update if the commit is already committed
|
||||||
|
if (commit.isCommitted) {
|
||||||
|
throw new Error("Attempted to set code of a committed commit");
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
commits: {
|
||||||
|
...state.commits,
|
||||||
|
[hash]: {
|
||||||
|
...commit,
|
||||||
|
variants: commit.variants.map((variant, index) =>
|
||||||
|
index === numVariant ? { ...variant, code } : variant
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
updateSelectedVariantIndex: (hash: CommitHash, index: number) =>
|
||||||
|
set((state) => {
|
||||||
|
const commit = state.commits[hash];
|
||||||
|
// Don't update if the commit is already committed
|
||||||
|
if (commit.isCommitted) {
|
||||||
|
throw new Error(
|
||||||
|
"Attempted to update selected variant index of a committed commit"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
commits: {
|
||||||
|
...state.commits,
|
||||||
|
[hash]: {
|
||||||
|
...commit,
|
||||||
|
selectedVariantIndex: index,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
|
||||||
|
setHead: (hash: CommitHash) => set({ head: hash }),
|
||||||
|
resetHead: () => set({ head: null }),
|
||||||
|
|
||||||
|
executionConsoles: {},
|
||||||
|
appendExecutionConsole: (variantIndex: number, line: string) =>
|
||||||
|
set((state) => ({
|
||||||
|
executionConsoles: {
|
||||||
|
...state.executionConsoles,
|
||||||
|
[variantIndex]: [
|
||||||
|
...(state.executionConsoles[variantIndex] || []),
|
||||||
|
line,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
resetExecutionConsoles: () => set({ executionConsoles: {} }),
|
||||||
}));
|
}));
|
||||||
|
|||||||
@ -4582,6 +4582,11 @@ nanoid@^3.3.6, nanoid@^3.3.7:
|
|||||||
resolved "https://registry.npmjs.org/nanoid/-/nanoid-3.3.7.tgz"
|
resolved "https://registry.npmjs.org/nanoid/-/nanoid-3.3.7.tgz"
|
||||||
integrity sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g==
|
integrity sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g==
|
||||||
|
|
||||||
|
nanoid@^5.0.7:
|
||||||
|
version "5.0.7"
|
||||||
|
resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-5.0.7.tgz#6452e8c5a816861fd9d2b898399f7e5fd6944cc6"
|
||||||
|
integrity sha512-oLxFY2gd2IqnjcYyOXD8XGCftpGtZP2AbHbOkthDkvRywH5ayNtPVy9YlOPcHckXzbLTCHpkb7FB+yuxKV13pQ==
|
||||||
|
|
||||||
natural-compare@^1.4.0:
|
natural-compare@^1.4.0:
|
||||||
version "1.4.0"
|
version "1.4.0"
|
||||||
resolved "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz"
|
resolved "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user