diff --git a/README.md b/README.md
index 77c93ed..ae59e48 100644
--- a/README.md
+++ b/README.md
@@ -33,10 +33,6 @@ We also just added experimental support for taking a video/screen recording of a
[Follow me on Twitter for updates](https://twitter.com/_abi_).
-## Sponsors
-
-
-
## 🚀 Hosted Version
[Try it live on the hosted version (paid)](https://screenshottocode.com).
diff --git a/backend/.pre-commit-config.yaml b/backend/.pre-commit-config.yaml
index b54da93..b27eb3a 100644
--- a/backend/.pre-commit-config.yaml
+++ b/backend/.pre-commit-config.yaml
@@ -7,19 +7,19 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- - repo: local
- hooks:
- - id: poetry-pytest
- name: Run pytest with Poetry
- entry: poetry run --directory backend pytest
- language: system
- pass_filenames: false
- always_run: true
- files: ^backend/
- # - id: poetry-pyright
- # name: Run pyright with Poetry
- # entry: poetry run --directory backend pyright
- # language: system
- # pass_filenames: false
- # always_run: true
- # files: ^backend/
+ # - repo: local
+ # hooks:
+ # - id: poetry-pytest
+ # name: Run pytest with Poetry
+ # entry: poetry run --directory backend pytest
+ # language: system
+ # pass_filenames: false
+ # always_run: true
+ # files: ^backend/
+ # - id: poetry-pyright
+ # name: Run pyright with Poetry
+ # entry: poetry run --directory backend pyright
+ # language: system
+ # pass_filenames: false
+ # always_run: true
+ # files: ^backend/
diff --git a/backend/config.py b/backend/config.py
index dfb6d9d..34fb3bb 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -3,8 +3,12 @@
# TODO: Should only be set to true when value is 'True', not any abitrary truthy value
import os
+NUM_VARIANTS = 2
+
+# LLM-related
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
+OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", None)
# Image generation (optional)
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
diff --git a/backend/fs_logging/__init__.py b/backend/fs_logging/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/backend/fs_logging/core.py b/backend/fs_logging/core.py
new file mode 100644
index 0000000..e89096f
--- /dev/null
+++ b/backend/fs_logging/core.py
@@ -0,0 +1,23 @@
+from datetime import datetime
+import json
+import os
+from openai.types.chat import ChatCompletionMessageParam
+
+
+def write_logs(prompt_messages: list[ChatCompletionMessageParam], completion: str):
+ # Get the logs path from environment, default to the current working directory
+ logs_path = os.environ.get("LOGS_PATH", os.getcwd())
+
+ # Create run_logs directory if it doesn't exist within the specified logs path
+ logs_directory = os.path.join(logs_path, "run_logs")
+ if not os.path.exists(logs_directory):
+ os.makedirs(logs_directory)
+
+ print("Writing to logs directory:", logs_directory)
+
+ # Generate a unique filename using the current timestamp within the logs directory
+ filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json")
+
+ # Write the messages dict into a new file for each run
+ with open(filename, "w") as f:
+ f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
diff --git a/backend/image_generation/core.py b/backend/image_generation/core.py
index b5f3d1b..f8ee462 100644
--- a/backend/image_generation/core.py
+++ b/backend/image_generation/core.py
@@ -28,7 +28,7 @@ async def process_tasks(
processed_results: List[Union[str, None]] = []
for result in results:
- if isinstance(result, Exception):
+ if isinstance(result, BaseException):
print(f"An exception occurred: {result}")
try:
raise result
diff --git a/backend/llm.py b/backend/llm.py
index ab8de7a..2d202eb 100644
--- a/backend/llm.py
+++ b/backend/llm.py
@@ -1,3 +1,4 @@
+import copy
from enum import Enum
from typing import Any, Awaitable, Callable, List, cast
from anthropic import AsyncAnthropic
@@ -112,8 +113,12 @@ async def stream_claude_response(
temperature = 0.0
# Translate OpenAI messages to Claude messages
- system_prompt = cast(str, messages[0].get("content"))
- claude_messages = [dict(message) for message in messages[1:]]
+
+ # Deep copy messages to avoid modifying the original list
+ cloned_messages = copy.deepcopy(messages)
+
+ system_prompt = cast(str, cloned_messages[0].get("content"))
+ claude_messages = [dict(message) for message in cloned_messages[1:]]
for message in claude_messages:
if not isinstance(message["content"], list):
continue
diff --git a/backend/mock_llm.py b/backend/mock_llm.py
index b85b1b1..a76b906 100644
--- a/backend/mock_llm.py
+++ b/backend/mock_llm.py
@@ -8,7 +8,7 @@ STREAM_CHUNK_SIZE = 20
async def mock_completion(
- process_chunk: Callable[[str], Awaitable[None]], input_mode: InputMode
+ process_chunk: Callable[[str, int], Awaitable[None]], input_mode: InputMode
) -> str:
code_to_return = (
TALLY_FORM_VIDEO_PROMPT_MOCK
@@ -17,7 +17,7 @@ async def mock_completion(
)
for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE):
- await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE])
+ await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE], 0)
await asyncio.sleep(0.01)
if input_mode == "video":
diff --git a/backend/prompts/__init__.py b/backend/prompts/__init__.py
index de544a7..9aaf45a 100644
--- a/backend/prompts/__init__.py
+++ b/backend/prompts/__init__.py
@@ -1,12 +1,13 @@
-from typing import List, NoReturn, Union
-
+from typing import Union
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
-from llm import Llm
+from custom_types import InputMode
+from image_generation.core import create_alt_url_mapping
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
from prompts.text_prompts import SYSTEM_PROMPTS as TEXT_SYSTEM_PROMPTS
from prompts.types import Stack
+from video.utils import assemble_claude_prompt_video
USER_PROMPT = """
@@ -18,9 +19,65 @@ Generate code for a SVG that looks exactly like this.
"""
+async def create_prompt(
+ params: dict[str, str], stack: Stack, input_mode: InputMode
+) -> tuple[list[ChatCompletionMessageParam], dict[str, str]]:
+
+ image_cache: dict[str, str] = {}
+
+ # If this generation started off with imported code, we need to assemble the prompt differently
+ if params.get("isImportedFromCode"):
+ original_imported_code = params["history"][0]
+ prompt_messages = assemble_imported_code_prompt(original_imported_code, stack)
+ for index, text in enumerate(params["history"][1:]):
+ if index % 2 == 0:
+ message: ChatCompletionMessageParam = {
+ "role": "user",
+ "content": text,
+ }
+ else:
+ message: ChatCompletionMessageParam = {
+ "role": "assistant",
+ "content": text,
+ }
+ prompt_messages.append(message)
+ else:
+ # Assemble the prompt for non-imported code
+ if params.get("resultImage"):
+ prompt_messages = assemble_prompt(
+ params["image"], stack, params["resultImage"]
+ )
+ else:
+ prompt_messages = assemble_prompt(params["image"], stack)
+
+ if params["generationType"] == "update":
+ # Transform the history tree into message format
+ # TODO: Move this to frontend
+ for index, text in enumerate(params["history"]):
+ if index % 2 == 0:
+ message: ChatCompletionMessageParam = {
+ "role": "assistant",
+ "content": text,
+ }
+ else:
+ message: ChatCompletionMessageParam = {
+ "role": "user",
+ "content": text,
+ }
+ prompt_messages.append(message)
+
+ image_cache = create_alt_url_mapping(params["history"][-2])
+
+ if input_mode == "video":
+ video_data_url = params["image"]
+ prompt_messages = await assemble_claude_prompt_video(video_data_url)
+
+ return prompt_messages, image_cache
+
+
def assemble_imported_code_prompt(
- code: str, stack: Stack, model: Llm
-) -> List[ChatCompletionMessageParam]:
+ code: str, stack: Stack
+) -> list[ChatCompletionMessageParam]:
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
user_content = (
@@ -29,24 +86,12 @@ def assemble_imported_code_prompt(
else "Here is the code of the SVG: " + code
)
- if model == Llm.CLAUDE_3_5_SONNET_2024_06_20:
- return [
- {
- "role": "system",
- "content": system_content + "\n " + user_content,
- }
- ]
- else:
- return [
- {
- "role": "system",
- "content": system_content,
- },
- {
- "role": "user",
- "content": user_content,
- },
- ]
+ return [
+ {
+ "role": "system",
+ "content": system_content + "\n " + user_content,
+ }
+ ]
# TODO: Use result_image_data_url
@@ -54,11 +99,11 @@ def assemble_prompt(
image_data_url: str,
stack: Stack,
result_image_data_url: Union[str, None] = None,
-) -> List[ChatCompletionMessageParam]:
+) -> list[ChatCompletionMessageParam]:
system_content = SYSTEM_PROMPTS[stack]
user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT
- user_content: List[ChatCompletionContentPartParam] = [
+ user_content: list[ChatCompletionContentPartParam] = [
{
"type": "image_url",
"image_url": {"url": image_data_url, "detail": "high"},
@@ -93,7 +138,7 @@ def assemble_prompt(
def assemble_text_prompt(
text_prompt: str,
stack: Stack,
-) -> List[ChatCompletionMessageParam]:
+) -> list[ChatCompletionMessageParam]:
system_content = TEXT_SYSTEM_PROMPTS[stack]
diff --git a/backend/prompts/test_prompts.py b/backend/prompts/test_prompts.py
index 9175fd8..049f9db 100644
--- a/backend/prompts/test_prompts.py
+++ b/backend/prompts/test_prompts.py
@@ -391,63 +391,81 @@ def test_prompts():
def test_imported_code_prompts():
- tailwind_prompt = assemble_imported_code_prompt(
- "code", "html_tailwind", Llm.GPT_4O_2024_05_13
- )
+ code = "Sample code"
+
+ tailwind_prompt = assemble_imported_code_prompt(code, "html_tailwind")
expected_tailwind_prompt = [
- {"role": "system", "content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT},
- {"role": "user", "content": "Here is the code of the app: code"},
+ {
+ "role": "system",
+ "content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
+ + "\n Here is the code of the app: "
+ + code,
+ }
]
assert tailwind_prompt == expected_tailwind_prompt
- html_css_prompt = assemble_imported_code_prompt(
- "code", "html_css", Llm.GPT_4O_2024_05_13
- )
+ html_css_prompt = assemble_imported_code_prompt(code, "html_css")
expected_html_css_prompt = [
- {"role": "system", "content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT},
- {"role": "user", "content": "Here is the code of the app: code"},
+ {
+ "role": "system",
+ "content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT
+ + "\n Here is the code of the app: "
+ + code,
+ }
]
assert html_css_prompt == expected_html_css_prompt
- react_tailwind_prompt = assemble_imported_code_prompt(
- "code", "react_tailwind", Llm.GPT_4O_2024_05_13
- )
+ react_tailwind_prompt = assemble_imported_code_prompt(code, "react_tailwind")
expected_react_tailwind_prompt = [
- {"role": "system", "content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT},
- {"role": "user", "content": "Here is the code of the app: code"},
+ {
+ "role": "system",
+ "content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT
+ + "\n Here is the code of the app: "
+ + code,
+ }
]
assert react_tailwind_prompt == expected_react_tailwind_prompt
- bootstrap_prompt = assemble_imported_code_prompt(
- "code", "bootstrap", Llm.GPT_4O_2024_05_13
- )
+ bootstrap_prompt = assemble_imported_code_prompt(code, "bootstrap")
expected_bootstrap_prompt = [
- {"role": "system", "content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT},
- {"role": "user", "content": "Here is the code of the app: code"},
+ {
+ "role": "system",
+ "content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT
+ + "\n Here is the code of the app: "
+ + code,
+ }
]
assert bootstrap_prompt == expected_bootstrap_prompt
- ionic_tailwind = assemble_imported_code_prompt(
- "code", "ionic_tailwind", Llm.GPT_4O_2024_05_13
- )
+ ionic_tailwind = assemble_imported_code_prompt(code, "ionic_tailwind")
expected_ionic_tailwind = [
- {"role": "system", "content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT},
- {"role": "user", "content": "Here is the code of the app: code"},
+ {
+ "role": "system",
+ "content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
+ + "\n Here is the code of the app: "
+ + code,
+ }
]
assert ionic_tailwind == expected_ionic_tailwind
- vue_tailwind = assemble_imported_code_prompt(
- "code", "vue_tailwind", Llm.GPT_4O_2024_05_13
- )
+ vue_tailwind = assemble_imported_code_prompt(code, "vue_tailwind")
expected_vue_tailwind = [
- {"role": "system", "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT},
- {"role": "user", "content": "Here is the code of the app: code"},
+ {
+ "role": "system",
+ "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT
+ + "\n Here is the code of the app: "
+ + code,
+ }
]
assert vue_tailwind == expected_vue_tailwind
- svg = assemble_imported_code_prompt("code", "svg", Llm.GPT_4O_2024_05_13)
+ svg = assemble_imported_code_prompt(code, "svg")
expected_svg = [
- {"role": "system", "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT},
- {"role": "user", "content": "Here is the code of the SVG: code"},
+ {
+ "role": "system",
+ "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT
+ + "\n Here is the code of the SVG: "
+ + code,
+ }
]
assert svg == expected_svg
diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py
index 42492c1..91111c1 100644
--- a/backend/routes/generate_code.py
+++ b/backend/routes/generate_code.py
@@ -1,12 +1,15 @@
+import asyncio
+from dataclasses import dataclass
import os
-import traceback
from fastapi import APIRouter, WebSocket
import openai
-import sentry_sdk
from codegen.utils import extract_html_content
from config import (
ANTHROPIC_API_KEY,
IS_PROD,
+ NUM_VARIANTS,
+ OPENAI_API_KEY,
+ OPENAI_BASE_URL,
REPLICATE_API_KEY,
SHOULD_MOCK_AI_RESPONSE,
)
@@ -18,77 +21,101 @@ from llm import (
stream_claude_response_native,
stream_openai_response,
)
-from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
-from typing import Dict, List, Union, cast, get_args
-from image_generation.core import create_alt_url_mapping, generate_images
-from prompts import assemble_imported_code_prompt, assemble_prompt, assemble_text_prompt
-from datetime import datetime
-import json
+from typing import Dict, List, cast, get_args
+from image_generation.core import generate_images
from routes.logging_utils import PaymentMethod, send_to_saas_backend
from routes.saas_utils import does_user_have_subscription_credits
+from typing import Any, Callable, Coroutine, Dict, List, Literal, cast, get_args
+from image_generation.core import generate_images
+from prompts import create_prompt
from prompts.claude_prompts import VIDEO_PROMPT
from prompts.types import Stack
-from utils import pprint_prompt
# from utils import pprint_prompt
-from video.utils import extract_tag_content, assemble_claude_prompt_video
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
router = APIRouter()
-def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str):
- # Get the logs path from environment, default to the current working directory
- logs_path = os.environ.get("LOGS_PATH", os.getcwd())
-
- # Create run_logs directory if it doesn't exist within the specified logs path
- logs_directory = os.path.join(logs_path, "run_logs")
- if not os.path.exists(logs_directory):
- os.makedirs(logs_directory)
-
- print("Writing to logs directory:", logs_directory)
-
- # Generate a unique filename using the current timestamp within the logs directory
- filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json")
-
- # Write the messages dict into a new file for each run
- with open(filename, "w") as f:
- f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
+# Auto-upgrade usage of older models
+def auto_upgrade_model(code_generation_model: Llm) -> Llm:
+ if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}:
+ print(
+ f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13"
+ )
+ return Llm.GPT_4O_2024_05_13
+ elif code_generation_model == Llm.CLAUDE_3_SONNET:
+ print(
+ f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20"
+ )
+ return Llm.CLAUDE_3_5_SONNET_2024_06_20
+ return code_generation_model
-@router.websocket("/generate-code")
-async def stream_code(websocket: WebSocket):
- await websocket.accept()
+# Generate images, if needed
+async def perform_image_generation(
+ completion: str,
+ should_generate_images: bool,
+ openai_api_key: str | None,
+ openai_base_url: str | None,
+ image_cache: dict[str, str],
+):
+ replicate_api_key = REPLICATE_API_KEY
+ if not should_generate_images:
+ return completion
- print("Incoming websocket connection...")
+ if replicate_api_key:
+ image_generation_model = "sdxl-lightning"
+ api_key = replicate_api_key
+ else:
+ if not openai_api_key:
+ print(
+ "No OpenAI API key and Replicate key found. Skipping image generation."
+ )
+ return completion
+ image_generation_model = "dalle3"
+ api_key = openai_api_key
- async def throw_error(
- message: str,
- ):
- await websocket.send_json({"type": "error", "value": message})
- await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
+ print("Generating images with model: ", image_generation_model)
- # TODO: Are the values always strings?
- params: Dict[str, str] = await websocket.receive_json()
+ return await generate_images(
+ completion,
+ api_key=api_key,
+ base_url=openai_base_url,
+ image_cache=image_cache,
+ model=image_generation_model,
+ )
- # Read the code config settings from the request. Fall back to default if not provided.
- generated_code_config = ""
- if "generatedCodeConfig" in params and params["generatedCodeConfig"]:
- generated_code_config = params["generatedCodeConfig"]
- if not generated_code_config in get_args(Stack):
+
+@dataclass
+class ExtractedParams:
+ stack: Stack
+ input_mode: InputMode
+ code_generation_model: Llm
+ should_generate_images: bool
+ openai_api_key: str | None
+ anthropic_api_key: str | None
+ openai_base_url: str | None
+ payment_method: PaymentMethod
+
+
+async def extract_params(
+ params: Dict[str, str], throw_error: Callable[[str], Coroutine[Any, Any, None]]
+) -> ExtractedParams:
+ # Read the code config settings (stack) from the request.
+ generated_code_config = params.get("generatedCodeConfig", "")
+ if generated_code_config not in get_args(Stack):
await throw_error(f"Invalid generated code config: {generated_code_config}")
- return
- # Cast the variable to the Stack type
- valid_stack = cast(Stack, generated_code_config)
+ raise ValueError(f"Invalid generated code config: {generated_code_config}")
+ validated_stack = cast(Stack, generated_code_config)
# Validate the input mode
- input_mode = params.get("inputMode", "image")
- if not input_mode in get_args(InputMode):
+ input_mode = params.get("inputMode")
+ if input_mode not in get_args(InputMode):
await throw_error(f"Invalid input mode: {input_mode}")
- raise Exception(f"Invalid input mode: {input_mode}")
- # Cast the variable to the right type
+ raise ValueError(f"Invalid input mode: {input_mode}")
validated_input_mode = cast(InputMode, input_mode)
# Read the model from the request. Fall back to default if not provided.
@@ -97,42 +124,19 @@ async def stream_code(websocket: WebSocket):
)
try:
code_generation_model = convert_frontend_str_to_llm(code_generation_model_str)
- except:
+ except ValueError:
await throw_error(f"Invalid model: {code_generation_model_str}")
- raise Exception(f"Invalid model: {code_generation_model_str}")
-
- # Auto-upgrade usage of older models
- if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}:
- print(
- f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13"
- )
- code_generation_model = Llm.GPT_4O_2024_05_13
- elif code_generation_model == Llm.CLAUDE_3_SONNET:
- print(
- f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20"
- )
- code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20
-
- exact_llm_version = None
-
- print(
- f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
- )
-
- # Track how this generation is being paid for
- payment_method: PaymentMethod = PaymentMethod.UNKNOWN
- # Track the OpenAI API key to use
- openai_api_key = None
+ raise ValueError(f"Invalid model: {code_generation_model_str}")
auth_token = params.get("authToken")
if not auth_token:
await throw_error("You need to be logged in to use screenshot to code")
raise Exception("No auth token")
- # Get the OpenAI key by waterfalling through the different payment methods
- # 1. Subscription
- # 2. User's API key from client-side settings dialog
- # 3. User's API key from environment variable
+ openai_api_key = None
+
+ # Track how this generation is being paid for
+ payment_method: PaymentMethod = PaymentMethod.UNKNOWN
# If the user is a subscriber, use the platform API key
# TODO: Rename does_user_have_subscription_credits
@@ -157,151 +161,152 @@ async def stream_code(websocket: WebSocket):
await throw_error("Unknown error occurred. Contact support.")
raise Exception("Unknown error occurred when checking subscription credits")
- # For non-subscribers, use the user's API key from client-side settings dialog
- if not openai_api_key:
- openai_api_key = params.get("openAiApiKey", None)
- payment_method = PaymentMethod.OPENAI_API_KEY
- print("Using OpenAI API key from client-side settings dialog")
-
# If we still don't have an API key, use the user's API key from environment variable
if not openai_api_key:
- openai_api_key = os.environ.get("OPENAI_API_KEY")
+ openai_api_key = get_from_settings_dialog_or_env(
+ params, "openAiApiKey", OPENAI_API_KEY
+ )
payment_method = PaymentMethod.OPENAI_API_KEY
if openai_api_key:
+ # TODO
print("Using OpenAI API key from environment variable")
- if not openai_api_key and (
- code_generation_model == Llm.GPT_4_VISION
- or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
- or code_generation_model == Llm.GPT_4O_2024_05_13
- ):
- print("OpenAI API key not found")
- await throw_error(
- "Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support."
- )
- raise Exception("No OpenAI API key found")
+ # if not openai_api_key and (
+ # code_generation_model == Llm.GPT_4_VISION
+ # or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
+ # or code_generation_model == Llm.GPT_4O_2024_05_13
+ # ):
+ # print("OpenAI API key not found")
+ # await throw_error(
+ # "Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support."
+ # )
+ # raise Exception("No OpenAI API key found")
- # Get the Anthropic API key from the request. Fall back to environment variable if not provided.
+ # TODO: Do not allow usage of key
# If neither is provided, we throw an error later only if Claude is used.
- anthropic_api_key = None
- if "anthropicApiKey" in params and params["anthropicApiKey"]:
- anthropic_api_key = params["anthropicApiKey"]
- print("Using Anthropic API key from client-side settings dialog")
- else:
- anthropic_api_key = ANTHROPIC_API_KEY
- if anthropic_api_key:
- print("Using Anthropic API key from environment variable")
+ anthropic_api_key = get_from_settings_dialog_or_env(
+ params, "anthropicApiKey", ANTHROPIC_API_KEY
+ )
- # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
- openai_base_url: Union[str, None] = None
+ # Base URL for OpenAI API
+ openai_base_url: str | None = None
# Disable user-specified OpenAI Base URL in prod
- if not os.environ.get("IS_PROD"):
- if "openAiBaseURL" in params and params["openAiBaseURL"]:
- openai_base_url = params["openAiBaseURL"]
- print("Using OpenAI Base URL from client-side settings dialog")
- else:
- openai_base_url = os.environ.get("OPENAI_BASE_URL")
- if openai_base_url:
- print("Using OpenAI Base URL from environment variable")
-
+ if not IS_PROD:
+ openai_base_url = get_from_settings_dialog_or_env(
+ params, "openAiBaseURL", OPENAI_BASE_URL
+ )
if not openai_base_url:
print("Using official OpenAI URL")
# Get the image generation flag from the request. Fall back to True if not provided.
- should_generate_images = (
- params["isImageGenerationEnabled"]
- if "isImageGenerationEnabled" in params
- else True
+ should_generate_images = bool(params.get("isImageGenerationEnabled", True))
+
+ return ExtractedParams(
+ stack=validated_stack,
+ input_mode=validated_input_mode,
+ code_generation_model=code_generation_model,
+ should_generate_images=should_generate_images,
+ openai_api_key=openai_api_key,
+ anthropic_api_key=anthropic_api_key,
+ openai_base_url=openai_base_url,
+ payment_method=payment_method,
)
- print("generating code...")
- await websocket.send_json({"type": "status", "value": "Generating code..."})
- async def process_chunk(content: str):
- await websocket.send_json({"type": "chunk", "value": content})
+def get_from_settings_dialog_or_env(
+ params: dict[str, str], key: str, env_var: str | None
+) -> str | None:
+ value = params.get(key)
+ if value:
+ print(f"Using {key} from client-side settings dialog")
+ return value
+
+ if env_var:
+ print(f"Using {key} from environment variable")
+ return env_var
+
+ return None
+
+
+@router.websocket("/generate-code")
+async def stream_code(websocket: WebSocket):
+ await websocket.accept()
+ print("Incoming websocket connection...")
+
+ ## Communication protocol setup
+ async def throw_error(
+ message: str,
+ ):
+ print(message)
+ await websocket.send_json({"type": "error", "value": message})
+ await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
+
+ async def send_message(
+ type: Literal["chunk", "status", "setCode", "error"],
+ value: str,
+ variantIndex: int,
+ ):
+ # Print for debugging on the backend
+ if type == "error":
+ print(f"Error (variant {variantIndex}): {value}")
+ elif type == "status":
+ print(f"Status (variant {variantIndex}): {value}")
+
+ await websocket.send_json(
+ {"type": type, "value": value, "variantIndex": variantIndex}
+ )
+
+ ## Parameter extract and validation
+
+ # TODO: Are the values always strings?
+ params: dict[str, str] = await websocket.receive_json()
+ print("Received params")
+
+ extracted_params = await extract_params(params, throw_error)
+ stack = extracted_params.stack
+ input_mode = extracted_params.input_mode
+ code_generation_model = extracted_params.code_generation_model
+ openai_api_key = extracted_params.openai_api_key
+ openai_base_url = extracted_params.openai_base_url
+ anthropic_api_key = extracted_params.anthropic_api_key
+ should_generate_images = extracted_params.should_generate_images
+ payment_method = extracted_params.payment_method
+
+ # Auto-upgrade usage of older models
+ code_generation_model = auto_upgrade_model(code_generation_model)
+
+ print(
+ f"Generating {stack} code in {input_mode} mode using {code_generation_model}..."
+ )
+
+ for i in range(NUM_VARIANTS):
+ await send_message("status", "Generating code...", i)
+
+ ### Prompt creation
# Image cache for updates so that we don't have to regenerate images
image_cache: Dict[str, str] = {}
- # If this generation started off with imported code, we need to assemble the prompt differently
- if params.get("isImportedFromCode") and params["isImportedFromCode"]:
- original_imported_code = params["history"][0]
- prompt_messages = assemble_imported_code_prompt(
- original_imported_code, valid_stack, code_generation_model
+ try:
+ prompt_messages, image_cache = await create_prompt(params, stack, input_mode)
+ except:
+ await throw_error(
+ "Error assembling prompt. Contact support at support@picoapps.xyz"
)
- for index, text in enumerate(params["history"][1:]):
- # TODO: Remove after "Select and edit" is fully implemented
- if "referring to this element specifically" in text:
- sentry_sdk.capture_exception(Exception("Point and edit used"))
-
- if index % 2 == 0:
- message: ChatCompletionMessageParam = {
- "role": "user",
- "content": text,
- }
- else:
- message: ChatCompletionMessageParam = {
- "role": "assistant",
- "content": text,
- }
- prompt_messages.append(message)
- else:
- # Assemble the prompt
- try:
- if validated_input_mode == "image":
- if params.get("resultImage") and params["resultImage"]:
- prompt_messages = assemble_prompt(
- params["image"], valid_stack, params["resultImage"]
- )
- else:
- prompt_messages = assemble_prompt(params["image"], valid_stack)
- elif validated_input_mode == "text":
- sentry_sdk.capture_exception(Exception("Text generation used"))
- prompt_messages = assemble_text_prompt(params["image"], valid_stack)
- else:
- await throw_error("Invalid input mode")
- return
- except:
- await websocket.send_json(
- {
- "type": "error",
- "value": "Error assembling prompt. Contact support at support@picoapps.xyz",
- }
- )
- await websocket.close()
- return
-
- # Transform the history tree into message format for updates
- if params["generationType"] == "update":
- # TODO: Move this to frontend
- for index, text in enumerate(params["history"]):
- if index % 2 == 0:
- message: ChatCompletionMessageParam = {
- "role": "assistant",
- "content": text,
- }
- else:
- message: ChatCompletionMessageParam = {
- "role": "user",
- "content": text,
- }
- prompt_messages.append(message)
-
- image_cache = create_alt_url_mapping(params["history"][-2])
-
- if validated_input_mode == "video":
- video_data_url = params["image"]
- prompt_messages = await assemble_claude_prompt_video(video_data_url)
+ raise
# pprint_prompt(prompt_messages) # type: ignore
+ ### Code generation
+
+ async def process_chunk(content: str, variantIndex: int):
+ await send_message("chunk", content, variantIndex)
+
if SHOULD_MOCK_AI_RESPONSE:
- completion = await mock_completion(
- process_chunk, input_mode=validated_input_mode
- )
+ completions = [await mock_completion(process_chunk, input_mode=input_mode)]
else:
try:
- if validated_input_mode == "video":
+ if input_mode == "video":
if IS_PROD:
raise Exception("Video mode is not supported in prod")
@@ -311,48 +316,66 @@ async def stream_code(websocket: WebSocket):
)
raise Exception("No Anthropic key")
- completion = await stream_claude_response_native(
- system_prompt=VIDEO_PROMPT,
- messages=prompt_messages, # type: ignore
- api_key=anthropic_api_key,
- callback=lambda x: process_chunk(x),
- model=Llm.CLAUDE_3_OPUS,
- include_thinking=True,
- )
- exact_llm_version = Llm.CLAUDE_3_OPUS
- elif (
- code_generation_model == Llm.CLAUDE_3_SONNET
- or code_generation_model == Llm.CLAUDE_3_5_SONNET_2024_06_20
- ):
- if not anthropic_api_key:
- await throw_error(
- "No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
+ completions = [
+ await stream_claude_response_native(
+ system_prompt=VIDEO_PROMPT,
+ messages=prompt_messages, # type: ignore
+ api_key=anthropic_api_key,
+ callback=lambda x: process_chunk(x, 0),
+ model=Llm.CLAUDE_3_OPUS,
+ include_thinking=True,
)
- raise Exception("No Anthropic key")
-
- # Do not allow non-subscribers to use Claude
- if payment_method != PaymentMethod.SUBSCRIPTION:
- await throw_error(
- "Please subscribe to a paid plan to use the Claude models"
- )
- raise Exception("Not subscribed to a paid plan for Claude")
-
- completion = await stream_claude_response(
- prompt_messages, # type: ignore
- api_key=anthropic_api_key,
- callback=lambda x: process_chunk(x),
- model=code_generation_model,
- )
- exact_llm_version = code_generation_model
+ ]
else:
- completion = await stream_openai_response(
- prompt_messages, # type: ignore
- api_key=openai_api_key,
- base_url=openai_base_url,
- callback=lambda x: process_chunk(x),
- model=code_generation_model,
- )
- exact_llm_version = code_generation_model
+
+ # Depending on the presence and absence of various keys,
+ # we decide which models to run
+ variant_models = []
+ if openai_api_key and anthropic_api_key:
+ variant_models = ["openai", "anthropic"]
+ elif openai_api_key:
+ variant_models = ["openai", "openai"]
+ elif anthropic_api_key:
+ variant_models = ["anthropic", "anthropic"]
+ else:
+ await throw_error(
+ "No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog. If you add it to .env, make sure to restart the backend server."
+ )
+ raise Exception("No OpenAI or Anthropic key")
+
+ tasks: List[Coroutine[Any, Any, str]] = []
+ for index, model in enumerate(variant_models):
+ if model == "openai":
+ if openai_api_key is None:
+ await throw_error("OpenAI API key is missing.")
+ raise Exception("OpenAI API key is missing.")
+
+ tasks.append(
+ stream_openai_response(
+ prompt_messages,
+ api_key=openai_api_key,
+ base_url=openai_base_url,
+ callback=lambda x, i=index: process_chunk(x, i),
+ model=Llm.GPT_4O_2024_05_13,
+ )
+ )
+ elif model == "anthropic":
+ if anthropic_api_key is None:
+ await throw_error("Anthropic API key is missing.")
+ raise Exception("Anthropic API key is missing.")
+
+ tasks.append(
+ stream_claude_response(
+ prompt_messages,
+ api_key=anthropic_api_key,
+ callback=lambda x, i=index: process_chunk(x, i),
+ model=Llm.CLAUDE_3_5_SONNET_2024_06_20,
+ )
+ )
+
+ completions = await asyncio.gather(*tasks)
+ print("Models used for generation: ", variant_models)
+
except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e)
error_message = (
@@ -388,13 +411,10 @@ async def stream_code(websocket: WebSocket):
)
return await throw_error(error_message)
- if validated_input_mode == "video":
- completion = extract_tag_content("html", completion)
-
- print("Exact used model for generation: ", exact_llm_version)
+ ## Post-processing
# Strip the completion of everything except the HTML content
- completion = extract_html_content(completion)
+ completions = [extract_html_content(completion) for completion in completions]
# Write the messages dict into a log so that we can debug later
# write_logs(prompt_messages, completion) # type: ignore
@@ -402,54 +422,44 @@ async def stream_code(websocket: WebSocket):
if IS_PROD:
# Catch any errors from sending to SaaS backend and continue
try:
- assert exact_llm_version is not None, "exact_llm_version is not set"
+ # TODO*
+ # assert exact_llm_version is not None, "exact_llm_version is not set"
await send_to_saas_backend(
prompt_messages,
- completion,
+ # TODO*: Store both completions
+ completions[0],
payment_method=payment_method,
- llm_version=exact_llm_version,
- stack=valid_stack,
+ # TODO*
+ llm_version=Llm.CLAUDE_3_5_SONNET_2024_06_20,
+ stack=stack,
is_imported_from_code=bool(params.get("isImportedFromCode", False)),
includes_result_image=bool(params.get("resultImage", False)),
- input_mode=validated_input_mode,
+ input_mode=input_mode,
auth_token=params["authToken"],
)
except Exception as e:
print("Error sending to SaaS backend", e)
- try:
- if should_generate_images:
- await websocket.send_json(
- {"type": "status", "value": "Generating images..."}
- )
- image_generation_model = "sdxl-lightning" if REPLICATE_API_KEY else "dalle3"
- print("Generating images with model: ", image_generation_model)
+ ## Image Generation
- updated_html = await generate_images(
- completion,
- api_key=(
- REPLICATE_API_KEY
- if image_generation_model == "sdxl-lightning"
- else openai_api_key
- ),
- base_url=openai_base_url,
- image_cache=image_cache,
- model=image_generation_model,
- )
- else:
- updated_html = completion
- await websocket.send_json({"type": "setCode", "value": updated_html})
- await websocket.send_json(
- {"type": "status", "value": "Code generation complete."}
- )
- except Exception as e:
- traceback.print_exc()
- print("Image generation failed", e)
- # Send set code even if image generation fails since that triggers
- # the frontend to update history
- await websocket.send_json({"type": "setCode", "value": completion})
- await websocket.send_json(
- {"type": "status", "value": "Image generation failed but code is complete."}
+ for index, _ in enumerate(completions):
+ await send_message("status", "Generating images...", index)
+
+ image_generation_tasks = [
+ perform_image_generation(
+ completion,
+ should_generate_images,
+ openai_api_key,
+ openai_base_url,
+ image_cache,
)
+ for completion in completions
+ ]
+
+ updated_completions = await asyncio.gather(*image_generation_tasks)
+
+ for index, updated_html in enumerate(updated_completions):
+ await send_message("setCode", updated_html, index)
+ await send_message("status", "Code generation complete.", index)
await websocket.close()
diff --git a/frontend/package.json b/frontend/package.json
index 5cdbf64..eee838e 100644
--- a/frontend/package.json
+++ b/frontend/package.json
@@ -43,6 +43,7 @@
"copy-to-clipboard": "^3.3.3",
"html2canvas": "^1.4.1",
"posthog-js": "^1.128.1",
+ "nanoid": "^5.0.7",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",
diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx
index 85b7e0e..c51f783 100644
--- a/frontend/src/App.tsx
+++ b/frontend/src/App.tsx
@@ -9,8 +9,7 @@ import { usePersistedState } from "./hooks/usePersistedState";
import TermsOfServiceDialog from "./components/TermsOfServiceDialog";
import { USER_CLOSE_WEB_SOCKET_CODE } from "./constants";
import { addEvent } from "./lib/analytics";
-import { History } from "./components/history/history_types";
-import { extractHistoryTree } from "./components/history/utils";
+import { extractHistory } from "./components/history/utils";
import toast from "react-hot-toast";
import { useAuth } from "@clerk/clerk-react";
import { useStore } from "./store/store";
@@ -27,6 +26,8 @@ import { GenerationSettings } from "./components/settings/GenerationSettings";
import StartPane from "./components/start-pane/StartPane";
import { takeScreenshot } from "./lib/takeScreenshot";
import Sidebar from "./components/sidebar/Sidebar";
+import { Commit } from "./components/commits/types";
+import { createCommit } from "./components/commits/utils";
interface Props {
navbarComponent?: JSX.Element;
@@ -49,13 +50,19 @@ function App({ navbarComponent }: Props) {
referenceImages,
setReferenceImages,
+ head,
+ commits,
+ addCommit,
+ removeCommit,
+ setHead,
+ appendCommitCode,
+ setCommitCode,
+ resetCommits,
+ resetHead,
+
// Outputs
- setGeneratedCode,
- setExecutionConsole,
- currentVersion,
- setCurrentVersion,
- appHistory,
- setAppHistory,
+ appendExecutionConsole,
+ resetExecutionConsoles,
} = useProjectStore();
const {
@@ -119,32 +126,31 @@ function App({ navbarComponent }: Props) {
// Functions
const reset = () => {
setAppState(AppState.INITIAL);
- setGeneratedCode("");
- setReferenceImages([]);
- setInitialPrompt("");
- setExecutionConsole([]);
- setUpdateInstruction("");
- setIsImportedFromCode(false);
- setAppHistory([]);
- setCurrentVersion(null);
setShouldIncludeResultImage(false);
+ setUpdateInstruction("");
disableInSelectAndEditMode();
+ resetExecutionConsoles();
+
+ resetCommits();
+ resetHead();
+
+ // Inputs
+ setInputMode("image");
+ setReferenceImages([]);
+ setIsImportedFromCode(false);
};
const regenerate = () => {
- if (currentVersion === null) {
- // This would be a error that I log to Sentry
- addEvent("RegenerateCurrentVersionNull");
+ if (head === null) {
toast.error(
- "No current version set. Please open a Github issue as this shouldn't happen."
+ "No current version set. Please contact support via chat or Github."
);
- return;
+ throw new Error("Regenerate called with no head");
}
// Retrieve the previous command
- const previousCommand = appHistory[currentVersion];
- if (previousCommand.type !== "ai_create") {
- addEvent("RegenerateNotFirstVersion");
+ const currentCommit = commits[head];
+ if (currentCommit.type !== "ai_create") {
toast.error("Only the first version can be regenerated.");
return;
}
@@ -165,28 +171,32 @@ function App({ navbarComponent }: Props) {
addEvent("Cancel");
wsRef.current?.close?.(USER_CLOSE_WEB_SOCKET_CODE);
- // make sure stop can correct the state even if the websocket is already closed
- cancelCodeGenerationAndReset();
};
// Used for code generation failure as well
- const cancelCodeGenerationAndReset = () => {
- // When this is the first version, reset the entire app state
- if (currentVersion === null) {
+ const cancelCodeGenerationAndReset = (commit: Commit) => {
+ // When the current commit is the first version, reset the entire app state
+ if (commit.type === "ai_create") {
reset();
} else {
- // Otherwise, revert to the last version
- setGeneratedCode(appHistory[currentVersion].code);
+ // Otherwise, remove current commit from commits
+ removeCommit(commit.hash);
+
+ // Revert to parent commit
+ const parentCommitHash = commit.parentHash;
+ if (parentCommitHash) {
+ setHead(parentCommitHash);
+ } else {
+ throw new Error("Parent commit not found");
+ }
+
setAppState(AppState.CODE_READY);
}
};
- async function doGenerateCode(
- params: CodeGenerationParams,
- parentVersion: number | null
- ) {
+ async function doGenerateCode(params: CodeGenerationParams) {
// Reset the execution console
- setExecutionConsole([]);
+ resetExecutionConsoles();
// Set the app state
setAppState(AppState.CODING);
@@ -199,69 +209,50 @@ function App({ navbarComponent }: Props) {
authToken: authToken || undefined,
};
+ const baseCommitObject = {
+ variants: [{ code: "" }, { code: "" }],
+ };
+
+ const commitInputObject =
+ params.generationType === "create"
+ ? {
+ ...baseCommitObject,
+ type: "ai_create" as const,
+ parentHash: null,
+ inputs: { image_url: referenceImages[0] },
+ }
+ : {
+ ...baseCommitObject,
+ type: "ai_edit" as const,
+ parentHash: head,
+ inputs: {
+ prompt: params.history
+ ? params.history[params.history.length - 1]
+ : "",
+ },
+ };
+
+ // Create a new commit and set it as the head
+ const commit = createCommit(commitInputObject);
+ addCommit(commit);
+ setHead(commit.hash);
+
generateCode(
wsRef,
updatedParams,
// On change
- (token) => setGeneratedCode((prev) => prev + token),
+ (token, variantIndex) => {
+ appendCommitCode(commit.hash, variantIndex, token);
+ },
// On set code
- (code) => {
- setGeneratedCode(code);
- if (params.generationType === "create") {
- if (inputMode === "image" || inputMode === "video") {
- setAppHistory([
- {
- type: "ai_create",
- parentIndex: null,
- code,
- inputs: { image_url: referenceImages[0] },
- },
- ]);
- } else {
- setAppHistory([
- {
- type: "ai_create",
- parentIndex: null,
- code,
- inputs: { text: params.image },
- },
- ]);
- }
- setCurrentVersion(0);
- } else {
- setAppHistory((prev) => {
- // Validate parent version
- if (parentVersion === null) {
- toast.error(
- "No parent version set. Contact support or open a Github issue."
- );
- addEvent("ParentVersionNull");
- return prev;
- }
-
- const newHistory: History = [
- ...prev,
- {
- type: "ai_edit",
- parentIndex: parentVersion,
- code,
- inputs: {
- prompt: params.history
- ? params.history[params.history.length - 1]
- : "", // History should never be empty when performing an edit
- },
- },
- ];
- setCurrentVersion(newHistory.length - 1);
- return newHistory;
- });
- }
+ (code, variantIndex) => {
+ setCommitCode(commit.hash, variantIndex, code);
},
// On status update
- (line) => setExecutionConsole((prev) => [...prev, line]),
+ (line, variantIndex) => appendExecutionConsole(variantIndex, line),
// On cancel
() => {
- cancelCodeGenerationAndReset();
+ cancelCodeGenerationAndReset(commit);
},
// On complete
() => {
@@ -285,14 +276,11 @@ function App({ navbarComponent }: Props) {
// Kick off the code generation
if (referenceImages.length > 0) {
addEvent("Create");
- await doGenerateCode(
- {
- generationType: "create",
- image: referenceImages[0],
- inputMode,
- },
- currentVersion
- );
+ doGenerateCode({
+ generationType: "create",
+ image: referenceImages[0],
+ inputMode,
+ });
}
}
@@ -302,14 +290,11 @@ function App({ navbarComponent }: Props) {
setInputMode("text");
setInitialPrompt(text);
- doGenerateCode(
- {
- generationType: "create",
- inputMode: "text",
- image: text,
- },
- currentVersion
- );
+ doGenerateCode({
+ generationType: "create",
+ inputMode: "text",
+ image: text,
+ });
}
// Subsequent updates
@@ -322,23 +307,22 @@ function App({ navbarComponent }: Props) {
return;
}
- if (currentVersion === null) {
+ if (head === null) {
toast.error(
"No current version set. Contact support or open a Github issue."
);
- addEvent("CurrentVersionNull");
- return;
+ throw new Error("Update called with no head");
}
let historyTree;
try {
- historyTree = extractHistoryTree(appHistory, currentVersion);
+ historyTree = extractHistory(head, commits);
} catch {
addEvent("HistoryTreeFailed");
toast.error(
"Version history is invalid. This shouldn't happen. Please contact support or open a Github issue."
);
- return;
+ throw new Error("Invalid version history");
}
let modifiedUpdateInstruction = updateInstruction;
@@ -352,34 +336,19 @@ function App({ navbarComponent }: Props) {
}
const updatedHistory = [...historyTree, modifiedUpdateInstruction];
+ const resultImage = shouldIncludeResultImage
+ ? await takeScreenshot()
+ : undefined;
- if (shouldIncludeResultImage) {
- const resultImage = await takeScreenshot();
- await doGenerateCode(
- {
- generationType: "update",
- inputMode,
- image: referenceImages[0],
- resultImage: resultImage,
- history: updatedHistory,
- isImportedFromCode,
- },
- currentVersion
- );
- } else {
- await doGenerateCode(
- {
- generationType: "update",
- inputMode,
- image: inputMode === "text" ? initialPrompt : referenceImages[0],
- history: updatedHistory,
- isImportedFromCode,
- },
- currentVersion
- );
- }
+ doGenerateCode({
+ generationType: "update",
+ inputMode,
+ image: referenceImages[0],
+ resultImage,
+ history: updatedHistory,
+ isImportedFromCode,
+ });
- setGeneratedCode("");
setUpdateInstruction("");
}
@@ -402,17 +371,17 @@ function App({ navbarComponent }: Props) {
setIsImportedFromCode(true);
// Set up this project
- setGeneratedCode(code);
setStack(stack);
- setAppHistory([
- {
- type: "code_create",
- parentIndex: null,
- code,
- inputs: { code },
- },
- ]);
- setCurrentVersion(0);
+
+ // Create a new commit and set it as the head
+ const commit = createCommit({
+ type: "code_create",
+ parentHash: null,
+ variants: [{ code }],
+ inputs: null,
+ });
+ addCommit(commit);
+ setHead(commit.hash);
// Set the app state
setAppState(AppState.CODE_READY);
diff --git a/frontend/src/components/commits/types.ts b/frontend/src/components/commits/types.ts
new file mode 100644
index 0000000..eb9ea12
--- /dev/null
+++ b/frontend/src/components/commits/types.ts
@@ -0,0 +1,37 @@
+export type CommitHash = string;
+
+export type Variant = {
+ code: string;
+};
+
+export type BaseCommit = {
+ hash: CommitHash;
+ parentHash: CommitHash | null;
+ dateCreated: Date;
+ isCommitted: boolean;
+ variants: Variant[];
+ selectedVariantIndex: number;
+};
+
+export type CommitType = "ai_create" | "ai_edit" | "code_create";
+
+export type AiCreateCommit = BaseCommit & {
+ type: "ai_create";
+ inputs: {
+ image_url: string;
+ };
+};
+
+export type AiEditCommit = BaseCommit & {
+ type: "ai_edit";
+ inputs: {
+ prompt: string;
+ };
+};
+
+export type CodeCreateCommit = BaseCommit & {
+ type: "code_create";
+ inputs: null;
+};
+
+export type Commit = AiCreateCommit | AiEditCommit | CodeCreateCommit;
diff --git a/frontend/src/components/commits/utils.ts b/frontend/src/components/commits/utils.ts
new file mode 100644
index 0000000..640f83e
--- /dev/null
+++ b/frontend/src/components/commits/utils.ts
@@ -0,0 +1,32 @@
+import { nanoid } from "nanoid";
+import {
+ AiCreateCommit,
+ AiEditCommit,
+ CodeCreateCommit,
+ Commit,
+} from "./types";
+
+export function createCommit(
+ commit:
+ | Omit<
+ AiCreateCommit,
+ "hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted"
+ >
+ | Omit<
+ AiEditCommit,
+ "hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted"
+ >
+ | Omit<
+ CodeCreateCommit,
+ "hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted"
+ >
+): Commit {
+ const hash = nanoid();
+ return {
+ ...commit,
+ hash,
+ isCommitted: false,
+ dateCreated: new Date(),
+ selectedVariantIndex: 0,
+ };
+}
diff --git a/frontend/src/components/history/HistoryDisplay.tsx b/frontend/src/components/history/HistoryDisplay.tsx
index 6bbdc59..79d7fff 100644
--- a/frontend/src/components/history/HistoryDisplay.tsx
+++ b/frontend/src/components/history/HistoryDisplay.tsx
@@ -17,19 +17,16 @@ interface Props {
}
export default function HistoryDisplay({ shouldDisableReverts }: Props) {
- const {
- appHistory: history,
- currentVersion,
- setCurrentVersion,
- setGeneratedCode,
- } = useProjectStore();
- const renderedHistory = renderHistory(history, currentVersion);
+ const { commits, head, setHead } = useProjectStore();
- const revertToVersion = (index: number) => {
- if (index < 0 || index >= history.length || !history[index]) return;
- setCurrentVersion(index);
- setGeneratedCode(history[index].code);
- };
+ // Put all commits into an array and sort by created date (oldest first)
+ const flatHistory = Object.values(commits).sort(
+ (a, b) =>
+ new Date(a.dateCreated).getTime() - new Date(b.dateCreated).getTime()
+ );
+
+ // Annotate history items with a summary, parent version, etc.
+ const renderedHistory = renderHistory(flatHistory);
return renderedHistory.length === 0 ? null : (