diff --git a/README.md b/README.md index ad42423..bc82497 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # screenshot-to-code -This simple app converts a screenshot to code (HTML/Tailwind CSS, or React or Bootstrap). It uses GPT-4 Vision to generate the code and DALL-E 3 to generate similar-looking images. You can now also enter a URL to clone a live website! +This simple app converts a screenshot to code (HTML/Tailwind CSS, or React or Bootstrap or Vue). It uses GPT-4 Vision to generate the code and DALL-E 3 to generate similar-looking images. You can now also enter a URL to clone a live website! https://github.com/abi/screenshot-to-code/assets/23818/6cebadae-2fe3-4986-ac6a-8fb9db030045 @@ -38,6 +38,12 @@ poetry shell poetry run uvicorn main:app --reload --port 7001 ``` +You can also run the backend (when you're in `backend`): + +```bash +poetry run pyright +``` + Run the frontend: ```bash @@ -58,7 +64,7 @@ MOCK=true poetry run uvicorn main:app --reload --port 7001 ## Configuration -* You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog +- You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog ## Docker diff --git a/backend/.gitignore b/backend/.gitignore index 128eab6..a42aad3 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -153,4 +153,4 @@ cython_debug/ # Temporary eval output -evals +evals_data diff --git a/backend/README.md b/backend/README.md index ee55816..155bf46 100644 --- a/backend/README.md +++ b/backend/README.md @@ -1,3 +1,7 @@ +# Run the type checker + +poetry run pyright + # Run tests poetry run pytest diff --git a/backend/eval_config.py b/backend/eval_config.py deleted file mode 100644 index 62a3b8f..0000000 --- a/backend/eval_config.py +++ /dev/null @@ -1 +0,0 @@ -EVALS_DIR = "./evals" diff --git a/backend/evals/__init__.py b/backend/evals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/evals/config.py b/backend/evals/config.py new file mode 100644 index 0000000..7643027 --- /dev/null +++ b/backend/evals/config.py @@ -0,0 +1 @@ +EVALS_DIR = "./evals_data" diff --git a/backend/evals/core.py b/backend/evals/core.py new file mode 100644 index 0000000..61db1a3 --- /dev/null +++ b/backend/evals/core.py @@ -0,0 +1,29 @@ +import os + +from llm import stream_openai_response +from prompts import assemble_prompt +from prompts.types import Stack +from utils import pprint_prompt + + +async def generate_code_core(image_url: str, stack: Stack) -> str: + prompt_messages = assemble_prompt(image_url, stack) + openai_api_key = os.environ.get("OPENAI_API_KEY") + openai_base_url = None + + pprint_prompt(prompt_messages) + + async def process_chunk(content: str): + pass + + if not openai_api_key: + raise Exception("OpenAI API key not found") + + completion = await stream_openai_response( + prompt_messages, + api_key=openai_api_key, + base_url=openai_base_url, + callback=lambda x: process_chunk(x), + ) + + return completion diff --git a/backend/eval_utils.py b/backend/evals/utils.py similarity index 100% rename from backend/eval_utils.py rename to backend/evals/utils.py diff --git a/backend/poetry.lock b/backend/poetry.lock index 2ff41d4..0ac5a9a 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -213,16 +213,31 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "nodeenv" +version = "1.8.0" +description = "Node.js virtual environment builder" +category = "dev" +optional = false +python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" +files = [ + {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, + {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, +] + +[package.dependencies] +setuptools = "*" + [[package]] name = "openai" -version = "1.6.1" +version = "1.7.0" description = "The official Python library for the openai API" category = "main" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.6.1-py3-none-any.whl", hash = "sha256:bc9f774838d67ac29fb24cdeb2d58faf57de8b311085dcd1348f7aa02a96c7ee"}, - {file = "openai-1.6.1.tar.gz", hash = "sha256:d553ca9dbf9486b08e75b09e8671e4f638462aaadccfced632bf490fc3d75fa2"}, + {file = "openai-1.7.0-py3-none-any.whl", hash = "sha256:2282e8e15acb05df79cccba330c025b8e84284c7ec1f3fa31f167a8479066333"}, + {file = "openai-1.7.0.tar.gz", hash = "sha256:f2a8dcb739e8620c9318a2c6304ea72aebb572ba02fa1d586344405e80d567d3"}, ] [package.dependencies] @@ -318,6 +333,25 @@ typing-extensions = ">=4.2.0" dotenv = ["python-dotenv (>=0.10.4)"] email = ["email-validator (>=1.0.3)"] +[[package]] +name = "pyright" +version = "1.1.345" +description = "Command line wrapper for pyright" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyright-1.1.345-py3-none-any.whl", hash = "sha256:00891361baf58698aa660d9374823d65782823ceb4a65515ff5dd159b0d4d2b1"}, + {file = "pyright-1.1.345.tar.gz", hash = "sha256:bb8c80671cdaeb913142b49642a741959f3fcd728c99814631c2bde3a7864938"}, +] + +[package.dependencies] +nodeenv = ">=1.6.0" + +[package.extras] +all = ["twine (>=3.4.1)"] +dev = ["twine (>=3.4.1)"] + [[package]] name = "pytest" version = "7.4.4" @@ -403,6 +437,23 @@ starlette = ["starlette (>=0.19.1)"] starlite = ["starlite (>=1.48)"] tornado = ["tornado (>=5)"] +[[package]] +name = "setuptools" +version = "69.0.3" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "setuptools-69.0.3-py3-none-any.whl", hash = "sha256:385eb4edd9c9d5c17540511303e39a147ce2fc04bc55289c322b9e5904fe2c05"}, + {file = "setuptools-69.0.3.tar.gz", hash = "sha256:be1af57fc409f93647f2e8e4573a142ed38724b8cdd389706a867bb4efcf1e78"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] + [[package]] name = "sniffio" version = "1.3.0" @@ -612,4 +663,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "ce29c56f8cb6ba1d0480489da7ab123aa620a146d464b2904cef8b7bcef82a05" +content-hash = "d69f678c50d40b06ddb2080231097b553fc7b1930dc69b3b07e18ca31ba010cf" diff --git a/backend/prompts.py b/backend/prompts.py deleted file mode 100644 index d3d3b18..0000000 --- a/backend/prompts.py +++ /dev/null @@ -1,119 +0,0 @@ -from typing import List, Union - -from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam - -from imported_code_prompts import ( - IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT, - IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT, - IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT, - IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT, - IMPORTED_CODE_SVG_SYSTEM_PROMPT, -) -from screenshot_system_prompts import ( - BOOTSTRAP_SYSTEM_PROMPT, - IONIC_TAILWIND_SYSTEM_PROMPT, - REACT_TAILWIND_SYSTEM_PROMPT, - TAILWIND_SYSTEM_PROMPT, - SVG_SYSTEM_PROMPT, - VUE_TAILWIND_SYSTEM_PROMPT, -) - - -USER_PROMPT = """ -Generate code for a web page that looks exactly like this. -""" - -SVG_USER_PROMPT = """ -Generate code for a SVG that looks exactly like this. -""" - - -def assemble_imported_code_prompt( - code: str, stack: str, result_image_data_url: Union[str, None] = None -) -> List[ChatCompletionMessageParam]: - system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT - if stack == "html_tailwind": - system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT - elif stack == "react_tailwind": - system_content = IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT - elif stack == "bootstrap": - system_content = IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT - elif stack == "ionic_tailwind": - system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT - elif stack == "svg": - system_content = IMPORTED_CODE_SVG_SYSTEM_PROMPT - else: - raise Exception("Code config is not one of available options") - - user_content = ( - "Here is the code of the app: " + code - if stack != "svg" - else "Here is the code of the SVG: " + code - ) - return [ - { - "role": "system", - "content": system_content, - }, - { - "role": "user", - "content": user_content, - }, - ] - # TODO: Use result_image_data_url - - -def assemble_prompt( - image_data_url: str, - generated_code_config: str, - result_image_data_url: Union[str, None] = None, -) -> List[ChatCompletionMessageParam]: - # Set the system prompt based on the output settings - system_content = TAILWIND_SYSTEM_PROMPT - if generated_code_config == "html_tailwind": - system_content = TAILWIND_SYSTEM_PROMPT - elif generated_code_config == "react_tailwind": - system_content = REACT_TAILWIND_SYSTEM_PROMPT - elif generated_code_config == "bootstrap": - system_content = BOOTSTRAP_SYSTEM_PROMPT - elif generated_code_config == "ionic_tailwind": - system_content = IONIC_TAILWIND_SYSTEM_PROMPT - elif generated_code_config == "vue_tailwind": - system_content = VUE_TAILWIND_SYSTEM_PROMPT - elif generated_code_config == "svg": - system_content = SVG_SYSTEM_PROMPT - else: - raise Exception("Code config is not one of available options") - - user_prompt = USER_PROMPT if generated_code_config != "svg" else SVG_USER_PROMPT - - user_content: List[ChatCompletionContentPartParam] = [ - { - "type": "image_url", - "image_url": {"url": image_data_url, "detail": "high"}, - }, - { - "type": "text", - "text": user_prompt, - }, - ] - - # Include the result image if it exists - if result_image_data_url: - user_content.insert( - 1, - { - "type": "image_url", - "image_url": {"url": result_image_data_url, "detail": "high"}, - }, - ) - return [ - { - "role": "system", - "content": system_content, - }, - { - "role": "user", - "content": user_content, - }, - ] diff --git a/backend/prompts/__init__.py b/backend/prompts/__init__.py new file mode 100644 index 0000000..4f2e329 --- /dev/null +++ b/backend/prompts/__init__.py @@ -0,0 +1,79 @@ +from typing import List, NoReturn, Union + +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam + +from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS +from prompts.screenshot_system_prompts import SYSTEM_PROMPTS +from prompts.types import Stack + + +USER_PROMPT = """ +Generate code for a web page that looks exactly like this. +""" + +SVG_USER_PROMPT = """ +Generate code for a SVG that looks exactly like this. +""" + + +def assemble_imported_code_prompt( + code: str, stack: Stack, result_image_data_url: Union[str, None] = None +) -> List[ChatCompletionMessageParam]: + system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack] + + user_content = ( + "Here is the code of the app: " + code + if stack != "svg" + else "Here is the code of the SVG: " + code + ) + return [ + { + "role": "system", + "content": system_content, + }, + { + "role": "user", + "content": user_content, + }, + ] + # TODO: Use result_image_data_url + + +def assemble_prompt( + image_data_url: str, + stack: Stack, + result_image_data_url: Union[str, None] = None, +) -> List[ChatCompletionMessageParam]: + system_content = SYSTEM_PROMPTS[stack] + user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT + + user_content: List[ChatCompletionContentPartParam] = [ + { + "type": "image_url", + "image_url": {"url": image_data_url, "detail": "high"}, + }, + { + "type": "text", + "text": user_prompt, + }, + ] + + # Include the result image if it exists + if result_image_data_url: + user_content.insert( + 1, + { + "type": "image_url", + "image_url": {"url": result_image_data_url, "detail": "high"}, + }, + ) + return [ + { + "role": "system", + "content": system_content, + }, + { + "role": "user", + "content": user_content, + }, + ] diff --git a/backend/imported_code_prompts.py b/backend/prompts/imported_code_prompts.py similarity index 75% rename from backend/imported_code_prompts.py rename to backend/prompts/imported_code_prompts.py index a8bfa6a..8babf78 100644 --- a/backend/imported_code_prompts.py +++ b/backend/prompts/imported_code_prompts.py @@ -1,3 +1,6 @@ +from prompts.types import SystemPrompts + + IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT = """ You are an expert Tailwind developer. @@ -79,6 +82,38 @@ Return only the full code in tags. Do not include markdown "```" or "```html" at the start or end. """ +IMPORTED_CODE_VUE_TAILWIND_SYSTEM_PROMPT = """ +You are an expert Vue/Tailwind developer. + +- Do not add comments in the code such as "" and "" in place of writing the full code. WRITE THE FULL CODE. +- Repeat elements as needed. For example, if there are 15 items, the code should have 15 items. DO NOT LEAVE comments like "" or bad things will happen. +- For images, use placeholder images from https://placehold.co and include a detailed description of the image in the alt text so that an image generation AI can generate the image later. + +In terms of libraries, + +- Use these script to include Vue so that it can run on a standalone page: + +- Use Vue using the global build like so: +
{{ message }}
+ +- Use this script to include Tailwind: +- You can use Google Fonts +- Font Awesome for icons: + +Return only the full code in tags. +Do not include markdown "```" or "```html" at the start or end. +The return result must only include the code.""" + IMPORTED_CODE_SVG_SYSTEM_PROMPT = """ You are an expert at building SVGs. @@ -90,3 +125,12 @@ You are an expert at building SVGs. Return only the full code in tags. Do not include markdown "```" or "```svg" at the start or end. """ + +IMPORTED_CODE_SYSTEM_PROMPTS = SystemPrompts( + html_tailwind=IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT, + react_tailwind=IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT, + bootstrap=IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT, + ionic_tailwind=IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT, + vue_tailwind=IMPORTED_CODE_VUE_TAILWIND_SYSTEM_PROMPT, + svg=IMPORTED_CODE_SVG_SYSTEM_PROMPT, +) diff --git a/backend/screenshot_system_prompts.py b/backend/prompts/screenshot_system_prompts.py similarity index 96% rename from backend/screenshot_system_prompts.py rename to backend/prompts/screenshot_system_prompts.py index 1dfcca4..fca91ba 100644 --- a/backend/screenshot_system_prompts.py +++ b/backend/prompts/screenshot_system_prompts.py @@ -1,4 +1,7 @@ -TAILWIND_SYSTEM_PROMPT = """ +from prompts.types import SystemPrompts + + +HTML_TAILWIND_SYSTEM_PROMPT = """ You are an expert Tailwind developer You take screenshots of a reference web page from the user, and then build single page apps using Tailwind, HTML and JS. @@ -170,3 +173,13 @@ padding, margin, border, etc. Match the colors and sizes exactly. Return only the full code in tags. Do not include markdown "```" or "```svg" at the start or end. """ + + +SYSTEM_PROMPTS = SystemPrompts( + html_tailwind=HTML_TAILWIND_SYSTEM_PROMPT, + react_tailwind=REACT_TAILWIND_SYSTEM_PROMPT, + bootstrap=BOOTSTRAP_SYSTEM_PROMPT, + ionic_tailwind=IONIC_TAILWIND_SYSTEM_PROMPT, + vue_tailwind=VUE_TAILWIND_SYSTEM_PROMPT, + svg=SVG_SYSTEM_PROMPT, +) diff --git a/backend/test_prompts.py b/backend/prompts/test_prompts.py similarity index 90% rename from backend/test_prompts.py rename to backend/prompts/test_prompts.py index e12a207..9e60410 100644 --- a/backend/test_prompts.py +++ b/backend/prompts/test_prompts.py @@ -253,6 +253,39 @@ Return only the full code in tags. Do not include markdown "```" or "```html" at the start or end. """ + +IMPORTED_CODE_VUE_TAILWIND_PROMPT = """ +You are an expert Vue/Tailwind developer. + +- Do not add comments in the code such as "" and "" in place of writing the full code. WRITE THE FULL CODE. +- Repeat elements as needed. For example, if there are 15 items, the code should have 15 items. DO NOT LEAVE comments like "" or bad things will happen. +- For images, use placeholder images from https://placehold.co and include a detailed description of the image in the alt text so that an image generation AI can generate the image later. + +In terms of libraries, + +- Use these script to include Vue so that it can run on a standalone page: + +- Use Vue using the global build like so: +
{{ message }}
+ +- Use this script to include Tailwind: +- You can use Google Fonts +- Font Awesome for icons: + +Return only the full code in tags. +Do not include markdown "```" or "```html" at the start or end. +The return result must only include the code.""" + IMPORTED_CODE_SVG_SYSTEM_PROMPT = """ You are an expert at building SVGs. @@ -299,9 +332,11 @@ def test_prompts(): assert ionic_tailwind[0]["content"] == IONIC_TAILWIND_SYSTEM_PROMPT assert ionic_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore - vue = assemble_prompt("image_data_url", "vue_tailwind", "result_image_data_url") - assert vue[0]["content"] == VUE_TAILWIND_SYSTEM_PROMPT - assert vue[1]["content"][2]["text"] == USER_PROMPT # type: ignore + vue_tailwind = assemble_prompt( + "image_data_url", "vue_tailwind", "result_image_data_url" + ) + assert vue_tailwind[0]["content"] == VUE_TAILWIND_SYSTEM_PROMPT + assert vue_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore svg_prompt = assemble_prompt("image_data_url", "svg", "result_image_data_url") assert svg_prompt[0]["content"] == SVG_SYSTEM_PROMPT @@ -345,6 +380,15 @@ def test_imported_code_prompts(): ] assert ionic_tailwind == expected_ionic_tailwind + vue_tailwind = assemble_imported_code_prompt( + "code", "vue_tailwind", "result_image_data_url" + ) + expected_vue_tailwind = [ + {"role": "system", "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT}, + {"role": "user", "content": "Here is the code of the app: code"}, + ] + assert vue_tailwind == expected_vue_tailwind + svg = assemble_imported_code_prompt("code", "svg", "result_image_data_url") expected_svg = [ {"role": "system", "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT}, diff --git a/backend/prompts/types.py b/backend/prompts/types.py new file mode 100644 index 0000000..9068443 --- /dev/null +++ b/backend/prompts/types.py @@ -0,0 +1,20 @@ +from typing import Literal, TypedDict + + +class SystemPrompts(TypedDict): + html_tailwind: str + react_tailwind: str + bootstrap: str + ionic_tailwind: str + vue_tailwind: str + svg: str + + +Stack = Literal[ + "html_tailwind", + "react_tailwind", + "bootstrap", + "ionic_tailwind", + "vue_tailwind", + "svg", +] diff --git a/backend/pyproject.toml b/backend/pyproject.toml index a99b980..aefad67 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -18,6 +18,7 @@ sentry-sdk = {extras = ["fastapi"], version = "^1.38.0"} [tool.poetry.group.dev.dependencies] pytest = "^7.4.3" +pyright = "^1.1.345" [build-system] requires = ["poetry-core"] diff --git a/backend/pyrightconfig.json b/backend/pyrightconfig.json new file mode 100644 index 0000000..6e475af --- /dev/null +++ b/backend/pyrightconfig.json @@ -0,0 +1,3 @@ +{ + "exclude": ["image_generation.py"] +} diff --git a/backend/routes/evals.py b/backend/routes/evals.py index 48a3a95..798a9d8 100644 --- a/backend/routes/evals.py +++ b/backend/routes/evals.py @@ -1,8 +1,8 @@ import os from fastapi import APIRouter from pydantic import BaseModel -from eval_utils import image_to_data_url -from eval_config import EVALS_DIR +from evals.utils import image_to_data_url +from evals.config import EVALS_DIR router = APIRouter() diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 580df4d..349afbc 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -6,7 +6,7 @@ from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE from llm import stream_openai_response from openai.types.chat import ChatCompletionMessageParam from mock_llm import mock_completion -from typing import Dict, List +from typing import Dict, List, cast, get_args from image_generation import create_alt_url_mapping, generate_images from prompts import assemble_imported_code_prompt, assemble_prompt from access_token import validate_access_token @@ -14,6 +14,7 @@ from datetime import datetime import json from routes.logging_utils import PaymentMethod, send_to_saas_backend from routes.saas_utils import does_user_have_subscription_credits +from prompts.types import Stack from utils import pprint_prompt # type: ignore @@ -124,6 +125,13 @@ async def stream_code(websocket: WebSocket): ) return + # Validate the generated code config + if not generated_code_config in get_args(Stack): + await throw_error(f"Invalid generated code config: {generated_code_config}") + return + # Cast the variable to the Stack type + valid_stack = cast(Stack, generated_code_config) + # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. openai_base_url = None # Disable user-specified OpenAI Base URL in prod @@ -159,7 +167,7 @@ async def stream_code(websocket: WebSocket): if params.get("isImportedFromCode") and params["isImportedFromCode"]: original_imported_code = params["history"][0] prompt_messages = assemble_imported_code_prompt( - original_imported_code, generated_code_config + original_imported_code, valid_stack ) for index, text in enumerate(params["history"][1:]): if index % 2 == 0: @@ -178,12 +186,10 @@ async def stream_code(websocket: WebSocket): try: if params.get("resultImage") and params["resultImage"]: prompt_messages = assemble_prompt( - params["image"], generated_code_config, params["resultImage"] + params["image"], valid_stack, params["resultImage"] ) else: - prompt_messages = assemble_prompt( - params["image"], generated_code_config - ) + prompt_messages = assemble_prompt(params["image"], valid_stack) except: await websocket.send_json( { diff --git a/backend/eval.py b/backend/run_evals.py similarity index 58% rename from backend/eval.py rename to backend/run_evals.py index 60ef409..ddb7eaa 100644 --- a/backend/eval.py +++ b/backend/run_evals.py @@ -1,41 +1,15 @@ # Load environment variables first -from typing import Any, Coroutine from dotenv import load_dotenv -from eval_config import EVALS_DIR -from eval_utils import image_to_data_url - load_dotenv() import os -from llm import stream_openai_response -from prompts import assemble_prompt +from typing import Any, Coroutine import asyncio -from utils import pprint_prompt - - -async def generate_code_core(image_url: str, stack: str) -> str: - prompt_messages = assemble_prompt(image_url, stack) - openai_api_key = os.environ.get("OPENAI_API_KEY") - openai_base_url = None - - pprint_prompt(prompt_messages) - - async def process_chunk(content: str): - pass - - if not openai_api_key: - raise Exception("OpenAI API key not found") - - completion = await stream_openai_response( - prompt_messages, - api_key=openai_api_key, - base_url=openai_base_url, - callback=lambda x: process_chunk(x), - ) - - return completion +from evals.config import EVALS_DIR +from evals.core import generate_code_core +from evals.utils import image_to_data_url async def main(): diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 6b6ff67..dc5a100 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -17,13 +17,7 @@ import { Button } from "@/components/ui/button"; import { Textarea } from "@/components/ui/textarea"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "./components/ui/tabs"; import SettingsDialog from "./components/SettingsDialog"; -import { - AppState, - CodeGenerationParams, - EditorTheme, - GeneratedCodeConfig, - Settings, -} from "./types"; +import { AppState, CodeGenerationParams, EditorTheme, Settings } from "./types"; import { IS_RUNNING_ON_CLOUD } from "./config"; import { PicoBadge } from "./components/PicoBadge"; import { OnboardingNote } from "./components/OnboardingNote"; @@ -42,6 +36,7 @@ import toast from "react-hot-toast"; import ImportCodeSection from "./components/ImportCodeSection"; import { useAuth } from "@clerk/clerk-react"; import { useStore } from "./store/store"; +import { Stack } from "./lib/stacks/types"; const IS_OPENAI_DOWN = false; @@ -71,7 +66,7 @@ function App({ navbarComponent }: Props) { screenshotOneApiKey: null, isImageGenerationEnabled: true, editorTheme: EditorTheme.COBALT, - generatedCodeConfig: GeneratedCodeConfig.HTML_TAILWIND, + generatedCodeConfig: Stack.HTML_TAILWIND, // Only relevant for hosted version isTermOfServiceAccepted: true, accessCode: null, @@ -96,7 +91,7 @@ function App({ navbarComponent }: Props) { if (!settings.generatedCodeConfig) { setSettings((prev) => ({ ...prev, - generatedCodeConfig: GeneratedCodeConfig.HTML_TAILWIND, + generatedCodeConfig: Stack.HTML_TAILWIND, })); } }, [settings.generatedCodeConfig, setSettings]); @@ -310,15 +305,14 @@ function App({ navbarComponent }: Props) { })); }; - // TODO: Rename everything to "stack" instead of "config" - function setStack(stack: GeneratedCodeConfig) { + function setStack(stack: Stack) { setSettings((prev) => ({ ...prev, generatedCodeConfig: stack, })); } - function importFromCode(code: string, stack: GeneratedCodeConfig) { + function importFromCode(code: string, stack: Stack) { setIsImportedFromCode(true); // Set up this project @@ -354,8 +348,8 @@ function App({ navbarComponent }: Props) { setStack(config)} + stack={settings.generatedCodeConfig} + setStack={(config) => setStack(config)} shouldDisableUpdates={ appState === AppState.CODING || appState === AppState.CODE_READY } diff --git a/frontend/src/components/ImportCodeSection.tsx b/frontend/src/components/ImportCodeSection.tsx index 04b2b5a..94d3cf1 100644 --- a/frontend/src/components/ImportCodeSection.tsx +++ b/frontend/src/components/ImportCodeSection.tsx @@ -11,18 +11,16 @@ import { } from "./ui/dialog"; import { Textarea } from "./ui/textarea"; import OutputSettingsSection from "./OutputSettingsSection"; -import { GeneratedCodeConfig } from "../types"; import toast from "react-hot-toast"; +import { Stack } from "../lib/stacks/types"; interface Props { - importFromCode: (code: string, stack: GeneratedCodeConfig) => void; + importFromCode: (code: string, stack: Stack) => void; } function ImportCodeSection({ importFromCode }: Props) { const [code, setCode] = useState(""); - const [stack, setStack] = useState( - undefined - ); + const [stack, setStack] = useState(undefined); const doImport = () => { if (code === "") { @@ -57,10 +55,8 @@ function ImportCodeSection({ importFromCode }: Props) { /> - setStack(config) - } + stack={stack} + setStack={(config: Stack) => setStack(config)} label="Stack:" shouldDisableUpdates={false} /> diff --git a/frontend/src/components/OutputSettingsSection.tsx b/frontend/src/components/OutputSettingsSection.tsx index 21b5c7b..b768b7a 100644 --- a/frontend/src/components/OutputSettingsSection.tsx +++ b/frontend/src/components/OutputSettingsSection.tsx @@ -1,3 +1,4 @@ +import React from "react"; import { Select, SelectContent, @@ -5,69 +6,36 @@ import { SelectItem, SelectTrigger, } from "./ui/select"; -import { GeneratedCodeConfig } from "../types"; import { addEvent } from "../lib/analytics"; import { Badge } from "./ui/badge"; +import { Stack } from "../lib/stacks/types"; +import { STACK_DESCRIPTIONS } from "../lib/stacks/descriptions"; -function generateDisplayComponent(config: GeneratedCodeConfig) { - switch (config) { - case GeneratedCodeConfig.HTML_TAILWIND: - return ( -
- HTML +{" "} - Tailwind -
- ); - case GeneratedCodeConfig.REACT_TAILWIND: - return ( -
- React +{" "} - Tailwind -
- ); - case GeneratedCodeConfig.BOOTSTRAP: - return ( -
- Bootstrap -
- ); - case GeneratedCodeConfig.IONIC_TAILWIND: - return ( -
- Ionic +{" "} - Tailwind -
- ); - case GeneratedCodeConfig.VUE_TAILWIND: - return ( -
- Vue +{" "} - Tailwind -
- ); - case GeneratedCodeConfig.SVG: - return ( -
- SVG -
- ); - default: { - const exhaustiveCheck: never = config; - throw new Error(`Unhandled case: ${exhaustiveCheck}`); - } - } +function generateDisplayComponent(stack: Stack) { + const stackComponents = STACK_DESCRIPTIONS[stack].components; + + return ( +
+ {stackComponents.map((component, index) => ( + + {component} + {index < stackComponents.length - 1 && " + "} + + ))} +
+ ); } interface Props { - generatedCodeConfig: GeneratedCodeConfig | undefined; - setGeneratedCodeConfig: (config: GeneratedCodeConfig) => void; + stack: Stack | undefined; + setStack: (config: Stack) => void; label?: string; shouldDisableUpdates?: boolean; } function OutputSettingsSection({ - generatedCodeConfig, - setGeneratedCodeConfig, + stack, + setStack, label = "Generating:", shouldDisableUpdates = false, }: Props) { @@ -76,53 +44,30 @@ function OutputSettingsSection({
{label} diff --git a/frontend/src/lib/stacks/descriptions.ts b/frontend/src/lib/stacks/descriptions.ts new file mode 100644 index 0000000..0f6cb31 --- /dev/null +++ b/frontend/src/lib/stacks/descriptions.ts @@ -0,0 +1,12 @@ +import { Stack } from "./types"; + +export const STACK_DESCRIPTIONS: { + [key in Stack]: { components: string[]; inBeta: boolean }; +} = { + html_tailwind: { components: ["HTML", "Tailwind"], inBeta: false }, + react_tailwind: { components: ["React", "Tailwind"], inBeta: false }, + bootstrap: { components: ["Bootstrap"], inBeta: false }, + vue_tailwind: { components: ["Vue", "Tailwind"], inBeta: true }, + ionic_tailwind: { components: ["Ionic", "Tailwind"], inBeta: true }, + svg: { components: ["SVG"], inBeta: true }, +}; diff --git a/frontend/src/lib/stacks/types.ts b/frontend/src/lib/stacks/types.ts new file mode 100644 index 0000000..7edb0b5 --- /dev/null +++ b/frontend/src/lib/stacks/types.ts @@ -0,0 +1,9 @@ +// Keep in sync with backend (prompts/types.py) +export enum Stack { + HTML_TAILWIND = "html_tailwind", + REACT_TAILWIND = "react_tailwind", + BOOTSTRAP = "bootstrap", + VUE_TAILWIND = "vue_tailwind", + IONIC_TAILWIND = "ionic_tailwind", + SVG = "svg", +} diff --git a/frontend/src/types.ts b/frontend/src/types.ts index a3b2f02..7e4eb23 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -1,25 +1,17 @@ +import { Stack } from "./lib/stacks/types"; + export enum EditorTheme { ESPRESSO = "espresso", COBALT = "cobalt", } -// Keep in sync with backend (prompts.py) -export enum GeneratedCodeConfig { - HTML_TAILWIND = "html_tailwind", - REACT_TAILWIND = "react_tailwind", - VUE_TAILWIND = "vue_tailwind", - BOOTSTRAP = "bootstrap", - IONIC_TAILWIND = "ionic_tailwind", - SVG = "svg", -} - export interface Settings { openAiApiKey: string | null; openAiBaseURL: string | null; screenshotOneApiKey: string | null; isImageGenerationEnabled: boolean; editorTheme: EditorTheme; - generatedCodeConfig: GeneratedCodeConfig; + generatedCodeConfig: Stack; // Only relevant for hosted version isTermOfServiceAccepted: boolean; accessCode: string | null;