diff --git a/README.md b/README.md index 77c93ed..ae59e48 100644 --- a/README.md +++ b/README.md @@ -33,10 +33,6 @@ We also just added experimental support for taking a video/screen recording of a [Follow me on Twitter for updates](https://twitter.com/_abi_). -## Sponsors - - - ## 🚀 Hosted Version [Try it live on the hosted version (paid)](https://screenshottocode.com). diff --git a/backend/.pre-commit-config.yaml b/backend/.pre-commit-config.yaml index b54da93..b27eb3a 100644 --- a/backend/.pre-commit-config.yaml +++ b/backend/.pre-commit-config.yaml @@ -7,19 +7,19 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files - - repo: local - hooks: - - id: poetry-pytest - name: Run pytest with Poetry - entry: poetry run --directory backend pytest - language: system - pass_filenames: false - always_run: true - files: ^backend/ - # - id: poetry-pyright - # name: Run pyright with Poetry - # entry: poetry run --directory backend pyright - # language: system - # pass_filenames: false - # always_run: true - # files: ^backend/ + # - repo: local + # hooks: + # - id: poetry-pytest + # name: Run pytest with Poetry + # entry: poetry run --directory backend pytest + # language: system + # pass_filenames: false + # always_run: true + # files: ^backend/ + # - id: poetry-pyright + # name: Run pyright with Poetry + # entry: poetry run --directory backend pyright + # language: system + # pass_filenames: false + # always_run: true + # files: ^backend/ diff --git a/backend/config.py b/backend/config.py index dfb6d9d..34fb3bb 100644 --- a/backend/config.py +++ b/backend/config.py @@ -3,8 +3,12 @@ # TODO: Should only be set to true when value is 'True', not any abitrary truthy value import os +NUM_VARIANTS = 2 + +# LLM-related OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None) ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None) +OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", None) # Image generation (optional) REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None) diff --git a/backend/fs_logging/__init__.py b/backend/fs_logging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/fs_logging/core.py b/backend/fs_logging/core.py new file mode 100644 index 0000000..e89096f --- /dev/null +++ b/backend/fs_logging/core.py @@ -0,0 +1,23 @@ +from datetime import datetime +import json +import os +from openai.types.chat import ChatCompletionMessageParam + + +def write_logs(prompt_messages: list[ChatCompletionMessageParam], completion: str): + # Get the logs path from environment, default to the current working directory + logs_path = os.environ.get("LOGS_PATH", os.getcwd()) + + # Create run_logs directory if it doesn't exist within the specified logs path + logs_directory = os.path.join(logs_path, "run_logs") + if not os.path.exists(logs_directory): + os.makedirs(logs_directory) + + print("Writing to logs directory:", logs_directory) + + # Generate a unique filename using the current timestamp within the logs directory + filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json") + + # Write the messages dict into a new file for each run + with open(filename, "w") as f: + f.write(json.dumps({"prompt": prompt_messages, "completion": completion})) diff --git a/backend/image_generation/core.py b/backend/image_generation/core.py index b5f3d1b..f8ee462 100644 --- a/backend/image_generation/core.py +++ b/backend/image_generation/core.py @@ -28,7 +28,7 @@ async def process_tasks( processed_results: List[Union[str, None]] = [] for result in results: - if isinstance(result, Exception): + if isinstance(result, BaseException): print(f"An exception occurred: {result}") try: raise result diff --git a/backend/llm.py b/backend/llm.py index ab8de7a..2d202eb 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -1,3 +1,4 @@ +import copy from enum import Enum from typing import Any, Awaitable, Callable, List, cast from anthropic import AsyncAnthropic @@ -112,8 +113,12 @@ async def stream_claude_response( temperature = 0.0 # Translate OpenAI messages to Claude messages - system_prompt = cast(str, messages[0].get("content")) - claude_messages = [dict(message) for message in messages[1:]] + + # Deep copy messages to avoid modifying the original list + cloned_messages = copy.deepcopy(messages) + + system_prompt = cast(str, cloned_messages[0].get("content")) + claude_messages = [dict(message) for message in cloned_messages[1:]] for message in claude_messages: if not isinstance(message["content"], list): continue diff --git a/backend/mock_llm.py b/backend/mock_llm.py index b85b1b1..a76b906 100644 --- a/backend/mock_llm.py +++ b/backend/mock_llm.py @@ -8,7 +8,7 @@ STREAM_CHUNK_SIZE = 20 async def mock_completion( - process_chunk: Callable[[str], Awaitable[None]], input_mode: InputMode + process_chunk: Callable[[str, int], Awaitable[None]], input_mode: InputMode ) -> str: code_to_return = ( TALLY_FORM_VIDEO_PROMPT_MOCK @@ -17,7 +17,7 @@ async def mock_completion( ) for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE): - await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE]) + await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE], 0) await asyncio.sleep(0.01) if input_mode == "video": diff --git a/backend/prompts/__init__.py b/backend/prompts/__init__.py index de544a7..9aaf45a 100644 --- a/backend/prompts/__init__.py +++ b/backend/prompts/__init__.py @@ -1,12 +1,13 @@ -from typing import List, NoReturn, Union - +from typing import Union from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam -from llm import Llm +from custom_types import InputMode +from image_generation.core import create_alt_url_mapping from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS from prompts.screenshot_system_prompts import SYSTEM_PROMPTS from prompts.text_prompts import SYSTEM_PROMPTS as TEXT_SYSTEM_PROMPTS from prompts.types import Stack +from video.utils import assemble_claude_prompt_video USER_PROMPT = """ @@ -18,9 +19,65 @@ Generate code for a SVG that looks exactly like this. """ +async def create_prompt( + params: dict[str, str], stack: Stack, input_mode: InputMode +) -> tuple[list[ChatCompletionMessageParam], dict[str, str]]: + + image_cache: dict[str, str] = {} + + # If this generation started off with imported code, we need to assemble the prompt differently + if params.get("isImportedFromCode"): + original_imported_code = params["history"][0] + prompt_messages = assemble_imported_code_prompt(original_imported_code, stack) + for index, text in enumerate(params["history"][1:]): + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + prompt_messages.append(message) + else: + # Assemble the prompt for non-imported code + if params.get("resultImage"): + prompt_messages = assemble_prompt( + params["image"], stack, params["resultImage"] + ) + else: + prompt_messages = assemble_prompt(params["image"], stack) + + if params["generationType"] == "update": + # Transform the history tree into message format + # TODO: Move this to frontend + for index, text in enumerate(params["history"]): + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + prompt_messages.append(message) + + image_cache = create_alt_url_mapping(params["history"][-2]) + + if input_mode == "video": + video_data_url = params["image"] + prompt_messages = await assemble_claude_prompt_video(video_data_url) + + return prompt_messages, image_cache + + def assemble_imported_code_prompt( - code: str, stack: Stack, model: Llm -) -> List[ChatCompletionMessageParam]: + code: str, stack: Stack +) -> list[ChatCompletionMessageParam]: system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack] user_content = ( @@ -29,24 +86,12 @@ def assemble_imported_code_prompt( else "Here is the code of the SVG: " + code ) - if model == Llm.CLAUDE_3_5_SONNET_2024_06_20: - return [ - { - "role": "system", - "content": system_content + "\n " + user_content, - } - ] - else: - return [ - { - "role": "system", - "content": system_content, - }, - { - "role": "user", - "content": user_content, - }, - ] + return [ + { + "role": "system", + "content": system_content + "\n " + user_content, + } + ] # TODO: Use result_image_data_url @@ -54,11 +99,11 @@ def assemble_prompt( image_data_url: str, stack: Stack, result_image_data_url: Union[str, None] = None, -) -> List[ChatCompletionMessageParam]: +) -> list[ChatCompletionMessageParam]: system_content = SYSTEM_PROMPTS[stack] user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT - user_content: List[ChatCompletionContentPartParam] = [ + user_content: list[ChatCompletionContentPartParam] = [ { "type": "image_url", "image_url": {"url": image_data_url, "detail": "high"}, @@ -93,7 +138,7 @@ def assemble_prompt( def assemble_text_prompt( text_prompt: str, stack: Stack, -) -> List[ChatCompletionMessageParam]: +) -> list[ChatCompletionMessageParam]: system_content = TEXT_SYSTEM_PROMPTS[stack] diff --git a/backend/prompts/test_prompts.py b/backend/prompts/test_prompts.py index 9175fd8..049f9db 100644 --- a/backend/prompts/test_prompts.py +++ b/backend/prompts/test_prompts.py @@ -391,63 +391,81 @@ def test_prompts(): def test_imported_code_prompts(): - tailwind_prompt = assemble_imported_code_prompt( - "code", "html_tailwind", Llm.GPT_4O_2024_05_13 - ) + code = "Sample code" + + tailwind_prompt = assemble_imported_code_prompt(code, "html_tailwind") expected_tailwind_prompt = [ - {"role": "system", "content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert tailwind_prompt == expected_tailwind_prompt - html_css_prompt = assemble_imported_code_prompt( - "code", "html_css", Llm.GPT_4O_2024_05_13 - ) + html_css_prompt = assemble_imported_code_prompt(code, "html_css") expected_html_css_prompt = [ - {"role": "system", "content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert html_css_prompt == expected_html_css_prompt - react_tailwind_prompt = assemble_imported_code_prompt( - "code", "react_tailwind", Llm.GPT_4O_2024_05_13 - ) + react_tailwind_prompt = assemble_imported_code_prompt(code, "react_tailwind") expected_react_tailwind_prompt = [ - {"role": "system", "content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert react_tailwind_prompt == expected_react_tailwind_prompt - bootstrap_prompt = assemble_imported_code_prompt( - "code", "bootstrap", Llm.GPT_4O_2024_05_13 - ) + bootstrap_prompt = assemble_imported_code_prompt(code, "bootstrap") expected_bootstrap_prompt = [ - {"role": "system", "content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert bootstrap_prompt == expected_bootstrap_prompt - ionic_tailwind = assemble_imported_code_prompt( - "code", "ionic_tailwind", Llm.GPT_4O_2024_05_13 - ) + ionic_tailwind = assemble_imported_code_prompt(code, "ionic_tailwind") expected_ionic_tailwind = [ - {"role": "system", "content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert ionic_tailwind == expected_ionic_tailwind - vue_tailwind = assemble_imported_code_prompt( - "code", "vue_tailwind", Llm.GPT_4O_2024_05_13 - ) + vue_tailwind = assemble_imported_code_prompt(code, "vue_tailwind") expected_vue_tailwind = [ - {"role": "system", "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert vue_tailwind == expected_vue_tailwind - svg = assemble_imported_code_prompt("code", "svg", Llm.GPT_4O_2024_05_13) + svg = assemble_imported_code_prompt(code, "svg") expected_svg = [ - {"role": "system", "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the SVG: code"}, + { + "role": "system", + "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT + + "\n Here is the code of the SVG: " + + code, + } ] assert svg == expected_svg diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 42492c1..91111c1 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -1,12 +1,15 @@ +import asyncio +from dataclasses import dataclass import os -import traceback from fastapi import APIRouter, WebSocket import openai -import sentry_sdk from codegen.utils import extract_html_content from config import ( ANTHROPIC_API_KEY, IS_PROD, + NUM_VARIANTS, + OPENAI_API_KEY, + OPENAI_BASE_URL, REPLICATE_API_KEY, SHOULD_MOCK_AI_RESPONSE, ) @@ -18,77 +21,101 @@ from llm import ( stream_claude_response_native, stream_openai_response, ) -from openai.types.chat import ChatCompletionMessageParam from mock_llm import mock_completion -from typing import Dict, List, Union, cast, get_args -from image_generation.core import create_alt_url_mapping, generate_images -from prompts import assemble_imported_code_prompt, assemble_prompt, assemble_text_prompt -from datetime import datetime -import json +from typing import Dict, List, cast, get_args +from image_generation.core import generate_images from routes.logging_utils import PaymentMethod, send_to_saas_backend from routes.saas_utils import does_user_have_subscription_credits +from typing import Any, Callable, Coroutine, Dict, List, Literal, cast, get_args +from image_generation.core import generate_images +from prompts import create_prompt from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack -from utils import pprint_prompt # from utils import pprint_prompt -from video.utils import extract_tag_content, assemble_claude_prompt_video from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore router = APIRouter() -def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str): - # Get the logs path from environment, default to the current working directory - logs_path = os.environ.get("LOGS_PATH", os.getcwd()) - - # Create run_logs directory if it doesn't exist within the specified logs path - logs_directory = os.path.join(logs_path, "run_logs") - if not os.path.exists(logs_directory): - os.makedirs(logs_directory) - - print("Writing to logs directory:", logs_directory) - - # Generate a unique filename using the current timestamp within the logs directory - filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json") - - # Write the messages dict into a new file for each run - with open(filename, "w") as f: - f.write(json.dumps({"prompt": prompt_messages, "completion": completion})) +# Auto-upgrade usage of older models +def auto_upgrade_model(code_generation_model: Llm) -> Llm: + if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}: + print( + f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13" + ) + return Llm.GPT_4O_2024_05_13 + elif code_generation_model == Llm.CLAUDE_3_SONNET: + print( + f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20" + ) + return Llm.CLAUDE_3_5_SONNET_2024_06_20 + return code_generation_model -@router.websocket("/generate-code") -async def stream_code(websocket: WebSocket): - await websocket.accept() +# Generate images, if needed +async def perform_image_generation( + completion: str, + should_generate_images: bool, + openai_api_key: str | None, + openai_base_url: str | None, + image_cache: dict[str, str], +): + replicate_api_key = REPLICATE_API_KEY + if not should_generate_images: + return completion - print("Incoming websocket connection...") + if replicate_api_key: + image_generation_model = "sdxl-lightning" + api_key = replicate_api_key + else: + if not openai_api_key: + print( + "No OpenAI API key and Replicate key found. Skipping image generation." + ) + return completion + image_generation_model = "dalle3" + api_key = openai_api_key - async def throw_error( - message: str, - ): - await websocket.send_json({"type": "error", "value": message}) - await websocket.close(APP_ERROR_WEB_SOCKET_CODE) + print("Generating images with model: ", image_generation_model) - # TODO: Are the values always strings? - params: Dict[str, str] = await websocket.receive_json() + return await generate_images( + completion, + api_key=api_key, + base_url=openai_base_url, + image_cache=image_cache, + model=image_generation_model, + ) - # Read the code config settings from the request. Fall back to default if not provided. - generated_code_config = "" - if "generatedCodeConfig" in params and params["generatedCodeConfig"]: - generated_code_config = params["generatedCodeConfig"] - if not generated_code_config in get_args(Stack): + +@dataclass +class ExtractedParams: + stack: Stack + input_mode: InputMode + code_generation_model: Llm + should_generate_images: bool + openai_api_key: str | None + anthropic_api_key: str | None + openai_base_url: str | None + payment_method: PaymentMethod + + +async def extract_params( + params: Dict[str, str], throw_error: Callable[[str], Coroutine[Any, Any, None]] +) -> ExtractedParams: + # Read the code config settings (stack) from the request. + generated_code_config = params.get("generatedCodeConfig", "") + if generated_code_config not in get_args(Stack): await throw_error(f"Invalid generated code config: {generated_code_config}") - return - # Cast the variable to the Stack type - valid_stack = cast(Stack, generated_code_config) + raise ValueError(f"Invalid generated code config: {generated_code_config}") + validated_stack = cast(Stack, generated_code_config) # Validate the input mode - input_mode = params.get("inputMode", "image") - if not input_mode in get_args(InputMode): + input_mode = params.get("inputMode") + if input_mode not in get_args(InputMode): await throw_error(f"Invalid input mode: {input_mode}") - raise Exception(f"Invalid input mode: {input_mode}") - # Cast the variable to the right type + raise ValueError(f"Invalid input mode: {input_mode}") validated_input_mode = cast(InputMode, input_mode) # Read the model from the request. Fall back to default if not provided. @@ -97,42 +124,19 @@ async def stream_code(websocket: WebSocket): ) try: code_generation_model = convert_frontend_str_to_llm(code_generation_model_str) - except: + except ValueError: await throw_error(f"Invalid model: {code_generation_model_str}") - raise Exception(f"Invalid model: {code_generation_model_str}") - - # Auto-upgrade usage of older models - if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}: - print( - f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13" - ) - code_generation_model = Llm.GPT_4O_2024_05_13 - elif code_generation_model == Llm.CLAUDE_3_SONNET: - print( - f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20" - ) - code_generation_model = Llm.CLAUDE_3_5_SONNET_2024_06_20 - - exact_llm_version = None - - print( - f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." - ) - - # Track how this generation is being paid for - payment_method: PaymentMethod = PaymentMethod.UNKNOWN - # Track the OpenAI API key to use - openai_api_key = None + raise ValueError(f"Invalid model: {code_generation_model_str}") auth_token = params.get("authToken") if not auth_token: await throw_error("You need to be logged in to use screenshot to code") raise Exception("No auth token") - # Get the OpenAI key by waterfalling through the different payment methods - # 1. Subscription - # 2. User's API key from client-side settings dialog - # 3. User's API key from environment variable + openai_api_key = None + + # Track how this generation is being paid for + payment_method: PaymentMethod = PaymentMethod.UNKNOWN # If the user is a subscriber, use the platform API key # TODO: Rename does_user_have_subscription_credits @@ -157,151 +161,152 @@ async def stream_code(websocket: WebSocket): await throw_error("Unknown error occurred. Contact support.") raise Exception("Unknown error occurred when checking subscription credits") - # For non-subscribers, use the user's API key from client-side settings dialog - if not openai_api_key: - openai_api_key = params.get("openAiApiKey", None) - payment_method = PaymentMethod.OPENAI_API_KEY - print("Using OpenAI API key from client-side settings dialog") - # If we still don't have an API key, use the user's API key from environment variable if not openai_api_key: - openai_api_key = os.environ.get("OPENAI_API_KEY") + openai_api_key = get_from_settings_dialog_or_env( + params, "openAiApiKey", OPENAI_API_KEY + ) payment_method = PaymentMethod.OPENAI_API_KEY if openai_api_key: + # TODO print("Using OpenAI API key from environment variable") - if not openai_api_key and ( - code_generation_model == Llm.GPT_4_VISION - or code_generation_model == Llm.GPT_4_TURBO_2024_04_09 - or code_generation_model == Llm.GPT_4O_2024_05_13 - ): - print("OpenAI API key not found") - await throw_error( - "Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support." - ) - raise Exception("No OpenAI API key found") + # if not openai_api_key and ( + # code_generation_model == Llm.GPT_4_VISION + # or code_generation_model == Llm.GPT_4_TURBO_2024_04_09 + # or code_generation_model == Llm.GPT_4O_2024_05_13 + # ): + # print("OpenAI API key not found") + # await throw_error( + # "Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support." + # ) + # raise Exception("No OpenAI API key found") - # Get the Anthropic API key from the request. Fall back to environment variable if not provided. + # TODO: Do not allow usage of key # If neither is provided, we throw an error later only if Claude is used. - anthropic_api_key = None - if "anthropicApiKey" in params and params["anthropicApiKey"]: - anthropic_api_key = params["anthropicApiKey"] - print("Using Anthropic API key from client-side settings dialog") - else: - anthropic_api_key = ANTHROPIC_API_KEY - if anthropic_api_key: - print("Using Anthropic API key from environment variable") + anthropic_api_key = get_from_settings_dialog_or_env( + params, "anthropicApiKey", ANTHROPIC_API_KEY + ) - # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. - openai_base_url: Union[str, None] = None + # Base URL for OpenAI API + openai_base_url: str | None = None # Disable user-specified OpenAI Base URL in prod - if not os.environ.get("IS_PROD"): - if "openAiBaseURL" in params and params["openAiBaseURL"]: - openai_base_url = params["openAiBaseURL"] - print("Using OpenAI Base URL from client-side settings dialog") - else: - openai_base_url = os.environ.get("OPENAI_BASE_URL") - if openai_base_url: - print("Using OpenAI Base URL from environment variable") - + if not IS_PROD: + openai_base_url = get_from_settings_dialog_or_env( + params, "openAiBaseURL", OPENAI_BASE_URL + ) if not openai_base_url: print("Using official OpenAI URL") # Get the image generation flag from the request. Fall back to True if not provided. - should_generate_images = ( - params["isImageGenerationEnabled"] - if "isImageGenerationEnabled" in params - else True + should_generate_images = bool(params.get("isImageGenerationEnabled", True)) + + return ExtractedParams( + stack=validated_stack, + input_mode=validated_input_mode, + code_generation_model=code_generation_model, + should_generate_images=should_generate_images, + openai_api_key=openai_api_key, + anthropic_api_key=anthropic_api_key, + openai_base_url=openai_base_url, + payment_method=payment_method, ) - print("generating code...") - await websocket.send_json({"type": "status", "value": "Generating code..."}) - async def process_chunk(content: str): - await websocket.send_json({"type": "chunk", "value": content}) +def get_from_settings_dialog_or_env( + params: dict[str, str], key: str, env_var: str | None +) -> str | None: + value = params.get(key) + if value: + print(f"Using {key} from client-side settings dialog") + return value + + if env_var: + print(f"Using {key} from environment variable") + return env_var + + return None + + +@router.websocket("/generate-code") +async def stream_code(websocket: WebSocket): + await websocket.accept() + print("Incoming websocket connection...") + + ## Communication protocol setup + async def throw_error( + message: str, + ): + print(message) + await websocket.send_json({"type": "error", "value": message}) + await websocket.close(APP_ERROR_WEB_SOCKET_CODE) + + async def send_message( + type: Literal["chunk", "status", "setCode", "error"], + value: str, + variantIndex: int, + ): + # Print for debugging on the backend + if type == "error": + print(f"Error (variant {variantIndex}): {value}") + elif type == "status": + print(f"Status (variant {variantIndex}): {value}") + + await websocket.send_json( + {"type": type, "value": value, "variantIndex": variantIndex} + ) + + ## Parameter extract and validation + + # TODO: Are the values always strings? + params: dict[str, str] = await websocket.receive_json() + print("Received params") + + extracted_params = await extract_params(params, throw_error) + stack = extracted_params.stack + input_mode = extracted_params.input_mode + code_generation_model = extracted_params.code_generation_model + openai_api_key = extracted_params.openai_api_key + openai_base_url = extracted_params.openai_base_url + anthropic_api_key = extracted_params.anthropic_api_key + should_generate_images = extracted_params.should_generate_images + payment_method = extracted_params.payment_method + + # Auto-upgrade usage of older models + code_generation_model = auto_upgrade_model(code_generation_model) + + print( + f"Generating {stack} code in {input_mode} mode using {code_generation_model}..." + ) + + for i in range(NUM_VARIANTS): + await send_message("status", "Generating code...", i) + + ### Prompt creation # Image cache for updates so that we don't have to regenerate images image_cache: Dict[str, str] = {} - # If this generation started off with imported code, we need to assemble the prompt differently - if params.get("isImportedFromCode") and params["isImportedFromCode"]: - original_imported_code = params["history"][0] - prompt_messages = assemble_imported_code_prompt( - original_imported_code, valid_stack, code_generation_model + try: + prompt_messages, image_cache = await create_prompt(params, stack, input_mode) + except: + await throw_error( + "Error assembling prompt. Contact support at support@picoapps.xyz" ) - for index, text in enumerate(params["history"][1:]): - # TODO: Remove after "Select and edit" is fully implemented - if "referring to this element specifically" in text: - sentry_sdk.capture_exception(Exception("Point and edit used")) - - if index % 2 == 0: - message: ChatCompletionMessageParam = { - "role": "user", - "content": text, - } - else: - message: ChatCompletionMessageParam = { - "role": "assistant", - "content": text, - } - prompt_messages.append(message) - else: - # Assemble the prompt - try: - if validated_input_mode == "image": - if params.get("resultImage") and params["resultImage"]: - prompt_messages = assemble_prompt( - params["image"], valid_stack, params["resultImage"] - ) - else: - prompt_messages = assemble_prompt(params["image"], valid_stack) - elif validated_input_mode == "text": - sentry_sdk.capture_exception(Exception("Text generation used")) - prompt_messages = assemble_text_prompt(params["image"], valid_stack) - else: - await throw_error("Invalid input mode") - return - except: - await websocket.send_json( - { - "type": "error", - "value": "Error assembling prompt. Contact support at support@picoapps.xyz", - } - ) - await websocket.close() - return - - # Transform the history tree into message format for updates - if params["generationType"] == "update": - # TODO: Move this to frontend - for index, text in enumerate(params["history"]): - if index % 2 == 0: - message: ChatCompletionMessageParam = { - "role": "assistant", - "content": text, - } - else: - message: ChatCompletionMessageParam = { - "role": "user", - "content": text, - } - prompt_messages.append(message) - - image_cache = create_alt_url_mapping(params["history"][-2]) - - if validated_input_mode == "video": - video_data_url = params["image"] - prompt_messages = await assemble_claude_prompt_video(video_data_url) + raise # pprint_prompt(prompt_messages) # type: ignore + ### Code generation + + async def process_chunk(content: str, variantIndex: int): + await send_message("chunk", content, variantIndex) + if SHOULD_MOCK_AI_RESPONSE: - completion = await mock_completion( - process_chunk, input_mode=validated_input_mode - ) + completions = [await mock_completion(process_chunk, input_mode=input_mode)] else: try: - if validated_input_mode == "video": + if input_mode == "video": if IS_PROD: raise Exception("Video mode is not supported in prod") @@ -311,48 +316,66 @@ async def stream_code(websocket: WebSocket): ) raise Exception("No Anthropic key") - completion = await stream_claude_response_native( - system_prompt=VIDEO_PROMPT, - messages=prompt_messages, # type: ignore - api_key=anthropic_api_key, - callback=lambda x: process_chunk(x), - model=Llm.CLAUDE_3_OPUS, - include_thinking=True, - ) - exact_llm_version = Llm.CLAUDE_3_OPUS - elif ( - code_generation_model == Llm.CLAUDE_3_SONNET - or code_generation_model == Llm.CLAUDE_3_5_SONNET_2024_06_20 - ): - if not anthropic_api_key: - await throw_error( - "No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog" + completions = [ + await stream_claude_response_native( + system_prompt=VIDEO_PROMPT, + messages=prompt_messages, # type: ignore + api_key=anthropic_api_key, + callback=lambda x: process_chunk(x, 0), + model=Llm.CLAUDE_3_OPUS, + include_thinking=True, ) - raise Exception("No Anthropic key") - - # Do not allow non-subscribers to use Claude - if payment_method != PaymentMethod.SUBSCRIPTION: - await throw_error( - "Please subscribe to a paid plan to use the Claude models" - ) - raise Exception("Not subscribed to a paid plan for Claude") - - completion = await stream_claude_response( - prompt_messages, # type: ignore - api_key=anthropic_api_key, - callback=lambda x: process_chunk(x), - model=code_generation_model, - ) - exact_llm_version = code_generation_model + ] else: - completion = await stream_openai_response( - prompt_messages, # type: ignore - api_key=openai_api_key, - base_url=openai_base_url, - callback=lambda x: process_chunk(x), - model=code_generation_model, - ) - exact_llm_version = code_generation_model + + # Depending on the presence and absence of various keys, + # we decide which models to run + variant_models = [] + if openai_api_key and anthropic_api_key: + variant_models = ["openai", "anthropic"] + elif openai_api_key: + variant_models = ["openai", "openai"] + elif anthropic_api_key: + variant_models = ["anthropic", "anthropic"] + else: + await throw_error( + "No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog. If you add it to .env, make sure to restart the backend server." + ) + raise Exception("No OpenAI or Anthropic key") + + tasks: List[Coroutine[Any, Any, str]] = [] + for index, model in enumerate(variant_models): + if model == "openai": + if openai_api_key is None: + await throw_error("OpenAI API key is missing.") + raise Exception("OpenAI API key is missing.") + + tasks.append( + stream_openai_response( + prompt_messages, + api_key=openai_api_key, + base_url=openai_base_url, + callback=lambda x, i=index: process_chunk(x, i), + model=Llm.GPT_4O_2024_05_13, + ) + ) + elif model == "anthropic": + if anthropic_api_key is None: + await throw_error("Anthropic API key is missing.") + raise Exception("Anthropic API key is missing.") + + tasks.append( + stream_claude_response( + prompt_messages, + api_key=anthropic_api_key, + callback=lambda x, i=index: process_chunk(x, i), + model=Llm.CLAUDE_3_5_SONNET_2024_06_20, + ) + ) + + completions = await asyncio.gather(*tasks) + print("Models used for generation: ", variant_models) + except openai.AuthenticationError as e: print("[GENERATE_CODE] Authentication failed", e) error_message = ( @@ -388,13 +411,10 @@ async def stream_code(websocket: WebSocket): ) return await throw_error(error_message) - if validated_input_mode == "video": - completion = extract_tag_content("html", completion) - - print("Exact used model for generation: ", exact_llm_version) + ## Post-processing # Strip the completion of everything except the HTML content - completion = extract_html_content(completion) + completions = [extract_html_content(completion) for completion in completions] # Write the messages dict into a log so that we can debug later # write_logs(prompt_messages, completion) # type: ignore @@ -402,54 +422,44 @@ async def stream_code(websocket: WebSocket): if IS_PROD: # Catch any errors from sending to SaaS backend and continue try: - assert exact_llm_version is not None, "exact_llm_version is not set" + # TODO* + # assert exact_llm_version is not None, "exact_llm_version is not set" await send_to_saas_backend( prompt_messages, - completion, + # TODO*: Store both completions + completions[0], payment_method=payment_method, - llm_version=exact_llm_version, - stack=valid_stack, + # TODO* + llm_version=Llm.CLAUDE_3_5_SONNET_2024_06_20, + stack=stack, is_imported_from_code=bool(params.get("isImportedFromCode", False)), includes_result_image=bool(params.get("resultImage", False)), - input_mode=validated_input_mode, + input_mode=input_mode, auth_token=params["authToken"], ) except Exception as e: print("Error sending to SaaS backend", e) - try: - if should_generate_images: - await websocket.send_json( - {"type": "status", "value": "Generating images..."} - ) - image_generation_model = "sdxl-lightning" if REPLICATE_API_KEY else "dalle3" - print("Generating images with model: ", image_generation_model) + ## Image Generation - updated_html = await generate_images( - completion, - api_key=( - REPLICATE_API_KEY - if image_generation_model == "sdxl-lightning" - else openai_api_key - ), - base_url=openai_base_url, - image_cache=image_cache, - model=image_generation_model, - ) - else: - updated_html = completion - await websocket.send_json({"type": "setCode", "value": updated_html}) - await websocket.send_json( - {"type": "status", "value": "Code generation complete."} - ) - except Exception as e: - traceback.print_exc() - print("Image generation failed", e) - # Send set code even if image generation fails since that triggers - # the frontend to update history - await websocket.send_json({"type": "setCode", "value": completion}) - await websocket.send_json( - {"type": "status", "value": "Image generation failed but code is complete."} + for index, _ in enumerate(completions): + await send_message("status", "Generating images...", index) + + image_generation_tasks = [ + perform_image_generation( + completion, + should_generate_images, + openai_api_key, + openai_base_url, + image_cache, ) + for completion in completions + ] + + updated_completions = await asyncio.gather(*image_generation_tasks) + + for index, updated_html in enumerate(updated_completions): + await send_message("setCode", updated_html, index) + await send_message("status", "Code generation complete.", index) await websocket.close() diff --git a/frontend/package.json b/frontend/package.json index 5cdbf64..eee838e 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -43,6 +43,7 @@ "copy-to-clipboard": "^3.3.3", "html2canvas": "^1.4.1", "posthog-js": "^1.128.1", + "nanoid": "^5.0.7", "react": "^18.2.0", "react-dom": "^18.2.0", "react-dropzone": "^14.2.3", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 85b7e0e..c51f783 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -9,8 +9,7 @@ import { usePersistedState } from "./hooks/usePersistedState"; import TermsOfServiceDialog from "./components/TermsOfServiceDialog"; import { USER_CLOSE_WEB_SOCKET_CODE } from "./constants"; import { addEvent } from "./lib/analytics"; -import { History } from "./components/history/history_types"; -import { extractHistoryTree } from "./components/history/utils"; +import { extractHistory } from "./components/history/utils"; import toast from "react-hot-toast"; import { useAuth } from "@clerk/clerk-react"; import { useStore } from "./store/store"; @@ -27,6 +26,8 @@ import { GenerationSettings } from "./components/settings/GenerationSettings"; import StartPane from "./components/start-pane/StartPane"; import { takeScreenshot } from "./lib/takeScreenshot"; import Sidebar from "./components/sidebar/Sidebar"; +import { Commit } from "./components/commits/types"; +import { createCommit } from "./components/commits/utils"; interface Props { navbarComponent?: JSX.Element; @@ -49,13 +50,19 @@ function App({ navbarComponent }: Props) { referenceImages, setReferenceImages, + head, + commits, + addCommit, + removeCommit, + setHead, + appendCommitCode, + setCommitCode, + resetCommits, + resetHead, + // Outputs - setGeneratedCode, - setExecutionConsole, - currentVersion, - setCurrentVersion, - appHistory, - setAppHistory, + appendExecutionConsole, + resetExecutionConsoles, } = useProjectStore(); const { @@ -119,32 +126,31 @@ function App({ navbarComponent }: Props) { // Functions const reset = () => { setAppState(AppState.INITIAL); - setGeneratedCode(""); - setReferenceImages([]); - setInitialPrompt(""); - setExecutionConsole([]); - setUpdateInstruction(""); - setIsImportedFromCode(false); - setAppHistory([]); - setCurrentVersion(null); setShouldIncludeResultImage(false); + setUpdateInstruction(""); disableInSelectAndEditMode(); + resetExecutionConsoles(); + + resetCommits(); + resetHead(); + + // Inputs + setInputMode("image"); + setReferenceImages([]); + setIsImportedFromCode(false); }; const regenerate = () => { - if (currentVersion === null) { - // This would be a error that I log to Sentry - addEvent("RegenerateCurrentVersionNull"); + if (head === null) { toast.error( - "No current version set. Please open a Github issue as this shouldn't happen." + "No current version set. Please contact support via chat or Github." ); - return; + throw new Error("Regenerate called with no head"); } // Retrieve the previous command - const previousCommand = appHistory[currentVersion]; - if (previousCommand.type !== "ai_create") { - addEvent("RegenerateNotFirstVersion"); + const currentCommit = commits[head]; + if (currentCommit.type !== "ai_create") { toast.error("Only the first version can be regenerated."); return; } @@ -165,28 +171,32 @@ function App({ navbarComponent }: Props) { addEvent("Cancel"); wsRef.current?.close?.(USER_CLOSE_WEB_SOCKET_CODE); - // make sure stop can correct the state even if the websocket is already closed - cancelCodeGenerationAndReset(); }; // Used for code generation failure as well - const cancelCodeGenerationAndReset = () => { - // When this is the first version, reset the entire app state - if (currentVersion === null) { + const cancelCodeGenerationAndReset = (commit: Commit) => { + // When the current commit is the first version, reset the entire app state + if (commit.type === "ai_create") { reset(); } else { - // Otherwise, revert to the last version - setGeneratedCode(appHistory[currentVersion].code); + // Otherwise, remove current commit from commits + removeCommit(commit.hash); + + // Revert to parent commit + const parentCommitHash = commit.parentHash; + if (parentCommitHash) { + setHead(parentCommitHash); + } else { + throw new Error("Parent commit not found"); + } + setAppState(AppState.CODE_READY); } }; - async function doGenerateCode( - params: CodeGenerationParams, - parentVersion: number | null - ) { + async function doGenerateCode(params: CodeGenerationParams) { // Reset the execution console - setExecutionConsole([]); + resetExecutionConsoles(); // Set the app state setAppState(AppState.CODING); @@ -199,69 +209,50 @@ function App({ navbarComponent }: Props) { authToken: authToken || undefined, }; + const baseCommitObject = { + variants: [{ code: "" }, { code: "" }], + }; + + const commitInputObject = + params.generationType === "create" + ? { + ...baseCommitObject, + type: "ai_create" as const, + parentHash: null, + inputs: { image_url: referenceImages[0] }, + } + : { + ...baseCommitObject, + type: "ai_edit" as const, + parentHash: head, + inputs: { + prompt: params.history + ? params.history[params.history.length - 1] + : "", + }, + }; + + // Create a new commit and set it as the head + const commit = createCommit(commitInputObject); + addCommit(commit); + setHead(commit.hash); + generateCode( wsRef, updatedParams, // On change - (token) => setGeneratedCode((prev) => prev + token), + (token, variantIndex) => { + appendCommitCode(commit.hash, variantIndex, token); + }, // On set code - (code) => { - setGeneratedCode(code); - if (params.generationType === "create") { - if (inputMode === "image" || inputMode === "video") { - setAppHistory([ - { - type: "ai_create", - parentIndex: null, - code, - inputs: { image_url: referenceImages[0] }, - }, - ]); - } else { - setAppHistory([ - { - type: "ai_create", - parentIndex: null, - code, - inputs: { text: params.image }, - }, - ]); - } - setCurrentVersion(0); - } else { - setAppHistory((prev) => { - // Validate parent version - if (parentVersion === null) { - toast.error( - "No parent version set. Contact support or open a Github issue." - ); - addEvent("ParentVersionNull"); - return prev; - } - - const newHistory: History = [ - ...prev, - { - type: "ai_edit", - parentIndex: parentVersion, - code, - inputs: { - prompt: params.history - ? params.history[params.history.length - 1] - : "", // History should never be empty when performing an edit - }, - }, - ]; - setCurrentVersion(newHistory.length - 1); - return newHistory; - }); - } + (code, variantIndex) => { + setCommitCode(commit.hash, variantIndex, code); }, // On status update - (line) => setExecutionConsole((prev) => [...prev, line]), + (line, variantIndex) => appendExecutionConsole(variantIndex, line), // On cancel () => { - cancelCodeGenerationAndReset(); + cancelCodeGenerationAndReset(commit); }, // On complete () => { @@ -285,14 +276,11 @@ function App({ navbarComponent }: Props) { // Kick off the code generation if (referenceImages.length > 0) { addEvent("Create"); - await doGenerateCode( - { - generationType: "create", - image: referenceImages[0], - inputMode, - }, - currentVersion - ); + doGenerateCode({ + generationType: "create", + image: referenceImages[0], + inputMode, + }); } } @@ -302,14 +290,11 @@ function App({ navbarComponent }: Props) { setInputMode("text"); setInitialPrompt(text); - doGenerateCode( - { - generationType: "create", - inputMode: "text", - image: text, - }, - currentVersion - ); + doGenerateCode({ + generationType: "create", + inputMode: "text", + image: text, + }); } // Subsequent updates @@ -322,23 +307,22 @@ function App({ navbarComponent }: Props) { return; } - if (currentVersion === null) { + if (head === null) { toast.error( "No current version set. Contact support or open a Github issue." ); - addEvent("CurrentVersionNull"); - return; + throw new Error("Update called with no head"); } let historyTree; try { - historyTree = extractHistoryTree(appHistory, currentVersion); + historyTree = extractHistory(head, commits); } catch { addEvent("HistoryTreeFailed"); toast.error( "Version history is invalid. This shouldn't happen. Please contact support or open a Github issue." ); - return; + throw new Error("Invalid version history"); } let modifiedUpdateInstruction = updateInstruction; @@ -352,34 +336,19 @@ function App({ navbarComponent }: Props) { } const updatedHistory = [...historyTree, modifiedUpdateInstruction]; + const resultImage = shouldIncludeResultImage + ? await takeScreenshot() + : undefined; - if (shouldIncludeResultImage) { - const resultImage = await takeScreenshot(); - await doGenerateCode( - { - generationType: "update", - inputMode, - image: referenceImages[0], - resultImage: resultImage, - history: updatedHistory, - isImportedFromCode, - }, - currentVersion - ); - } else { - await doGenerateCode( - { - generationType: "update", - inputMode, - image: inputMode === "text" ? initialPrompt : referenceImages[0], - history: updatedHistory, - isImportedFromCode, - }, - currentVersion - ); - } + doGenerateCode({ + generationType: "update", + inputMode, + image: referenceImages[0], + resultImage, + history: updatedHistory, + isImportedFromCode, + }); - setGeneratedCode(""); setUpdateInstruction(""); } @@ -402,17 +371,17 @@ function App({ navbarComponent }: Props) { setIsImportedFromCode(true); // Set up this project - setGeneratedCode(code); setStack(stack); - setAppHistory([ - { - type: "code_create", - parentIndex: null, - code, - inputs: { code }, - }, - ]); - setCurrentVersion(0); + + // Create a new commit and set it as the head + const commit = createCommit({ + type: "code_create", + parentHash: null, + variants: [{ code }], + inputs: null, + }); + addCommit(commit); + setHead(commit.hash); // Set the app state setAppState(AppState.CODE_READY); diff --git a/frontend/src/components/commits/types.ts b/frontend/src/components/commits/types.ts new file mode 100644 index 0000000..eb9ea12 --- /dev/null +++ b/frontend/src/components/commits/types.ts @@ -0,0 +1,37 @@ +export type CommitHash = string; + +export type Variant = { + code: string; +}; + +export type BaseCommit = { + hash: CommitHash; + parentHash: CommitHash | null; + dateCreated: Date; + isCommitted: boolean; + variants: Variant[]; + selectedVariantIndex: number; +}; + +export type CommitType = "ai_create" | "ai_edit" | "code_create"; + +export type AiCreateCommit = BaseCommit & { + type: "ai_create"; + inputs: { + image_url: string; + }; +}; + +export type AiEditCommit = BaseCommit & { + type: "ai_edit"; + inputs: { + prompt: string; + }; +}; + +export type CodeCreateCommit = BaseCommit & { + type: "code_create"; + inputs: null; +}; + +export type Commit = AiCreateCommit | AiEditCommit | CodeCreateCommit; diff --git a/frontend/src/components/commits/utils.ts b/frontend/src/components/commits/utils.ts new file mode 100644 index 0000000..640f83e --- /dev/null +++ b/frontend/src/components/commits/utils.ts @@ -0,0 +1,32 @@ +import { nanoid } from "nanoid"; +import { + AiCreateCommit, + AiEditCommit, + CodeCreateCommit, + Commit, +} from "./types"; + +export function createCommit( + commit: + | Omit< + AiCreateCommit, + "hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted" + > + | Omit< + AiEditCommit, + "hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted" + > + | Omit< + CodeCreateCommit, + "hash" | "dateCreated" | "selectedVariantIndex" | "isCommitted" + > +): Commit { + const hash = nanoid(); + return { + ...commit, + hash, + isCommitted: false, + dateCreated: new Date(), + selectedVariantIndex: 0, + }; +} diff --git a/frontend/src/components/history/HistoryDisplay.tsx b/frontend/src/components/history/HistoryDisplay.tsx index 6bbdc59..79d7fff 100644 --- a/frontend/src/components/history/HistoryDisplay.tsx +++ b/frontend/src/components/history/HistoryDisplay.tsx @@ -17,19 +17,16 @@ interface Props { } export default function HistoryDisplay({ shouldDisableReverts }: Props) { - const { - appHistory: history, - currentVersion, - setCurrentVersion, - setGeneratedCode, - } = useProjectStore(); - const renderedHistory = renderHistory(history, currentVersion); + const { commits, head, setHead } = useProjectStore(); - const revertToVersion = (index: number) => { - if (index < 0 || index >= history.length || !history[index]) return; - setCurrentVersion(index); - setGeneratedCode(history[index].code); - }; + // Put all commits into an array and sort by created date (oldest first) + const flatHistory = Object.values(commits).sort( + (a, b) => + new Date(a.dateCreated).getTime() - new Date(b.dateCreated).getTime() + ); + + // Annotate history items with a summary, parent version, etc. + const renderedHistory = renderHistory(flatHistory); return renderedHistory.length === 0 ? null : (
@@ -43,8 +40,8 @@ export default function HistoryDisplay({ shouldDisableReverts }: Props) { "flex items-center justify-between space-x-2 w-full pr-2", "border-b cursor-pointer", { - " hover:bg-black hover:text-white": !item.isActive, - "bg-slate-500 text-white": item.isActive, + " hover:bg-black hover:text-white": item.hash === head, + "bg-slate-500 text-white": item.hash === head, } )} > @@ -55,14 +52,14 @@ export default function HistoryDisplay({ shouldDisableReverts }: Props) { ? toast.error( "Please wait for code generation to complete before viewing an older version." ) - : revertToVersion(index) + : setHead(item.hash) } >

{item.summary}

{item.parentVersion !== null && (

- (parent: {item.parentVersion}) + (parent: v{item.parentVersion})

)}
diff --git a/frontend/src/components/history/history_types.ts b/frontend/src/components/history/history_types.ts deleted file mode 100644 index 19f2f9c..0000000 --- a/frontend/src/components/history/history_types.ts +++ /dev/null @@ -1,45 +0,0 @@ -export type HistoryItemType = "ai_create" | "ai_edit" | "code_create"; - -type CommonHistoryItem = { - parentIndex: null | number; - code: string; -}; - -export type HistoryItem = - | ({ - type: "ai_create"; - inputs: AiCreateInputs | AiCreateInputsText; - } & CommonHistoryItem) - | ({ - type: "ai_edit"; - inputs: AiEditInputs; - } & CommonHistoryItem) - | ({ - type: "code_create"; - inputs: CodeCreateInputs; - } & CommonHistoryItem); - -export type AiCreateInputs = { - image_url: string; -}; - -export type AiCreateInputsText = { - text: string; -}; - -export type AiEditInputs = { - prompt: string; -}; - -export type CodeCreateInputs = { - code: string; -}; - -export type History = HistoryItem[]; - -export type RenderedHistoryItem = { - type: string; - summary: string; - parentVersion: string | null; - isActive: boolean; -}; diff --git a/frontend/src/components/history/utils.test.ts b/frontend/src/components/history/utils.test.ts index e321bdc..7a5aaa2 100644 --- a/frontend/src/components/history/utils.test.ts +++ b/frontend/src/components/history/utils.test.ts @@ -1,91 +1,125 @@ -import { extractHistoryTree, renderHistory } from "./utils"; -import type { History } from "./history_types"; +import { extractHistory, renderHistory } from "./utils"; +import { Commit, CommitHash } from "../commits/types"; -const basicLinearHistory: History = [ - { +const basicLinearHistory: Record = { + "0": { + hash: "0", + dateCreated: new Date(), + isCommitted: false, type: "ai_create", - parentIndex: null, - code: "1. create", + parentHash: null, + variants: [{ code: "1. create" }], + selectedVariantIndex: 0, inputs: { image_url: "", }, }, - { + "1": { + hash: "1", + dateCreated: new Date(), + isCommitted: false, type: "ai_edit", - parentIndex: 0, - code: "2. edit with better icons", + parentHash: "0", + variants: [{ code: "2. edit with better icons" }], + selectedVariantIndex: 0, inputs: { prompt: "use better icons", }, }, - { + "2": { + hash: "2", + dateCreated: new Date(), + isCommitted: false, type: "ai_edit", - parentIndex: 1, - code: "3. edit with better icons and red text", + parentHash: "1", + variants: [{ code: "3. edit with better icons and red text" }], + selectedVariantIndex: 0, inputs: { prompt: "make text red", }, }, -]; +}; -const basicLinearHistoryWithCode: History = [ - { +const basicLinearHistoryWithCode: Record = { + "0": { + hash: "0", + dateCreated: new Date(), + isCommitted: false, type: "code_create", - parentIndex: null, - code: "1. create", - inputs: { - code: "1. create", - }, + parentHash: null, + variants: [{ code: "1. create" }], + selectedVariantIndex: 0, + inputs: null, }, - ...basicLinearHistory.slice(1), -]; + ...Object.fromEntries(Object.entries(basicLinearHistory).slice(1)), +}; -const basicBranchingHistory: History = [ +const basicBranchingHistory: Record = { ...basicLinearHistory, - { + "3": { + hash: "3", + dateCreated: new Date(), + isCommitted: false, type: "ai_edit", - parentIndex: 1, - code: "4. edit with better icons and green text", + parentHash: "1", + variants: [ + { code: "4. edit with better icons and green text" }, + ], + selectedVariantIndex: 0, inputs: { prompt: "make text green", }, }, -]; +}; -const longerBranchingHistory: History = [ +const longerBranchingHistory: Record = { ...basicBranchingHistory, - { + "4": { + hash: "4", + dateCreated: new Date(), + isCommitted: false, type: "ai_edit", - parentIndex: 3, - code: "5. edit with better icons and green, bold text", + parentHash: "3", + variants: [ + { code: "5. edit with better icons and green, bold text" }, + ], + selectedVariantIndex: 0, inputs: { prompt: "make text bold", }, }, -]; +}; -const basicBadHistory: History = [ - { +const basicBadHistory: Record = { + "0": { + hash: "0", + dateCreated: new Date(), + isCommitted: false, type: "ai_create", - parentIndex: null, - code: "1. create", + parentHash: null, + variants: [{ code: "1. create" }], + selectedVariantIndex: 0, inputs: { image_url: "", }, }, - { + "1": { + hash: "1", + dateCreated: new Date(), + isCommitted: false, type: "ai_edit", - parentIndex: 2, // <- Bad parent index - code: "2. edit with better icons", + parentHash: "2", // <- Bad parent hash + variants: [{ code: "2. edit with better icons" }], + selectedVariantIndex: 0, inputs: { prompt: "use better icons", }, }, -]; +}; describe("History Utils", () => { test("should correctly extract the history tree", () => { - expect(extractHistoryTree(basicLinearHistory, 2)).toEqual([ + expect(extractHistory("2", basicLinearHistory)).toEqual([ "1. create", "use better icons", "2. edit with better icons", @@ -93,12 +127,12 @@ describe("History Utils", () => { "3. edit with better icons and red text", ]); - expect(extractHistoryTree(basicLinearHistory, 0)).toEqual([ + expect(extractHistory("0", basicLinearHistory)).toEqual([ "1. create", ]); // Test branching - expect(extractHistoryTree(basicBranchingHistory, 3)).toEqual([ + expect(extractHistory("3", basicBranchingHistory)).toEqual([ "1. create", "use better icons", "2. edit with better icons", @@ -106,7 +140,7 @@ describe("History Utils", () => { "4. edit with better icons and green text", ]); - expect(extractHistoryTree(longerBranchingHistory, 4)).toEqual([ + expect(extractHistory("4", longerBranchingHistory)).toEqual([ "1. create", "use better icons", "2. edit with better icons", @@ -116,7 +150,7 @@ describe("History Utils", () => { "5. edit with better icons and green, bold text", ]); - expect(extractHistoryTree(longerBranchingHistory, 2)).toEqual([ + expect(extractHistory("2", longerBranchingHistory)).toEqual([ "1. create", "use better icons", "2. edit with better icons", @@ -126,105 +160,82 @@ describe("History Utils", () => { // Errors - // Bad index - expect(() => extractHistoryTree(basicLinearHistory, 100)).toThrow(); - expect(() => extractHistoryTree(basicLinearHistory, -2)).toThrow(); + // Bad hash + expect(() => extractHistory("100", basicLinearHistory)).toThrow(); // Bad tree - expect(() => extractHistoryTree(basicBadHistory, 1)).toThrow(); + expect(() => extractHistory("1", basicBadHistory)).toThrow(); }); test("should correctly render the history tree", () => { - expect(renderHistory(basicLinearHistory, 2)).toEqual([ + expect(renderHistory(Object.values(basicLinearHistory))).toEqual([ { - isActive: false, - parentVersion: null, - summary: "Create", + ...basicLinearHistory["0"], type: "Create", - }, - { - isActive: false, - parentVersion: null, - summary: "use better icons", - type: "Edit", - }, - { - isActive: true, - parentVersion: null, - summary: "make text red", - type: "Edit", - }, - ]); - - // Current version is the first version - expect(renderHistory(basicLinearHistory, 0)).toEqual([ - { - isActive: true, - parentVersion: null, summary: "Create", - type: "Create", + parentVersion: null, }, { - isActive: false, - parentVersion: null, + ...basicLinearHistory["1"], + type: "Edit", summary: "use better icons", - type: "Edit", + parentVersion: null, }, { - isActive: false, - parentVersion: null, - summary: "make text red", + ...basicLinearHistory["2"], type: "Edit", + summary: "make text red", + parentVersion: null, }, ]); // Render a history with code - expect(renderHistory(basicLinearHistoryWithCode, 0)).toEqual([ + expect(renderHistory(Object.values(basicLinearHistoryWithCode))).toEqual([ { - isActive: true, - parentVersion: null, - summary: "Imported from code", + ...basicLinearHistoryWithCode["0"], type: "Imported from code", + summary: "Imported from code", + parentVersion: null, }, { - isActive: false, - parentVersion: null, + ...basicLinearHistoryWithCode["1"], + type: "Edit", summary: "use better icons", - type: "Edit", + parentVersion: null, }, { - isActive: false, - parentVersion: null, - summary: "make text red", + ...basicLinearHistoryWithCode["2"], type: "Edit", + summary: "make text red", + parentVersion: null, }, ]); // Render a non-linear history - expect(renderHistory(basicBranchingHistory, 3)).toEqual([ + expect(renderHistory(Object.values(basicBranchingHistory))).toEqual([ { - isActive: false, - parentVersion: null, - summary: "Create", + ...basicBranchingHistory["0"], type: "Create", + summary: "Create", + parentVersion: null, }, { - isActive: false, - parentVersion: null, + ...basicBranchingHistory["1"], + type: "Edit", summary: "use better icons", - type: "Edit", - }, - { - isActive: false, parentVersion: null, - summary: "make text red", - type: "Edit", }, { - isActive: true, - parentVersion: "v2", - summary: "make text green", + ...basicBranchingHistory["2"], type: "Edit", + summary: "make text red", + parentVersion: null, + }, + { + ...basicBranchingHistory["3"], + type: "Edit", + summary: "make text green", + parentVersion: 2, }, ]); }); diff --git a/frontend/src/components/history/utils.ts b/frontend/src/components/history/utils.ts index 785c20b..16fee84 100644 --- a/frontend/src/components/history/utils.ts +++ b/frontend/src/components/history/utils.ts @@ -1,33 +1,25 @@ -import { - History, - HistoryItem, - HistoryItemType, - RenderedHistoryItem, -} from "./history_types"; +import { Commit, CommitHash, CommitType } from "../commits/types"; -export function extractHistoryTree( - history: History, - version: number +export function extractHistory( + hash: CommitHash, + commits: Record ): string[] { const flatHistory: string[] = []; - let currentIndex: number | null = version; - while (currentIndex !== null) { - const item: HistoryItem = history[currentIndex]; + let currentCommitHash: CommitHash | null = hash; + while (currentCommitHash !== null) { + const commit: Commit | null = commits[currentCommitHash]; - if (item) { - if (item.type === "ai_create") { - // Don't include the image for ai_create - flatHistory.unshift(item.code); - } else if (item.type === "ai_edit") { - flatHistory.unshift(item.code); - flatHistory.unshift(item.inputs.prompt); - } else if (item.type === "code_create") { - flatHistory.unshift(item.code); + if (commit) { + flatHistory.unshift(commit.variants[commit.selectedVariantIndex].code); + + // For edits, add the prompt to the history + if (commit.type === "ai_edit") { + flatHistory.unshift(commit.inputs.prompt); } // Move to the parent of the current item - currentIndex = item.parentIndex; + currentCommitHash = commit.parentHash; } else { throw new Error("Malformed history: missing parent index"); } @@ -36,7 +28,7 @@ export function extractHistoryTree( return flatHistory; } -function displayHistoryItemType(itemType: HistoryItemType) { +function displayHistoryItemType(itemType: CommitType) { switch (itemType) { case "ai_create": return "Create"; @@ -51,44 +43,48 @@ function displayHistoryItemType(itemType: HistoryItemType) { } } -function summarizeHistoryItem(item: HistoryItem) { - const itemType = item.type; - switch (itemType) { +const setParentVersion = (commit: Commit, history: Commit[]) => { + // If the commit has no parent, return null + if (!commit.parentHash) return null; + + const parentIndex = history.findIndex( + (item) => item.hash === commit.parentHash + ); + const currentIndex = history.findIndex((item) => item.hash === commit.hash); + + // Only set parent version if the parent is not the previous commit + // and parent exists + return parentIndex !== -1 && parentIndex != currentIndex - 1 + ? parentIndex + 1 + : null; +}; + +export function summarizeHistoryItem(commit: Commit) { + const commitType = commit.type; + switch (commitType) { case "ai_create": return "Create"; case "ai_edit": - return item.inputs.prompt; + return commit.inputs.prompt; case "code_create": return "Imported from code"; default: { - const exhaustiveCheck: never = itemType; + const exhaustiveCheck: never = commitType; throw new Error(`Unhandled case: ${exhaustiveCheck}`); } } } -export const renderHistory = ( - history: History, - currentVersion: number | null -) => { - const renderedHistory: RenderedHistoryItem[] = []; +export const renderHistory = (history: Commit[]) => { + const renderedHistory = []; for (let i = 0; i < history.length; i++) { - const item = history[i]; - // Only show the parent version if it's not the previous version - // (i.e. it's the branching point) and if it's not the first version - const parentVersion = - item.parentIndex !== null && item.parentIndex !== i - 1 - ? `v${(item.parentIndex || 0) + 1}` - : null; - const type = displayHistoryItemType(item.type); - const isActive = i === currentVersion; - const summary = summarizeHistoryItem(item); + const commit = history[i]; renderedHistory.push({ - isActive, - summary: summary, - parentVersion, - type, + ...commit, + type: displayHistoryItemType(commit.type), + summary: summarizeHistoryItem(commit), + parentVersion: setParentVersion(commit, history), }); } diff --git a/frontend/src/components/preview/PreviewPane.tsx b/frontend/src/components/preview/PreviewPane.tsx index bc4c1cf..cc58008 100644 --- a/frontend/src/components/preview/PreviewPane.tsx +++ b/frontend/src/components/preview/PreviewPane.tsx @@ -23,12 +23,17 @@ interface Props { function PreviewPane({ doUpdate, reset, settings }: Props) { const { appState } = useAppStore(); - const { inputMode, generatedCode, setGeneratedCode } = useProjectStore(); + const { inputMode, head, commits } = useProjectStore(); + + const currentCommit = head && commits[head] ? commits[head] : ""; + const currentCode = currentCommit + ? currentCommit.variants[currentCommit.selectedVariantIndex].code + : ""; const previewCode = inputMode === "video" && appState === AppState.CODING - ? extractHtml(generatedCode) - : generatedCode; + ? extractHtml(currentCode) + : currentCode; return (
@@ -45,7 +50,7 @@ function PreviewPane({ doUpdate, reset, settings }: Props) { Reset
diff --git a/frontend/src/components/sidebar/Sidebar.tsx b/frontend/src/components/sidebar/Sidebar.tsx index 5246637..913df1f 100644 --- a/frontend/src/components/sidebar/Sidebar.tsx +++ b/frontend/src/components/sidebar/Sidebar.tsx @@ -12,6 +12,7 @@ import { Button } from "../ui/button"; import { Textarea } from "../ui/textarea"; import { useEffect, useRef } from "react"; import HistoryDisplay from "../history/HistoryDisplay"; +import Variants from "../variants/Variants"; interface SidebarProps { showSelectAndEditFeature: boolean; @@ -35,9 +36,18 @@ function Sidebar({ shouldIncludeResultImage, setShouldIncludeResultImage, } = useAppStore(); - const { inputMode, generatedCode, referenceImages, executionConsole } = + + const { inputMode, referenceImages, executionConsoles, head, commits } = useProjectStore(); + const viewedCode = + head && commits[head] + ? commits[head].variants[commits[head].selectedVariantIndex].code + : ""; + + const executionConsole = + (head && executionConsoles[commits[head].selectedVariantIndex]) || []; + // When coding is complete, focus on the update instruction textarea useEffect(() => { if (appState === AppState.CODE_READY && textareaRef.current) { @@ -47,6 +57,8 @@ function Sidebar({ return ( <> + + {/* Show code preview only when coding */} {appState === AppState.CODING && (
@@ -66,7 +78,7 @@ function Sidebar({ {executionConsole.slice(-1)[0]}
- +