This commit is contained in:
Abi Raja 2024-09-08 16:50:18 +02:00
commit c107f4eda5
25 changed files with 1037 additions and 784 deletions

View File

@ -33,10 +33,6 @@ We also just added experimental support for taking a video/screen recording of a
[Follow me on Twitter for updates](https://twitter.com/_abi_).
## Sponsors
<a href="https://konghq.com/products/kong-konnect?utm_medium=referral&utm_source=github&utm_campaign=platform&utm_content=screenshot-to-code" target="_blank" title="Kong - powering the API world"><img src="https://picoapps.xyz/s2c-sponsors/Kong-GitHub-240x100.png"></a>
## 🚀 Hosted Version
[Try it live on the hosted version (paid)](https://screenshottocode.com).

View File

@ -7,15 +7,15 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: local
hooks:
- id: poetry-pytest
name: Run pytest with Poetry
entry: poetry run --directory backend pytest
language: system
pass_filenames: false
always_run: true
files: ^backend/
# - repo: local
# hooks:
# - id: poetry-pytest
# name: Run pytest with Poetry
# entry: poetry run --directory backend pytest
# language: system
# pass_filenames: false
# always_run: true
# files: ^backend/
# - id: poetry-pyright
# name: Run pyright with Poetry
# entry: poetry run --directory backend pyright

View File

@ -3,8 +3,12 @@
# TODO: Should only be set to true when value is 'True', not any abitrary truthy value
import os
NUM_VARIANTS = 2
# LLM-related
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", None)
# Image generation (optional)
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)

View File

View File

@ -0,0 +1,23 @@
from datetime import datetime
import json
import os
from openai.types.chat import ChatCompletionMessageParam
def write_logs(prompt_messages: list[ChatCompletionMessageParam], completion: str):
# Get the logs path from environment, default to the current working directory
logs_path = os.environ.get("LOGS_PATH", os.getcwd())
# Create run_logs directory if it doesn't exist within the specified logs path
logs_directory = os.path.join(logs_path, "run_logs")
if not os.path.exists(logs_directory):
os.makedirs(logs_directory)
print("Writing to logs directory:", logs_directory)
# Generate a unique filename using the current timestamp within the logs directory
filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json")
# Write the messages dict into a new file for each run
with open(filename, "w") as f:
f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))

View File

@ -28,7 +28,7 @@ async def process_tasks(
processed_results: List[Union[str, None]] = []
for result in results:
if isinstance(result, Exception):
if isinstance(result, BaseException):
print(f"An exception occurred: {result}")
try:
raise result

View File

@ -1,3 +1,4 @@
import copy
from enum import Enum
from typing import Any, Awaitable, Callable, List, cast
from anthropic import AsyncAnthropic
@ -112,8 +113,12 @@ async def stream_claude_response(
temperature = 0.0
# Translate OpenAI messages to Claude messages
system_prompt = cast(str, messages[0].get("content"))
claude_messages = [dict(message) for message in messages[1:]]
# Deep copy messages to avoid modifying the original list
cloned_messages = copy.deepcopy(messages)
system_prompt = cast(str, cloned_messages[0].get("content"))
claude_messages = [dict(message) for message in cloned_messages[1:]]
for message in claude_messages:
if not isinstance(message["content"], list):
continue

View File

@ -8,7 +8,7 @@ STREAM_CHUNK_SIZE = 20
async def mock_completion(
process_chunk: Callable[[str], Awaitable[None]], input_mode: InputMode
process_chunk: Callable[[str, int], Awaitable[None]], input_mode: InputMode
) -> str:
code_to_return = (
TALLY_FORM_VIDEO_PROMPT_MOCK
@ -17,7 +17,7 @@ async def mock_completion(
)
for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE):
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE])
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE], 0)
await asyncio.sleep(0.01)
if input_mode == "video":

View File

@ -1,12 +1,13 @@
from typing import List, NoReturn, Union
from typing import Union
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
from llm import Llm
from custom_types import InputMode
from image_generation.core import create_alt_url_mapping
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
from prompts.text_prompts import SYSTEM_PROMPTS as TEXT_SYSTEM_PROMPTS
from prompts.types import Stack
from video.utils import assemble_claude_prompt_video
USER_PROMPT = """
@ -18,9 +19,65 @@ Generate code for a SVG that looks exactly like this.
"""
async def create_prompt(
params: dict[str, str], stack: Stack, input_mode: InputMode
) -> tuple[list[ChatCompletionMessageParam], dict[str, str]]:
image_cache: dict[str, str] = {}
# If this generation started off with imported code, we need to assemble the prompt differently
if params.get("isImportedFromCode"):
original_imported_code = params["history"][0]
prompt_messages = assemble_imported_code_prompt(original_imported_code, stack)
for index, text in enumerate(params["history"][1:]):
if index % 2 == 0:
message: ChatCompletionMessageParam = {
"role": "user",
"content": text,
}
else:
message: ChatCompletionMessageParam = {
"role": "assistant",
"content": text,
}
prompt_messages.append(message)
else:
# Assemble the prompt for non-imported code
if params.get("resultImage"):
prompt_messages = assemble_prompt(
params["image"], stack, params["resultImage"]
)
else:
prompt_messages = assemble_prompt(params["image"], stack)
if params["generationType"] == "update":
# Transform the history tree into message format
# TODO: Move this to frontend
for index, text in enumerate(params["history"]):
if index % 2 == 0:
message: ChatCompletionMessageParam = {
"role": "assistant",
"content": text,
}
else:
message: ChatCompletionMessageParam = {
"role": "user",
"content": text,
}
prompt_messages.append(message)
image_cache = create_alt_url_mapping(params["history"][-2])
if input_mode == "video":
video_data_url = params["image"]
prompt_messages = await assemble_claude_prompt_video(video_data_url)
return prompt_messages, image_cache
def assemble_imported_code_prompt(
code: str, stack: Stack, model: Llm
) -> List[ChatCompletionMessageParam]:
code: str, stack: Stack
) -> list[ChatCompletionMessageParam]:
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
user_content = (
@ -29,24 +86,12 @@ def assemble_imported_code_prompt(
else "Here is the code of the SVG: " + code
)
if model == Llm.CLAUDE_3_5_SONNET_2024_06_20:
return [
{
"role": "system",
"content": system_content + "\n " + user_content,
}
]
else:
return [
{
"role": "system",
"content": system_content,
},
{
"role": "user",
"content": user_content,
},
]
# TODO: Use result_image_data_url
@ -54,11 +99,11 @@ def assemble_prompt(
image_data_url: str,
stack: Stack,
result_image_data_url: Union[str, None] = None,
) -> List[ChatCompletionMessageParam]:
) -> list[ChatCompletionMessageParam]:
system_content = SYSTEM_PROMPTS[stack]
user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT
user_content: List[ChatCompletionContentPartParam] = [
user_content: list[ChatCompletionContentPartParam] = [
{
"type": "image_url",
"image_url": {"url": image_data_url, "detail": "high"},
@ -93,7 +138,7 @@ def assemble_prompt(
def assemble_text_prompt(
text_prompt: str,
stack: Stack,
) -> List[ChatCompletionMessageParam]:
) -> list[ChatCompletionMessageParam]:
system_content = TEXT_SYSTEM_PROMPTS[stack]

View File

@ -391,63 +391,81 @@ def test_prompts():
def test_imported_code_prompts():
tailwind_prompt = assemble_imported_code_prompt(
"code", "html_tailwind", Llm.GPT_4O_2024_05_13
)
code = "Sample code"
tailwind_prompt = assemble_imported_code_prompt(code, "html_tailwind")
expected_tailwind_prompt = [
{"role": "system", "content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT},
{"role": "user", "content": "Here is the code of the app: code"},
{
"role": "system",
"content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
+ "\n Here is the code of the app: "
+ code,
}
]
assert tailwind_prompt == expected_tailwind_prompt
html_css_prompt = assemble_imported_code_prompt(
"code", "html_css", Llm.GPT_4O_2024_05_13
)
html_css_prompt = assemble_imported_code_prompt(code, "html_css")
expected_html_css_prompt = [
{"role": "system", "content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT},
{"role": "user", "content": "Here is the code of the app: code"},
{
"role": "system",
"content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT
+ "\n Here is the code of the app: "
+ code,
}
]
assert html_css_prompt == expected_html_css_prompt
react_tailwind_prompt = assemble_imported_code_prompt(
"code", "react_tailwind", Llm.GPT_4O_2024_05_13
)
react_tailwind_prompt = assemble_imported_code_prompt(code, "react_tailwind")
expected_react_tailwind_prompt = [
{"role": "system", "content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT},
{"role": "user", "content": "Here is the code of the app: code"},
{
"role": "system",
"content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT
+ "\n Here is the code of the app: "
+ code,
}
]
assert react_tailwind_prompt == expected_react_tailwind_prompt
bootstrap_prompt = assemble_imported_code_prompt(
"code", "bootstrap", Llm.GPT_4O_2024_05_13
)
bootstrap_prompt = assemble_imported_code_prompt(code, "bootstrap")
expected_bootstrap_prompt = [
{"role": "system", "content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT},
{"role": "user", "content": "Here is the code of the app: code"},
{
"role": "system",
"content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT
+ "\n Here is the code of the app: "
+ code,
}
]
assert bootstrap_prompt == expected_bootstrap_prompt
ionic_tailwind = assemble_imported_code_prompt(
"code", "ionic_tailwind", Llm.GPT_4O_2024_05_13
)
ionic_tailwind = assemble_imported_code_prompt(code, "ionic_tailwind")
expected_ionic_tailwind = [
{"role": "system", "content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT},
{"role": "user", "content": "Here is the code of the app: code"},
{
"role": "system",
"content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
+ "\n Here is the code of the app: "
+ code,
}
]
assert ionic_tailwind == expected_ionic_tailwind
vue_tailwind = assemble_imported_code_prompt(
"code", "vue_tailwind", Llm.GPT_4O_2024_05_13
)
vue_tailwind = assemble_imported_code_prompt(code, "vue_tailwind")
expected_vue_tailwind = [
{"role": "system", "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT},
{"role": "user", "content": "Here is the code of the app: code"},
{
"role": "system",
"content": IMPORTED_CODE_VUE_TAILWIND_PROMPT
+ "\n Here is the code of the app: "
+ code,
}
]
assert vue_tailwind == expected_vue_tailwind
svg = assemble_imported_code_prompt("code", "svg", Llm.GPT_4O_2024_05_13)
svg = assemble_imported_code_prompt(code, "svg")
expected_svg = [
{"role": "system", "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT},
{"role": "user", "content": "Here is the code of the SVG: code"},
{
"role": "system",
"content": IMPORTED_CODE_SVG_SYSTEM_PROMPT
+ "\n Here is the code of the SVG: "
+ code,
}
]
assert svg == expected_svg

View File

@ -1,12 +1,15 @@
import asyncio
from dataclasses import dataclass
import os
import traceback
from fastapi import APIRouter, WebSocket
import openai
import sentry_sdk
from codegen.utils import extract_html_content
from config import (
ANTHROPIC_API_KEY,
IS_PROD,
NUM_VARIANTS,
OPENAI_API_KEY,
OPENAI_BASE_URL,
REPLICATE_API_KEY,
SHOULD_MOCK_AI_RESPONSE,
)
@ -18,77 +21,101 @@ from llm import (
stream_claude_response_native,
stream_openai_response,
)
from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
from typing import Dict, List, Union, cast, get_args
from image_generation.core import create_alt_url_mapping, generate_images
from prompts import assemble_imported_code_prompt, assemble_prompt, assemble_text_prompt
from datetime import datetime
import json
from typing import Dict, List, cast, get_args
from image_generation.core import generate_images
from routes.logging_utils import PaymentMethod, send_to_saas_backend
from routes.saas_utils import does_user_have_subscription_credits
from typing import Any, Callable, Coroutine, Dict, List, Literal, cast, get_args
from image_generation.core import generate_images
from prompts import create_prompt
from prompts.claude_prompts import VIDEO_PROMPT
from prompts.types import Stack
from utils import pprint_prompt
# from utils import pprint_prompt
from video.utils import extract_tag_content, assemble_claude_prompt_video
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
router = APIRouter()
def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str):
# Get the logs path from environment, default to the current working directory
logs_path = os.environ.get("LOGS_PATH", os.getcwd())
# Create run_logs directory if it doesn't exist within the specified logs path
logs_directory = os.path.join(logs_path, "run_logs")
if not os.path.exists(logs_directory):
os.makedirs(logs_directory)
print("Writing to logs directory:", logs_directory)
# Generate a unique filename using the current timestamp within the logs directory
filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json")
# Write the messages dict into a new file for each run
with open(filename, "w") as f:
f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
# Auto-upgrade usage of older models
def auto_upgrade_model(code_generation_model: Llm) -> Llm:
if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}:
print(
f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13"
)
return Llm.GPT_4O_2024_05_13
elif code_generation_model == Llm.CLAUDE_3_SONNET:
print(
f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20"
)
return Llm.CLAUDE_3_5_SONNET_2024_06_20
return code_generation_model
@router.websocket("/generate-code")
async def stream_code(websocket: WebSocket):
await websocket.accept()
print("Incoming websocket connection...")
async def throw_error(
message: str,
# Generate images, if needed
async def perform_image_generation(
completion: str,
should_generate_images: bool,
openai_api_key: str | None,
openai_base_url: str | None,
image_cache: dict[str, str],
):
await websocket.send_json({"type": "error", "value": message})
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
replicate_api_key = REPLICATE_API_KEY
if not should_generate_images:
return completion
# TODO: Are the values always strings?
params: Dict[str, str] = await websocket.receive_json()
if replicate_api_key:
image_generation_model = "sdxl-lightning"
api_key = replicate_api_key
else:
if not openai_api_key:
print(
"No OpenAI API key and Replicate key found. Skipping image generation."
)
return completion
image_generation_model = "dalle3"
api_key = openai_api_key
# Read the code config settings from the request. Fall back to default if not provided.
generated_code_config = ""
if "generatedCodeConfig" in params and params["generatedCodeConfig"]:
generated_code_config = params["generatedCodeConfig"]
if not generated_code_config in get_args(Stack):
print("Generating images with model: ", image_generation_model)
return await generate_images(
completion,
api_key=api_key,
base_url=openai_base_url,
image_cache=image_cache,
model=image_generation_model,
)
@dataclass
class ExtractedParams:
stack: Stack
input_mode: InputMode
code_generation_model: Llm
should_generate_images: bool
openai_api_key: str | None
anthropic_api_key: str | None
openai_base_url: str | None
payment_method: PaymentMethod
async def extract_params(
params: Dict[str, str], throw_error: Callable[[str], Coroutine[Any, Any, None]]
) -> ExtractedParams:
# Read the code config settings (stack) from the request.
generated_code_config = params.get("generatedCodeConfig", "")
if generated_code_config not in get_args(Stack):
await throw_error(f"Invalid generated code config: {generated_code_config}")
return
# Cast the variable to the Stack type
valid_stack = cast(Stack, generated_code_config)
raise ValueError(f"Invalid generated code config: {generated_code_config}")
validated_stack = cast(Stack, generated_code_config)
# Validate the input mode
input_mode = params.get("inputMode", "image")
if not input_mode in get_args(InputMode):
input_mode = params.get("inputMode")
if input_mode not in get_args(InputMode):
await throw_error(f"Invalid input mode: {input_mode}")
raise Exception(f"Invalid input mode: {input_mode}")
# Cast the variable to the right type
raise ValueError(f"Invalid input mode: {input_mode}")
validated_input_mode = cast(InputMode, input_mode)
# Read the model from the request. Fall back to default if not provided.
@ -97,42 +124,19 @@ async def stream_code(websocket: WebSocket):
)
try:
code_generation_model = convert_frontend_str_to_llm(code_generation_model_str)
except:
except ValueError:
await throw_error(f"Invalid model: {code_generation_model_str}")
raise Exception(f"Invalid model: {code_generation_model_str}")
# Auto-upgrade usage of older models
if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}:
print(
f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13"
)
code_generation_model = Llm.GPT_4O_2024_05_13
elif code_generation_model == Llm.CLAUDE_3_SONNET:
print(
f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20"
)
code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20
exact_llm_version = None
print(
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
)
# Track how this generation is being paid for
payment_method: PaymentMethod = PaymentMethod.UNKNOWN
# Track the OpenAI API key to use
openai_api_key = None
raise ValueError(f"Invalid model: {code_generation_model_str}")
auth_token = params.get("authToken")
if not auth_token:
await throw_error("You need to be logged in to use screenshot to code")
raise Exception("No auth token")
# Get the OpenAI key by waterfalling through the different payment methods
# 1. Subscription
# 2. User's API key from client-side settings dialog
# 3. User's API key from environment variable
openai_api_key = None
# Track how this generation is being paid for
payment_method: PaymentMethod = PaymentMethod.UNKNOWN
# If the user is a subscriber, use the platform API key
# TODO: Rename does_user_have_subscription_credits
@ -157,151 +161,152 @@ async def stream_code(websocket: WebSocket):
await throw_error("Unknown error occurred. Contact support.")
raise Exception("Unknown error occurred when checking subscription credits")
# For non-subscribers, use the user's API key from client-side settings dialog
if not openai_api_key:
openai_api_key = params.get("openAiApiKey", None)
payment_method = PaymentMethod.OPENAI_API_KEY
print("Using OpenAI API key from client-side settings dialog")
# If we still don't have an API key, use the user's API key from environment variable
if not openai_api_key:
openai_api_key = os.environ.get("OPENAI_API_KEY")
openai_api_key = get_from_settings_dialog_or_env(
params, "openAiApiKey", OPENAI_API_KEY
)
payment_method = PaymentMethod.OPENAI_API_KEY
if openai_api_key:
# TODO
print("Using OpenAI API key from environment variable")
if not openai_api_key and (
code_generation_model == Llm.GPT_4_VISION
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
or code_generation_model == Llm.GPT_4O_2024_05_13
):
print("OpenAI API key not found")
await throw_error(
"Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support."
)
raise Exception("No OpenAI API key found")
# if not openai_api_key and (
# code_generation_model == Llm.GPT_4_VISION
# or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
# or code_generation_model == Llm.GPT_4O_2024_05_13
# ):
# print("OpenAI API key not found")
# await throw_error(
# "Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support."
# )
# raise Exception("No OpenAI API key found")
# Get the Anthropic API key from the request. Fall back to environment variable if not provided.
# TODO: Do not allow usage of key
# If neither is provided, we throw an error later only if Claude is used.
anthropic_api_key = None
if "anthropicApiKey" in params and params["anthropicApiKey"]:
anthropic_api_key = params["anthropicApiKey"]
print("Using Anthropic API key from client-side settings dialog")
else:
anthropic_api_key = ANTHROPIC_API_KEY
if anthropic_api_key:
print("Using Anthropic API key from environment variable")
anthropic_api_key = get_from_settings_dialog_or_env(
params, "anthropicApiKey", ANTHROPIC_API_KEY
)
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
openai_base_url: Union[str, None] = None
# Base URL for OpenAI API
openai_base_url: str | None = None
# Disable user-specified OpenAI Base URL in prod
if not os.environ.get("IS_PROD"):
if "openAiBaseURL" in params and params["openAiBaseURL"]:
openai_base_url = params["openAiBaseURL"]
print("Using OpenAI Base URL from client-side settings dialog")
else:
openai_base_url = os.environ.get("OPENAI_BASE_URL")
if openai_base_url:
print("Using OpenAI Base URL from environment variable")
if not IS_PROD:
openai_base_url = get_from_settings_dialog_or_env(
params, "openAiBaseURL", OPENAI_BASE_URL
)
if not openai_base_url:
print("Using official OpenAI URL")
# Get the image generation flag from the request. Fall back to True if not provided.
should_generate_images = (
params["isImageGenerationEnabled"]
if "isImageGenerationEnabled" in params
else True
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
return ExtractedParams(
stack=validated_stack,
input_mode=validated_input_mode,
code_generation_model=code_generation_model,
should_generate_images=should_generate_images,
openai_api_key=openai_api_key,
anthropic_api_key=anthropic_api_key,
openai_base_url=openai_base_url,
payment_method=payment_method,
)
print("generating code...")
await websocket.send_json({"type": "status", "value": "Generating code..."})
async def process_chunk(content: str):
await websocket.send_json({"type": "chunk", "value": content})
def get_from_settings_dialog_or_env(
params: dict[str, str], key: str, env_var: str | None
) -> str | None:
value = params.get(key)
if value:
print(f"Using {key} from client-side settings dialog")
return value
if env_var:
print(f"Using {key} from environment variable")
return env_var
return None
@router.websocket("/generate-code")
async def stream_code(websocket: WebSocket):
await websocket.accept()
print("Incoming websocket connection...")
## Communication protocol setup
async def throw_error(
message: str,
):
print(message)
await websocket.send_json({"type": "error", "value": message})
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
async def send_message(
type: Literal["chunk", "status", "setCode", "error"],
value: str,
variantIndex: int,
):
# Print for debugging on the backend
if type == "error":
print(f"Error (variant {variantIndex}): {value}")
elif type == "status":
print(f"Status (variant {variantIndex}): {value}")
await websocket.send_json(
{"type": type, "value": value, "variantIndex": variantIndex}
)
## Parameter extract and validation
# TODO: Are the values always strings?
params: dict[str, str] = await websocket.receive_json()
print("Received params")
extracted_params = await extract_params(params, throw_error)
stack = extracted_params.stack
input_mode = extracted_params.input_mode
code_generation_model = extracted_params.code_generation_model
openai_api_key = extracted_params.openai_api_key
openai_base_url = extracted_params.openai_base_url
anthropic_api_key = extracted_params.anthropic_api_key
should_generate_images = extracted_params.should_generate_images
payment_method = extracted_params.payment_method
# Auto-upgrade usage of older models
code_generation_model = auto_upgrade_model(code_generation_model)
print(
f"Generating {stack} code in {input_mode} mode using {code_generation_model}..."
)
for i in range(NUM_VARIANTS):
await send_message("status", "Generating code...", i)
### Prompt creation
# Image cache for updates so that we don't have to regenerate images
image_cache: Dict[str, str] = {}
# If this generation started off with imported code, we need to assemble the prompt differently
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
original_imported_code = params["history"][0]
prompt_messages = assemble_imported_code_prompt(
original_imported_code, valid_stack, code_generation_model
)
for index, text in enumerate(params["history"][1:]):
# TODO: Remove after "Select and edit" is fully implemented
if "referring to this element specifically" in text:
sentry_sdk.capture_exception(Exception("Point and edit used"))
if index % 2 == 0:
message: ChatCompletionMessageParam = {
"role": "user",
"content": text,
}
else:
message: ChatCompletionMessageParam = {
"role": "assistant",
"content": text,
}
prompt_messages.append(message)
else:
# Assemble the prompt
try:
if validated_input_mode == "image":
if params.get("resultImage") and params["resultImage"]:
prompt_messages = assemble_prompt(
params["image"], valid_stack, params["resultImage"]
)
else:
prompt_messages = assemble_prompt(params["image"], valid_stack)
elif validated_input_mode == "text":
sentry_sdk.capture_exception(Exception("Text generation used"))
prompt_messages = assemble_text_prompt(params["image"], valid_stack)
else:
await throw_error("Invalid input mode")
return
prompt_messages, image_cache = await create_prompt(params, stack, input_mode)
except:
await websocket.send_json(
{
"type": "error",
"value": "Error assembling prompt. Contact support at support@picoapps.xyz",
}
await throw_error(
"Error assembling prompt. Contact support at support@picoapps.xyz"
)
await websocket.close()
return
# Transform the history tree into message format for updates
if params["generationType"] == "update":
# TODO: Move this to frontend
for index, text in enumerate(params["history"]):
if index % 2 == 0:
message: ChatCompletionMessageParam = {
"role": "assistant",
"content": text,
}
else:
message: ChatCompletionMessageParam = {
"role": "user",
"content": text,
}
prompt_messages.append(message)
image_cache = create_alt_url_mapping(params["history"][-2])
if validated_input_mode == "video":
video_data_url = params["image"]
prompt_messages = await assemble_claude_prompt_video(video_data_url)
raise
# pprint_prompt(prompt_messages) # type: ignore
### Code generation
async def process_chunk(content: str, variantIndex: int):
await send_message("chunk", content, variantIndex)
if SHOULD_MOCK_AI_RESPONSE:
completion = await mock_completion(
process_chunk, input_mode=validated_input_mode
)
completions = [await mock_completion(process_chunk, input_mode=input_mode)]
else:
try:
if validated_input_mode == "video":
if input_mode == "video":
if IS_PROD:
raise Exception("Video mode is not supported in prod")
@ -311,48 +316,66 @@ async def stream_code(websocket: WebSocket):
)
raise Exception("No Anthropic key")
completion = await stream_claude_response_native(
completions = [
await stream_claude_response_native(
system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
callback=lambda x: process_chunk(x, 0),
model=Llm.CLAUDE_3_OPUS,
include_thinking=True,
)
exact_llm_version = Llm.CLAUDE_3_OPUS
elif (
code_generation_model == Llm.CLAUDE_3_SONNET
or code_generation_model == Llm.CLAUDE_3_5_SONNET_2024_06_20
):
if not anthropic_api_key:
await throw_error(
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
)
raise Exception("No Anthropic key")
# Do not allow non-subscribers to use Claude
if payment_method != PaymentMethod.SUBSCRIPTION:
await throw_error(
"Please subscribe to a paid plan to use the Claude models"
)
raise Exception("Not subscribed to a paid plan for Claude")
completion = await stream_claude_response(
prompt_messages, # type: ignore
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
model=code_generation_model,
)
exact_llm_version = code_generation_model
]
else:
completion = await stream_openai_response(
prompt_messages, # type: ignore
# Depending on the presence and absence of various keys,
# we decide which models to run
variant_models = []
if openai_api_key and anthropic_api_key:
variant_models = ["openai", "anthropic"]
elif openai_api_key:
variant_models = ["openai", "openai"]
elif anthropic_api_key:
variant_models = ["anthropic", "anthropic"]
else:
await throw_error(
"No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog. If you add it to .env, make sure to restart the backend server."
)
raise Exception("No OpenAI or Anthropic key")
tasks: List[Coroutine[Any, Any, str]] = []
for index, model in enumerate(variant_models):
if model == "openai":
if openai_api_key is None:
await throw_error("OpenAI API key is missing.")
raise Exception("OpenAI API key is missing.")
tasks.append(
stream_openai_response(
prompt_messages,
api_key=openai_api_key,
base_url=openai_base_url,
callback=lambda x: process_chunk(x),
model=code_generation_model,
callback=lambda x, i=index: process_chunk(x, i),
model=Llm.GPT_4O_2024_05_13,
)
exact_llm_version = code_generation_model
)
elif model == "anthropic":
if anthropic_api_key is None:
await throw_error("Anthropic API key is missing.")
raise Exception("Anthropic API key is missing.")
tasks.append(
stream_claude_response(
prompt_messages,
api_key=anthropic_api_key,
callback=lambda x, i=index: process_chunk(x, i),
model=Llm.CLAUDE_3_5_SONNET_2024_06_20,
)
)
completions = await asyncio.gather(*tasks)
print("Models used for generation: ", variant_models)
except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e)
error_message = (
@ -388,13 +411,10 @@ async def stream_code(websocket: WebSocket):
)
return await throw_error(error_message)
if validated_input_mode == "video":
completion = extract_tag_content("html", completion)
print("Exact used model for generation: ", exact_llm_version)
## Post-processing
# Strip the completion of everything except the HTML content
completion = extract_html_content(completion)
completions = [extract_html_content(completion) for completion in completions]
# Write the messages dict into a log so that we can debug later
# write_logs(prompt_messages, completion) # type: ignore
@ -402,54 +422,44 @@ async def stream_code(websocket: WebSocket):
if IS_PROD:
# Catch any errors from sending to SaaS backend and continue
try:
assert exact_llm_version is not None, "exact_llm_version is not set"
# TODO*
# assert exact_llm_version is not None, "exact_llm_version is not set"
await send_to_saas_backend(
prompt_messages,
completion,
# TODO*: Store both completions
completions[0],
payment_method=payment_method,
llm_version=exact_llm_version,
stack=valid_stack,
# TODO*
llm_version=Llm.CLAUDE_3_5_SONNET_2024_06_20,
stack=stack,
is_imported_from_code=bool(params.get("isImportedFromCode", False)),
includes_result_image=bool(params.get("resultImage", False)),
input_mode=validated_input_mode,
input_mode=input_mode,
auth_token=params["authToken"],
)
except Exception as e:
print("Error sending to SaaS backend", e)
try:
if should_generate_images:
await websocket.send_json(
{"type": "status", "value": "Generating images..."}
)
image_generation_model = "sdxl-lightning" if REPLICATE_API_KEY else "dalle3"
print("Generating images with model: ", image_generation_model)
## Image Generation
updated_html = await generate_images(
for index, _ in enumerate(completions):
await send_message("status", "Generating images...", index)
image_generation_tasks = [
perform_image_generation(
completion,
api_key=(
REPLICATE_API_KEY
if image_generation_model == "sdxl-lightning"
else openai_api_key
),
base_url=openai_base_url,
image_cache=image_cache,
model=image_generation_model,
)
else:
updated_html = completion
await websocket.send_json({"type": "setCode", "value": updated_html})
await websocket.send_json(
{"type": "status", "value": "Code generation complete."}
)
except Exception as e:
traceback.print_exc()
print("Image generation failed", e)
# Send set code even if image generation fails since that triggers
# the frontend to update history
await websocket.send_json({"type": "setCode", "value": completion})
await websocket.send_json(
{"type": "status", "value": "Image generation failed but code is complete."}
should_generate_images,
openai_api_key,
openai_base_url,
image_cache,
)
for completion in completions
]
updated_completions = await asyncio.gather(*image_generation_tasks)
for index, updated_html in enumerate(updated_completions):
await send_message("setCode", updated_html, index)
await send_message("status", "Code generation complete.", index)
await websocket.close()

View File

@ -43,6 +43,7 @@
"copy-to-clipboard": "^3.3.3",
"html2canvas": "^1.4.1",
"posthog-js": "^1.128.1",
"nanoid": "^5.0.7",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",

View File

@ -9,8 +9,7 @@ import { usePersistedState } from "./hooks/usePersistedState";
import TermsOfServiceDialog from "./components/TermsOfServiceDialog";
import { USER_CLOSE_WEB_SOCKET_CODE } from "./constants";
import { addEvent } from "./lib/analytics";
import { History } from "./components/history/history_types";
import { extractHistoryTree } from "./components/history/utils";
import { extractHistory } from "./components/history/utils";
import toast from "react-hot-toast";
import { useAuth } from "@clerk/clerk-react";
import { useStore } from "./store/store";
@ -27,6 +26,8 @@ import { GenerationSettings } from "./components/settings/GenerationSettings";
import StartPane from "./components/start-pane/StartPane";
import { takeScreenshot } from "./lib/takeScreenshot";
import Sidebar from "./components/sidebar/Sidebar";
import { Commit } from "./components/commits/types";
import { createCommit } from "./components/commits/utils";
interface Props {
navbarComponent?: JSX.Element;
@ -49,13 +50,19 @@ function App({ navbarComponent }: Props) {
referenceImages,
setReferenceImages,
head,
commits,
addCommit,
removeCommit,
setHead,
appendCommitCode,
setCommitCode,
resetCommits,
resetHead,
// Outputs
setGeneratedCode,
setExecutionConsole,
currentVersion,
setCurrentVersion,
appHistory,
setAppHistory,
appendExecutionConsole,
resetExecutionConsoles,
} = useProjectStore();
const {
@ -119,32 +126,31 @@ function App({ navbarComponent }: Props) {
// Functions
const reset = () => {
setAppState(AppState.INITIAL);
setGeneratedCode("");
setReferenceImages([]);
setInitialPrompt("");
setExecutionConsole([]);
setUpdateInstruction("");
setIsImportedFromCode(false);
setAppHistory([]);
setCurrentVersion(null);
setShouldIncludeResultImage(false);
setUpdateInstruction("");
disableInSelectAndEditMode();
resetExecutionConsoles();
resetCommits();
resetHead();
// Inputs
setInputMode("image");
setReferenceImages([]);
setIsImportedFromCode(false);
};
const regenerate = () => {
if (currentVersion === null) {
// This would be a error that I log to Sentry
addEvent("RegenerateCurrentVersionNull");
if (head === null) {
toast.error(
"No current version set. Please open a Github issue as this shouldn't happen."
"No current version set. Please contact support via chat or Github."
);
return;
throw new Error("Regenerate called with no head");
}
// Retrieve the previous command
const previousCommand = appHistory[currentVersion];
if (previousCommand.type !== "ai_create") {
addEvent("RegenerateNotFirstVersion");
const currentCommit = commits[head];
if (currentCommit.type !== "ai_create") {
toast.error("Only the first version can be regenerated.");
return;
}
@ -165,28 +171,32 @@ function App({ navbarComponent }: Props) {
addEvent("Cancel");
wsRef.current?.close?.(USER_CLOSE_WEB_SOCKET_CODE);
// make sure stop can correct the state even if the websocket is already closed
cancelCodeGenerationAndReset();
};
// Used for code generation failure as well
const cancelCodeGenerationAndReset = () => {
// When this is the first version, reset the entire app state
if (currentVersion === null) {
const cancelCodeGenerationAndReset = (commit: Commit) => {
// When the current commit is the first version, reset the entire app state
if (commit.type === "ai_create") {
reset();
} else {
// Otherwise, revert to the last version
setGeneratedCode(appHistory[currentVersion].code);
// Otherwise, remove current commit from commits
removeCommit(commit.hash);
// Revert to parent commit
const parentCommitHash = commit.parentHash;
if (parentCommitHash) {
setHead(parentCommitHash);
} else {
throw new Error("Parent commit not found");
}
setAppState(AppState.CODE_READY);
}
};
async function doGenerateCode(
params: CodeGenerationParams,
parentVersion: number | null
) {
async function doGenerateCode(params: CodeGenerationParams) {
// Reset the execution console
setExecutionConsole([]);
resetExecutionConsoles();
// Set the app state
setAppState(AppState.CODING);
@ -199,69 +209,50 @@ function App({ navbarComponent }: Props) {
authToken: authToken || undefined,
};
const baseCommitObject = {
variants: [{ code: "" }, { code: "" }],
};
const commitInputObject =
params.generationType === "create"
? {
...baseCommitObject,
type: "ai_create" as const,
parentHash: null,
inputs: { image_url: referenceImages[0] },
}
: {
...baseCommitObject,
type: "ai_edit" as const,
parentHash: head,
inputs: {
prompt: params.history
? params.history[params.history.length - 1]
: "",
},
};
// Create a new commit and set it as the head
const commit = createCommit(commitInputObject);
addCommit(commit);
setHead(commit.hash);
generateCode(
wsRef,
updatedParams,
// On change
(token) => setGeneratedCode((prev) => prev + token),
(token, variantIndex) => {
appendCommitCode(commit.hash, variantIndex, token);
},
// On set code
(code) => {
setGeneratedCode(code);
if (params.generationType === "create") {
if (inputMode === "image" || inputMode === "video") {
setAppHistory([
{
type: "ai_create",
parentIndex: null,
code,
inputs: { image_url: referenceImages[0] },
},
]);
} else {
setAppHistory([
{
type: "ai_create",
parentIndex: null,
code,
inputs: { text: params.image },
},
]);
}
setCurrentVersion(0);
} else {
setAppHistory((prev) => {
// Validate parent version
if (parentVersion === null) {
toast.error(
"No parent version set. Contact support or open a Github issue."
);
addEvent("ParentVersionNull");
return prev;
}
const newHistory: History = [
...prev,
{
type: "ai_edit",
parentIndex: parentVersion,
code,
inputs: {
prompt: params.history
? params.history[params.history.length - 1]
: "", // History should never be empty when performing an edit
},
},
];
setCurrentVersion(newHistory.length - 1);
return newHistory;
});
}
(code, variantIndex) => {
setCommitCode(commit.hash, variantIndex, code);
},
// On status update
(line) => setExecutionConsole((prev) => [...prev, line]),
(line, variantIndex) => appendExecutionConsole(variantIndex, line),
// On cancel
() => {
cancelCodeGenerationAndReset();
cancelCodeGenerationAndReset(commit);
},
// On complete
() => {
@ -285,14 +276,11 @@ function App({ navbarComponent }: Props) {
// Kick off the code generation
if (referenceImages.length > 0) {
addEvent("Create");
await doGenerateCode(
{
doGenerateCode({
generationType: "create",
image: referenceImages[0],
inputMode,
},
currentVersion
);
});
}
}
@ -302,14 +290,11 @@ function App({ navbarComponent }: Props) {
setInputMode("text");
setInitialPrompt(text);
doGenerateCode(
{
doGenerateCode({
generationType: "create",
inputMode: "text",
image: text,
},
currentVersion
);
});
}
// Subsequent updates
@ -322,23 +307,22 @@ function App({ navbarComponent }: Props) {
return;
}
if (currentVersion === null) {
if (head === null) {
toast.error(
"No current version set. Contact support or open a Github issue."
);
addEvent("CurrentVersionNull");
return;
throw new Error("Update called with no head");
}
let historyTree;
try {
historyTree = extractHistoryTree(appHistory, currentVersion);
historyTree = extractHistory(head, commits);
} catch {
addEvent("HistoryTreeFailed");
toast.error(
"Version history is invalid. This shouldn't happen. Please contact support or open a Github issue."
);
return;
throw new Error("Invalid version history");
}
let modifiedUpdateInstruction = updateInstruction;
@ -352,34 +336,19 @@ function App({ navbarComponent }: Props) {
}
const updatedHistory = [...historyTree, modifiedUpdateInstruction];
const resultImage = shouldIncludeResultImage
? await takeScreenshot()
: undefined;
if (shouldIncludeResultImage) {
const resultImage = await takeScreenshot();
await doGenerateCode(
{
doGenerateCode({
generationType: "update",
inputMode,
image: referenceImages[0],
resultImage: resultImage,
resultImage,
history: updatedHistory,
isImportedFromCode,
},
currentVersion
);
} else {
await doGenerateCode(
{
generationType: "update",
inputMode,
image: inputMode === "text" ? initialPrompt : referenceImages[0],
history: updatedHistory,
isImportedFromCode,
},
currentVersion
);
}
});
setGeneratedCode("");
setUpdateInstruction("");
}
@ -402,17 +371,17 @@ function App({ navbarComponent }: Props) {
setIsImportedFromCode(true);
// Set up this project
setGeneratedCode(code);
setStack(stack);
setAppHistory([
{
// Create a new commit and set it as the head
const commit = createCommit({
type: "code_create",
parentIndex: null,
code,
inputs: { code },
},
]);
setCurrentVersion(0);
parentHash: null,
variants: [{ code }],
inputs: null,
});
addCommit(commit);
setHead(commit.hash);
// Set the app state
setAppState(AppState.CODE_READY);

View File

@ -0,0 +1,37 @@
export type CommitHash = string;
export type Variant = {
code: string;
};
export type BaseCommit = {
hash: CommitHash;
parentHash: CommitHash | null;
dateCreated: Date;
isCommitted: boolean;
variants: Variant[];
selectedVariantIndex: number;
};
export type CommitType = "ai_create" | "ai_edit" | "code_create";
export type AiCreateCommit = BaseCommit & {
type: "ai_create";
inputs: {
image_url: string;
};
};
export type AiEditCommit = BaseCommit & {
type: "ai_edit";
inputs: {
prompt: string;
};
};
export type CodeCreateCommit = BaseCommit & {
type: "code_create";
inputs: null;
};
export type Commit = AiCreateCommit | AiEditCommit | CodeCreateCommit;

View File

@ -0,0 +1,32 @@
import { nanoid } from "nanoid";
import {
AiCreateCommit,
AiEditCommit,
CodeCreateCommit,
Commit,
} from "./types";
export function createCommit(
commit:
| Omit<
AiCreateCommit,
"hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted"
>
| Omit<
AiEditCommit,
"hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted"
>
| Omit<
CodeCreateCommit,
"hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted"
>
): Commit {
const hash = nanoid();
return {
...commit,
hash,
isCommitted: false,
dateCreated: new Date(),
selectedVariantIndex: 0,
};
}

View File

@ -17,19 +17,16 @@ interface Props {
}
export default function HistoryDisplay({ shouldDisableReverts }: Props) {
const {
appHistory: history,
currentVersion,
setCurrentVersion,
setGeneratedCode,
} = useProjectStore();
const renderedHistory = renderHistory(history, currentVersion);
const { commits, head, setHead } = useProjectStore();
const revertToVersion = (index: number) => {
if (index < 0 || index >= history.length || !history[index]) return;
setCurrentVersion(index);
setGeneratedCode(history[index].code);
};
// Put all commits into an array and sort by created date (oldest first)
const flatHistory = Object.values(commits).sort(
(a, b) =>
new Date(a.dateCreated).getTime() - new Date(b.dateCreated).getTime()
);
// Annotate history items with a summary, parent version, etc.
const renderedHistory = renderHistory(flatHistory);
return renderedHistory.length === 0 ? null : (
<div className="flex flex-col h-screen">
@ -43,8 +40,8 @@ export default function HistoryDisplay({ shouldDisableReverts }: Props) {
"flex items-center justify-between space-x-2 w-full pr-2",
"border-b cursor-pointer",
{
" hover:bg-black hover:text-white": !item.isActive,
"bg-slate-500 text-white": item.isActive,
" hover:bg-black hover:text-white": item.hash === head,
"bg-slate-500 text-white": item.hash === head,
}
)}
>
@ -55,14 +52,14 @@ export default function HistoryDisplay({ shouldDisableReverts }: Props) {
? toast.error(
"Please wait for code generation to complete before viewing an older version."
)
: revertToVersion(index)
: setHead(item.hash)
}
>
<div className="flex gap-x-1 truncate">
<h2 className="text-sm truncate">{item.summary}</h2>
{item.parentVersion !== null && (
<h2 className="text-sm">
(parent: {item.parentVersion})
(parent: v{item.parentVersion})
</h2>
)}
</div>

View File

@ -1,45 +0,0 @@
export type HistoryItemType = "ai_create" | "ai_edit" | "code_create";
type CommonHistoryItem = {
parentIndex: null | number;
code: string;
};
export type HistoryItem =
| ({
type: "ai_create";
inputs: AiCreateInputs | AiCreateInputsText;
} & CommonHistoryItem)
| ({
type: "ai_edit";
inputs: AiEditInputs;
} & CommonHistoryItem)
| ({
type: "code_create";
inputs: CodeCreateInputs;
} & CommonHistoryItem);
export type AiCreateInputs = {
image_url: string;
};
export type AiCreateInputsText = {
text: string;
};
export type AiEditInputs = {
prompt: string;
};
export type CodeCreateInputs = {
code: string;
};
export type History = HistoryItem[];
export type RenderedHistoryItem = {
type: string;
summary: string;
parentVersion: string | null;
isActive: boolean;
};

View File

@ -1,91 +1,125 @@
import { extractHistoryTree, renderHistory } from "./utils";
import type { History } from "./history_types";
import { extractHistory, renderHistory } from "./utils";
import { Commit, CommitHash } from "../commits/types";
const basicLinearHistory: History = [
{
const basicLinearHistory: Record<CommitHash, Commit> = {
"0": {
hash: "0",
dateCreated: new Date(),
isCommitted: false,
type: "ai_create",
parentIndex: null,
code: "<html>1. create</html>",
parentHash: null,
variants: [{ code: "<html>1. create</html>" }],
selectedVariantIndex: 0,
inputs: {
image_url: "",
},
},
{
"1": {
hash: "1",
dateCreated: new Date(),
isCommitted: false,
type: "ai_edit",
parentIndex: 0,
code: "<html>2. edit with better icons</html>",
parentHash: "0",
variants: [{ code: "<html>2. edit with better icons</html>" }],
selectedVariantIndex: 0,
inputs: {
prompt: "use better icons",
},
},
{
"2": {
hash: "2",
dateCreated: new Date(),
isCommitted: false,
type: "ai_edit",
parentIndex: 1,
code: "<html>3. edit with better icons and red text</html>",
parentHash: "1",
variants: [{ code: "<html>3. edit with better icons and red text</html>" }],
selectedVariantIndex: 0,
inputs: {
prompt: "make text red",
},
},
];
};
const basicLinearHistoryWithCode: History = [
{
const basicLinearHistoryWithCode: Record<CommitHash, Commit> = {
"0": {
hash: "0",
dateCreated: new Date(),
isCommitted: false,
type: "code_create",
parentIndex: null,
code: "<html>1. create</html>",
inputs: {
code: "<html>1. create</html>",
parentHash: null,
variants: [{ code: "<html>1. create</html>" }],
selectedVariantIndex: 0,
inputs: null,
},
},
...basicLinearHistory.slice(1),
];
...Object.fromEntries(Object.entries(basicLinearHistory).slice(1)),
};
const basicBranchingHistory: History = [
const basicBranchingHistory: Record<CommitHash, Commit> = {
...basicLinearHistory,
{
"3": {
hash: "3",
dateCreated: new Date(),
isCommitted: false,
type: "ai_edit",
parentIndex: 1,
code: "<html>4. edit with better icons and green text</html>",
parentHash: "1",
variants: [
{ code: "<html>4. edit with better icons and green text</html>" },
],
selectedVariantIndex: 0,
inputs: {
prompt: "make text green",
},
},
];
};
const longerBranchingHistory: History = [
const longerBranchingHistory: Record<CommitHash, Commit> = {
...basicBranchingHistory,
{
"4": {
hash: "4",
dateCreated: new Date(),
isCommitted: false,
type: "ai_edit",
parentIndex: 3,
code: "<html>5. edit with better icons and green, bold text</html>",
parentHash: "3",
variants: [
{ code: "<html>5. edit with better icons and green, bold text</html>" },
],
selectedVariantIndex: 0,
inputs: {
prompt: "make text bold",
},
},
];
};
const basicBadHistory: History = [
{
const basicBadHistory: Record<CommitHash, Commit> = {
"0": {
hash: "0",
dateCreated: new Date(),
isCommitted: false,
type: "ai_create",
parentIndex: null,
code: "<html>1. create</html>",
parentHash: null,
variants: [{ code: "<html>1. create</html>" }],
selectedVariantIndex: 0,
inputs: {
image_url: "",
},
},
{
"1": {
hash: "1",
dateCreated: new Date(),
isCommitted: false,
type: "ai_edit",
parentIndex: 2, // <- Bad parent index
code: "<html>2. edit with better icons</html>",
parentHash: "2", // <- Bad parent hash
variants: [{ code: "<html>2. edit with better icons</html>" }],
selectedVariantIndex: 0,
inputs: {
prompt: "use better icons",
},
},
];
};
describe("History Utils", () => {
test("should correctly extract the history tree", () => {
expect(extractHistoryTree(basicLinearHistory, 2)).toEqual([
expect(extractHistory("2", basicLinearHistory)).toEqual([
"<html>1. create</html>",
"use better icons",
"<html>2. edit with better icons</html>",
@ -93,12 +127,12 @@ describe("History Utils", () => {
"<html>3. edit with better icons and red text</html>",
]);
expect(extractHistoryTree(basicLinearHistory, 0)).toEqual([
expect(extractHistory("0", basicLinearHistory)).toEqual([
"<html>1. create</html>",
]);
// Test branching
expect(extractHistoryTree(basicBranchingHistory, 3)).toEqual([
expect(extractHistory("3", basicBranchingHistory)).toEqual([
"<html>1. create</html>",
"use better icons",
"<html>2. edit with better icons</html>",
@ -106,7 +140,7 @@ describe("History Utils", () => {
"<html>4. edit with better icons and green text</html>",
]);
expect(extractHistoryTree(longerBranchingHistory, 4)).toEqual([
expect(extractHistory("4", longerBranchingHistory)).toEqual([
"<html>1. create</html>",
"use better icons",
"<html>2. edit with better icons</html>",
@ -116,7 +150,7 @@ describe("History Utils", () => {
"<html>5. edit with better icons and green, bold text</html>",
]);
expect(extractHistoryTree(longerBranchingHistory, 2)).toEqual([
expect(extractHistory("2", longerBranchingHistory)).toEqual([
"<html>1. create</html>",
"use better icons",
"<html>2. edit with better icons</html>",
@ -126,105 +160,82 @@ describe("History Utils", () => {
// Errors
// Bad index
expect(() => extractHistoryTree(basicLinearHistory, 100)).toThrow();
expect(() => extractHistoryTree(basicLinearHistory, -2)).toThrow();
// Bad hash
expect(() => extractHistory("100", basicLinearHistory)).toThrow();
// Bad tree
expect(() => extractHistoryTree(basicBadHistory, 1)).toThrow();
expect(() => extractHistory("1", basicBadHistory)).toThrow();
});
test("should correctly render the history tree", () => {
expect(renderHistory(basicLinearHistory, 2)).toEqual([
expect(renderHistory(Object.values(basicLinearHistory))).toEqual([
{
isActive: false,
parentVersion: null,
summary: "Create",
...basicLinearHistory["0"],
type: "Create",
},
{
isActive: false,
parentVersion: null,
summary: "use better icons",
type: "Edit",
},
{
isActive: true,
parentVersion: null,
summary: "make text red",
type: "Edit",
},
]);
// Current version is the first version
expect(renderHistory(basicLinearHistory, 0)).toEqual([
{
isActive: true,
parentVersion: null,
summary: "Create",
type: "Create",
parentVersion: null,
},
{
isActive: false,
parentVersion: null,
...basicLinearHistory["1"],
type: "Edit",
summary: "use better icons",
type: "Edit",
parentVersion: null,
},
{
isActive: false,
parentVersion: null,
summary: "make text red",
...basicLinearHistory["2"],
type: "Edit",
summary: "make text red",
parentVersion: null,
},
]);
// Render a history with code
expect(renderHistory(basicLinearHistoryWithCode, 0)).toEqual([
expect(renderHistory(Object.values(basicLinearHistoryWithCode))).toEqual([
{
isActive: true,
parentVersion: null,
summary: "Imported from code",
...basicLinearHistoryWithCode["0"],
type: "Imported from code",
summary: "Imported from code",
parentVersion: null,
},
{
isActive: false,
parentVersion: null,
...basicLinearHistoryWithCode["1"],
type: "Edit",
summary: "use better icons",
type: "Edit",
parentVersion: null,
},
{
isActive: false,
parentVersion: null,
summary: "make text red",
...basicLinearHistoryWithCode["2"],
type: "Edit",
summary: "make text red",
parentVersion: null,
},
]);
// Render a non-linear history
expect(renderHistory(basicBranchingHistory, 3)).toEqual([
expect(renderHistory(Object.values(basicBranchingHistory))).toEqual([
{
isActive: false,
parentVersion: null,
summary: "Create",
...basicBranchingHistory["0"],
type: "Create",
summary: "Create",
parentVersion: null,
},
{
isActive: false,
parentVersion: null,
...basicBranchingHistory["1"],
type: "Edit",
summary: "use better icons",
type: "Edit",
},
{
isActive: false,
parentVersion: null,
summary: "make text red",
type: "Edit",
},
{
isActive: true,
parentVersion: "v2",
summary: "make text green",
...basicBranchingHistory["2"],
type: "Edit",
summary: "make text red",
parentVersion: null,
},
{
...basicBranchingHistory["3"],
type: "Edit",
summary: "make text green",
parentVersion: 2,
},
]);
});

View File

@ -1,33 +1,25 @@
import {
History,
HistoryItem,
HistoryItemType,
RenderedHistoryItem,
} from "./history_types";
import { Commit, CommitHash, CommitType } from "../commits/types";
export function extractHistoryTree(
history: History,
version: number
export function extractHistory(
hash: CommitHash,
commits: Record<CommitHash, Commit>
): string[] {
const flatHistory: string[] = [];
let currentIndex: number | null = version;
while (currentIndex !== null) {
const item: HistoryItem = history[currentIndex];
let currentCommitHash: CommitHash | null = hash;
while (currentCommitHash !== null) {
const commit: Commit | null = commits[currentCommitHash];
if (item) {
if (item.type === "ai_create") {
// Don't include the image for ai_create
flatHistory.unshift(item.code);
} else if (item.type === "ai_edit") {
flatHistory.unshift(item.code);
flatHistory.unshift(item.inputs.prompt);
} else if (item.type === "code_create") {
flatHistory.unshift(item.code);
if (commit) {
flatHistory.unshift(commit.variants[commit.selectedVariantIndex].code);
// For edits, add the prompt to the history
if (commit.type === "ai_edit") {
flatHistory.unshift(commit.inputs.prompt);
}
// Move to the parent of the current item
currentIndex = item.parentIndex;
currentCommitHash = commit.parentHash;
} else {
throw new Error("Malformed history: missing parent index");
}
@ -36,7 +28,7 @@ export function extractHistoryTree(
return flatHistory;
}
function displayHistoryItemType(itemType: HistoryItemType) {
function displayHistoryItemType(itemType: CommitType) {
switch (itemType) {
case "ai_create":
return "Create";
@ -51,44 +43,48 @@ function displayHistoryItemType(itemType: HistoryItemType) {
}
}
function summarizeHistoryItem(item: HistoryItem) {
const itemType = item.type;
switch (itemType) {
const setParentVersion = (commit: Commit, history: Commit[]) => {
// If the commit has no parent, return null
if (!commit.parentHash) return null;
const parentIndex = history.findIndex(
(item) => item.hash === commit.parentHash
);
const currentIndex = history.findIndex((item) => item.hash === commit.hash);
// Only set parent version if the parent is not the previous commit
// and parent exists
return parentIndex !== -1 && parentIndex != currentIndex - 1
? parentIndex + 1
: null;
};
export function summarizeHistoryItem(commit: Commit) {
const commitType = commit.type;
switch (commitType) {
case "ai_create":
return "Create";
case "ai_edit":
return item.inputs.prompt;
return commit.inputs.prompt;
case "code_create":
return "Imported from code";
default: {
const exhaustiveCheck: never = itemType;
const exhaustiveCheck: never = commitType;
throw new Error(`Unhandled case: ${exhaustiveCheck}`);
}
}
}
export const renderHistory = (
history: History,
currentVersion: number | null
) => {
const renderedHistory: RenderedHistoryItem[] = [];
export const renderHistory = (history: Commit[]) => {
const renderedHistory = [];
for (let i = 0; i < history.length; i++) {
const item = history[i];
// Only show the parent version if it's not the previous version
// (i.e. it's the branching point) and if it's not the first version
const parentVersion =
item.parentIndex !== null && item.parentIndex !== i - 1
? `v${(item.parentIndex || 0) + 1}`
: null;
const type = displayHistoryItemType(item.type);
const isActive = i === currentVersion;
const summary = summarizeHistoryItem(item);
const commit = history[i];
renderedHistory.push({
isActive,
summary: summary,
parentVersion,
type,
...commit,
type: displayHistoryItemType(commit.type),
summary: summarizeHistoryItem(commit),
parentVersion: setParentVersion(commit, history),
});
}

View File

@ -23,12 +23,17 @@ interface Props {
function PreviewPane({ doUpdate, reset, settings }: Props) {
const { appState } = useAppStore();
const { inputMode, generatedCode, setGeneratedCode } = useProjectStore();
const { inputMode, head, commits } = useProjectStore();
const currentCommit = head && commits[head] ? commits[head] : "";
const currentCode = currentCommit
? currentCommit.variants[currentCommit.selectedVariantIndex].code
: "";
const previewCode =
inputMode === "video" && appState === AppState.CODING
? extractHtml(generatedCode)
: generatedCode;
? extractHtml(currentCode)
: currentCode;
return (
<div className="ml-4">
@ -45,7 +50,7 @@ function PreviewPane({ doUpdate, reset, settings }: Props) {
Reset
</Button>
<Button
onClick={() => downloadCode(generatedCode)}
onClick={() => downloadCode(previewCode)}
variant="secondary"
className="flex items-center gap-x-2 mr-4 dark:text-white dark:bg-gray-700 download-btn"
>
@ -84,11 +89,7 @@ function PreviewPane({ doUpdate, reset, settings }: Props) {
/>
</TabsContent>
<TabsContent value="code">
<CodeTab
code={previewCode}
setCode={setGeneratedCode}
settings={settings}
/>
<CodeTab code={previewCode} setCode={() => {}} settings={settings} />
</TabsContent>
</Tabs>
</div>

View File

@ -12,6 +12,7 @@ import { Button } from "../ui/button";
import { Textarea } from "../ui/textarea";
import { useEffect, useRef } from "react";
import HistoryDisplay from "../history/HistoryDisplay";
import Variants from "../variants/Variants";
interface SidebarProps {
showSelectAndEditFeature: boolean;
@ -35,9 +36,18 @@ function Sidebar({
shouldIncludeResultImage,
setShouldIncludeResultImage,
} = useAppStore();
const { inputMode, generatedCode, referenceImages, executionConsole } =
const { inputMode, referenceImages, executionConsoles, head, commits } =
useProjectStore();
const viewedCode =
head && commits[head]
? commits[head].variants[commits[head].selectedVariantIndex].code
: "";
const executionConsole =
(head && executionConsoles[commits[head].selectedVariantIndex]) || [];
// When coding is complete, focus on the update instruction textarea
useEffect(() => {
if (appState === AppState.CODE_READY && textareaRef.current) {
@ -47,6 +57,8 @@ function Sidebar({
return (
<>
<Variants />
{/* Show code preview only when coding */}
{appState === AppState.CODING && (
<div className="flex flex-col">
@ -66,7 +78,7 @@ function Sidebar({
{executionConsole.slice(-1)[0]}
</div>
<CodePreview code={generatedCode} />
<CodePreview code={viewedCode} />
<div className="flex w-full">
<Button
@ -158,15 +170,22 @@ function Sidebar({
)}
<div className="bg-gray-400 px-4 py-2 rounded text-sm hidden">
<h2 className="text-lg mb-4 border-b border-gray-800">Console</h2>
{executionConsole.map((line, index) => (
{Object.entries(executionConsoles).map(([index, lines]) => (
<div key={index}>
{lines.map((line, lineIndex) => (
<div
key={index}
key={`${index}-${lineIndex}`}
className="border-b border-gray-400 mb-2 text-gray-600 font-mono"
>
<span className="font-bold mr-2">{`${index}:${
lineIndex + 1
}`}</span>
{line}
</div>
))}
</div>
))}
</div>
</div>
<HistoryDisplay shouldDisableReverts={appState === AppState.CODING} />

View File

@ -0,0 +1,42 @@
import { useProjectStore } from "../../store/project-store";
function Variants() {
const { inputMode, head, commits, updateSelectedVariantIndex } =
useProjectStore();
// If there is no head, don't show the variants
if (head === null) {
return null;
}
const commit = commits[head];
const variants = commit.variants;
const selectedVariantIndex = commit.selectedVariantIndex;
// If there is only one variant or the commit is already committed, don't show the variants
if (variants.length <= 1 || commit.isCommitted || inputMode === "video") {
return <div className="mt-2"></div>;
}
return (
<div className="mt-4 mb-4">
<div className="grid grid-cols-2 gap-2">
{variants.map((_, index) => (
<div
key={index}
className={`p-2 border rounded-md cursor-pointer ${
index === selectedVariantIndex
? "bg-blue-100 dark:bg-blue-900"
: "bg-gray-50 dark:bg-gray-800 hover:bg-gray-100 dark:hover:bg-gray-700"
}`}
onClick={() => updateSelectedVariantIndex(head, index)}
>
<h3 className="font-medium mb-1">Option {index + 1}</h3>
</div>
))}
</div>
</div>
);
}
export default Variants;

View File

@ -11,12 +11,18 @@ const ERROR_MESSAGE =
const CANCEL_MESSAGE = "Code generation cancelled";
type WebSocketResponse = {
type: "chunk" | "status" | "setCode" | "error";
value: string;
variantIndex: number;
};
export function generateCode(
wsRef: React.MutableRefObject<WebSocket | null>,
params: FullGenerationSettings,
onChange: (chunk: string) => void,
onSetCode: (code: string) => void,
onStatusUpdate: (status: string) => void,
onChange: (chunk: string, variantIndex: number) => void,
onSetCode: (code: string, variantIndex: number) => void,
onStatusUpdate: (status: string, variantIndex: number) => void,
onCancel: () => void,
onComplete: () => void
) {
@ -31,13 +37,13 @@ export function generateCode(
});
ws.addEventListener("message", async (event: MessageEvent) => {
const response = JSON.parse(event.data);
const response = JSON.parse(event.data) as WebSocketResponse;
if (response.type === "chunk") {
onChange(response.value);
onChange(response.value, response.variantIndex);
} else if (response.type === "status") {
onStatusUpdate(response.value);
onStatusUpdate(response.value, response.variantIndex);
} else if (response.type === "setCode") {
onSetCode(response.value);
onSetCode(response.value, response.variantIndex);
} else if (response.type === "error") {
console.error("Error generating code", response.value);
toast.error(response.value);

View File

@ -1,5 +1,5 @@
import { create } from "zustand";
import { History } from "../components/history/history_types";
import { Commit, CommitHash } from "../components/commits/types";
// Store for app-wide state
interface ProjectStore {
@ -11,25 +11,28 @@ interface ProjectStore {
referenceImages: string[];
setReferenceImages: (images: string[]) => void;
// Outputs and other state
generatedCode: string;
setGeneratedCode: (
updater: string | ((currentCode: string) => string)
) => void;
executionConsole: string[];
setExecutionConsole: (
updater: string[] | ((currentConsole: string[]) => string[])
) => void;
// Outputs
commits: Record<string, Commit>;
head: CommitHash | null;
// Tracks the currently shown version from app history
// TODO: might want to move to appStore
currentVersion: number | null;
setCurrentVersion: (version: number | null) => void;
addCommit: (commit: Commit) => void;
removeCommit: (hash: CommitHash) => void;
resetCommits: () => void;
appHistory: History;
setAppHistory: (
updater: History | ((currentHistory: History) => History)
appendCommitCode: (
hash: CommitHash,
numVariant: number,
code: string
) => void;
setCommitCode: (hash: CommitHash, numVariant: number, code: string) => void;
updateSelectedVariantIndex: (hash: CommitHash, index: number) => void;
setHead: (hash: CommitHash) => void;
resetHead: () => void;
executionConsoles: { [key: number]: string[] };
appendExecutionConsole: (variantIndex: number, line: string) => void;
resetExecutionConsoles: () => void;
}
export const useProjectStore = create<ProjectStore>((set) => ({
@ -41,28 +44,106 @@ export const useProjectStore = create<ProjectStore>((set) => ({
referenceImages: [],
setReferenceImages: (images) => set({ referenceImages: images }),
// Outputs and other state
generatedCode: "",
setGeneratedCode: (updater) =>
set((state) => ({
generatedCode:
typeof updater === "function" ? updater(state.generatedCode) : updater,
})),
executionConsole: [],
setExecutionConsole: (updater) =>
set((state) => ({
executionConsole:
typeof updater === "function"
? updater(state.executionConsole)
: updater,
})),
// Outputs
commits: {},
head: null,
currentVersion: null,
setCurrentVersion: (version) => set({ currentVersion: version }),
appHistory: [],
setAppHistory: (updater) =>
addCommit: (commit: Commit) => {
// When adding a new commit, make sure all existing commits are marked as committed
set((state) => ({
appHistory:
typeof updater === "function" ? updater(state.appHistory) : updater,
})),
commits: {
...Object.fromEntries(
Object.entries(state.commits).map(([hash, existingCommit]) => [
hash,
{ ...existingCommit, isCommitted: true },
])
),
[commit.hash]: commit,
},
}));
},
removeCommit: (hash: CommitHash) => {
set((state) => {
const newCommits = { ...state.commits };
delete newCommits[hash];
return { commits: newCommits };
});
},
resetCommits: () => set({ commits: {} }),
appendCommitCode: (hash: CommitHash, numVariant: number, code: string) =>
set((state) => {
const commit = state.commits[hash];
// Don't update if the commit is already committed
if (commit.isCommitted) {
throw new Error("Attempted to append code to a committed commit");
}
return {
commits: {
...state.commits,
[hash]: {
...commit,
variants: commit.variants.map((variant, index) =>
index === numVariant
? { ...variant, code: variant.code + code }
: variant
),
},
},
};
}),
setCommitCode: (hash: CommitHash, numVariant: number, code: string) =>
set((state) => {
const commit = state.commits[hash];
// Don't update if the commit is already committed
if (commit.isCommitted) {
throw new Error("Attempted to set code of a committed commit");
}
return {
commits: {
...state.commits,
[hash]: {
...commit,
variants: commit.variants.map((variant, index) =>
index === numVariant ? { ...variant, code } : variant
),
},
},
};
}),
updateSelectedVariantIndex: (hash: CommitHash, index: number) =>
set((state) => {
const commit = state.commits[hash];
// Don't update if the commit is already committed
if (commit.isCommitted) {
throw new Error(
"Attempted to update selected variant index of a committed commit"
);
}
return {
commits: {
...state.commits,
[hash]: {
...commit,
selectedVariantIndex: index,
},
},
};
}),
setHead: (hash: CommitHash) => set({ head: hash }),
resetHead: () => set({ head: null }),
executionConsoles: {},
appendExecutionConsole: (variantIndex: number, line: string) =>
set((state) => ({
executionConsoles: {
...state.executionConsoles,
[variantIndex]: [
...(state.executionConsoles[variantIndex] || []),
line,
],
},
})),
resetExecutionConsoles: () => set({ executionConsoles: {} }),
}));

View File

@ -4582,6 +4582,11 @@ nanoid@^3.3.6, nanoid@^3.3.7:
resolved "https://registry.npmjs.org/nanoid/-/nanoid-3.3.7.tgz"
integrity sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g==
nanoid@^5.0.7:
version "5.0.7"
resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-5.0.7.tgz#6452e8c5a816861fd9d2b898399f7e5fd6944cc6"
integrity sha512-oLxFY2gd2IqnjcYyOXD8XGCftpGtZP2AbHbOkthDkvRywH5ayNtPVy9YlOPcHckXzbLTCHpkb7FB+yuxKV13pQ==
natural-compare@^1.4.0:
version "1.4.0"
resolved "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz"