Merge branch 'main' into hosted

This commit is contained in:
Abi Raja 2024-01-09 09:52:32 -08:00
commit a2c0ac1171
26 changed files with 396 additions and 293 deletions

View File

@ -1,6 +1,6 @@
# screenshot-to-code # screenshot-to-code
This simple app converts a screenshot to code (HTML/Tailwind CSS, or React or Bootstrap). It uses GPT-4 Vision to generate the code and DALL-E 3 to generate similar-looking images. You can now also enter a URL to clone a live website! This simple app converts a screenshot to code (HTML/Tailwind CSS, or React or Bootstrap or Vue). It uses GPT-4 Vision to generate the code and DALL-E 3 to generate similar-looking images. You can now also enter a URL to clone a live website!
https://github.com/abi/screenshot-to-code/assets/23818/6cebadae-2fe3-4986-ac6a-8fb9db030045 https://github.com/abi/screenshot-to-code/assets/23818/6cebadae-2fe3-4986-ac6a-8fb9db030045
@ -38,6 +38,12 @@ poetry shell
poetry run uvicorn main:app --reload --port 7001 poetry run uvicorn main:app --reload --port 7001
``` ```
You can also run the backend (when you're in `backend`):
```bash
poetry run pyright
```
Run the frontend: Run the frontend:
```bash ```bash
@ -58,7 +64,7 @@ MOCK=true poetry run uvicorn main:app --reload --port 7001
## Configuration ## Configuration
* You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog - You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
## Docker ## Docker

2
backend/.gitignore vendored
View File

@ -153,4 +153,4 @@ cython_debug/
# Temporary eval output # Temporary eval output
evals evals_data

View File

@ -1,3 +1,7 @@
# Run the type checker
poetry run pyright
# Run tests # Run tests
poetry run pytest poetry run pytest

View File

@ -1 +0,0 @@
EVALS_DIR = "./evals"

View File

1
backend/evals/config.py Normal file
View File

@ -0,0 +1 @@
EVALS_DIR = "./evals_data"

29
backend/evals/core.py Normal file
View File

@ -0,0 +1,29 @@
import os
from llm import stream_openai_response
from prompts import assemble_prompt
from prompts.types import Stack
from utils import pprint_prompt
async def generate_code_core(image_url: str, stack: Stack) -> str:
prompt_messages = assemble_prompt(image_url, stack)
openai_api_key = os.environ.get("OPENAI_API_KEY")
openai_base_url = None
pprint_prompt(prompt_messages)
async def process_chunk(content: str):
pass
if not openai_api_key:
raise Exception("OpenAI API key not found")
completion = await stream_openai_response(
prompt_messages,
api_key=openai_api_key,
base_url=openai_base_url,
callback=lambda x: process_chunk(x),
)
return completion

59
backend/poetry.lock generated
View File

@ -213,16 +213,31 @@ files = [
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
] ]
[[package]]
name = "nodeenv"
version = "1.8.0"
description = "Node.js virtual environment builder"
category = "dev"
optional = false
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
files = [
{file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
{file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
]
[package.dependencies]
setuptools = "*"
[[package]] [[package]]
name = "openai" name = "openai"
version = "1.6.1" version = "1.7.0"
description = "The official Python library for the openai API" description = "The official Python library for the openai API"
category = "main" category = "main"
optional = false optional = false
python-versions = ">=3.7.1" python-versions = ">=3.7.1"
files = [ files = [
{file = "openai-1.6.1-py3-none-any.whl", hash = "sha256:bc9f774838d67ac29fb24cdeb2d58faf57de8b311085dcd1348f7aa02a96c7ee"}, {file = "openai-1.7.0-py3-none-any.whl", hash = "sha256:2282e8e15acb05df79cccba330c025b8e84284c7ec1f3fa31f167a8479066333"},
{file = "openai-1.6.1.tar.gz", hash = "sha256:d553ca9dbf9486b08e75b09e8671e4f638462aaadccfced632bf490fc3d75fa2"}, {file = "openai-1.7.0.tar.gz", hash = "sha256:f2a8dcb739e8620c9318a2c6304ea72aebb572ba02fa1d586344405e80d567d3"},
] ]
[package.dependencies] [package.dependencies]
@ -318,6 +333,25 @@ typing-extensions = ">=4.2.0"
dotenv = ["python-dotenv (>=0.10.4)"] dotenv = ["python-dotenv (>=0.10.4)"]
email = ["email-validator (>=1.0.3)"] email = ["email-validator (>=1.0.3)"]
[[package]]
name = "pyright"
version = "1.1.345"
description = "Command line wrapper for pyright"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "pyright-1.1.345-py3-none-any.whl", hash = "sha256:00891361baf58698aa660d9374823d65782823ceb4a65515ff5dd159b0d4d2b1"},
{file = "pyright-1.1.345.tar.gz", hash = "sha256:bb8c80671cdaeb913142b49642a741959f3fcd728c99814631c2bde3a7864938"},
]
[package.dependencies]
nodeenv = ">=1.6.0"
[package.extras]
all = ["twine (>=3.4.1)"]
dev = ["twine (>=3.4.1)"]
[[package]] [[package]]
name = "pytest" name = "pytest"
version = "7.4.4" version = "7.4.4"
@ -403,6 +437,23 @@ starlette = ["starlette (>=0.19.1)"]
starlite = ["starlite (>=1.48)"] starlite = ["starlite (>=1.48)"]
tornado = ["tornado (>=5)"] tornado = ["tornado (>=5)"]
[[package]]
name = "setuptools"
version = "69.0.3"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
category = "dev"
optional = false
python-versions = ">=3.8"
files = [
{file = "setuptools-69.0.3-py3-none-any.whl", hash = "sha256:385eb4edd9c9d5c17540511303e39a147ce2fc04bc55289c322b9e5904fe2c05"},
{file = "setuptools-69.0.3.tar.gz", hash = "sha256:be1af57fc409f93647f2e8e4573a142ed38724b8cdd389706a867bb4efcf1e78"},
]
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
[[package]] [[package]]
name = "sniffio" name = "sniffio"
version = "1.3.0" version = "1.3.0"
@ -612,4 +663,4 @@ files = [
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "ce29c56f8cb6ba1d0480489da7ab123aa620a146d464b2904cef8b7bcef82a05" content-hash = "d69f678c50d40b06ddb2080231097b553fc7b1930dc69b3b07e18ca31ba010cf"

View File

@ -1,119 +0,0 @@
from typing import List, Union
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
from imported_code_prompts import (
IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT,
IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT,
IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT,
IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
IMPORTED_CODE_SVG_SYSTEM_PROMPT,
)
from screenshot_system_prompts import (
BOOTSTRAP_SYSTEM_PROMPT,
IONIC_TAILWIND_SYSTEM_PROMPT,
REACT_TAILWIND_SYSTEM_PROMPT,
TAILWIND_SYSTEM_PROMPT,
SVG_SYSTEM_PROMPT,
VUE_TAILWIND_SYSTEM_PROMPT,
)
USER_PROMPT = """
Generate code for a web page that looks exactly like this.
"""
SVG_USER_PROMPT = """
Generate code for a SVG that looks exactly like this.
"""
def assemble_imported_code_prompt(
code: str, stack: str, result_image_data_url: Union[str, None] = None
) -> List[ChatCompletionMessageParam]:
system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
if stack == "html_tailwind":
system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
elif stack == "react_tailwind":
system_content = IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT
elif stack == "bootstrap":
system_content = IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT
elif stack == "ionic_tailwind":
system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
elif stack == "svg":
system_content = IMPORTED_CODE_SVG_SYSTEM_PROMPT
else:
raise Exception("Code config is not one of available options")
user_content = (
"Here is the code of the app: " + code
if stack != "svg"
else "Here is the code of the SVG: " + code
)
return [
{
"role": "system",
"content": system_content,
},
{
"role": "user",
"content": user_content,
},
]
# TODO: Use result_image_data_url
def assemble_prompt(
image_data_url: str,
generated_code_config: str,
result_image_data_url: Union[str, None] = None,
) -> List[ChatCompletionMessageParam]:
# Set the system prompt based on the output settings
system_content = TAILWIND_SYSTEM_PROMPT
if generated_code_config == "html_tailwind":
system_content = TAILWIND_SYSTEM_PROMPT
elif generated_code_config == "react_tailwind":
system_content = REACT_TAILWIND_SYSTEM_PROMPT
elif generated_code_config == "bootstrap":
system_content = BOOTSTRAP_SYSTEM_PROMPT
elif generated_code_config == "ionic_tailwind":
system_content = IONIC_TAILWIND_SYSTEM_PROMPT
elif generated_code_config == "vue_tailwind":
system_content = VUE_TAILWIND_SYSTEM_PROMPT
elif generated_code_config == "svg":
system_content = SVG_SYSTEM_PROMPT
else:
raise Exception("Code config is not one of available options")
user_prompt = USER_PROMPT if generated_code_config != "svg" else SVG_USER_PROMPT
user_content: List[ChatCompletionContentPartParam] = [
{
"type": "image_url",
"image_url": {"url": image_data_url, "detail": "high"},
},
{
"type": "text",
"text": user_prompt,
},
]
# Include the result image if it exists
if result_image_data_url:
user_content.insert(
1,
{
"type": "image_url",
"image_url": {"url": result_image_data_url, "detail": "high"},
},
)
return [
{
"role": "system",
"content": system_content,
},
{
"role": "user",
"content": user_content,
},
]

View File

@ -0,0 +1,79 @@
from typing import List, NoReturn, Union
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
from prompts.types import Stack
USER_PROMPT = """
Generate code for a web page that looks exactly like this.
"""
SVG_USER_PROMPT = """
Generate code for a SVG that looks exactly like this.
"""
def assemble_imported_code_prompt(
code: str, stack: Stack, result_image_data_url: Union[str, None] = None
) -> List[ChatCompletionMessageParam]:
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
user_content = (
"Here is the code of the app: " + code
if stack != "svg"
else "Here is the code of the SVG: " + code
)
return [
{
"role": "system",
"content": system_content,
},
{
"role": "user",
"content": user_content,
},
]
# TODO: Use result_image_data_url
def assemble_prompt(
image_data_url: str,
stack: Stack,
result_image_data_url: Union[str, None] = None,
) -> List[ChatCompletionMessageParam]:
system_content = SYSTEM_PROMPTS[stack]
user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT
user_content: List[ChatCompletionContentPartParam] = [
{
"type": "image_url",
"image_url": {"url": image_data_url, "detail": "high"},
},
{
"type": "text",
"text": user_prompt,
},
]
# Include the result image if it exists
if result_image_data_url:
user_content.insert(
1,
{
"type": "image_url",
"image_url": {"url": result_image_data_url, "detail": "high"},
},
)
return [
{
"role": "system",
"content": system_content,
},
{
"role": "user",
"content": user_content,
},
]

View File

@ -1,3 +1,6 @@
from prompts.types import SystemPrompts
IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT = """ IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT = """
You are an expert Tailwind developer. You are an expert Tailwind developer.
@ -79,6 +82,38 @@ Return only the full code in <html></html> tags.
Do not include markdown "```" or "```html" at the start or end. Do not include markdown "```" or "```html" at the start or end.
""" """
IMPORTED_CODE_VUE_TAILWIND_SYSTEM_PROMPT = """
You are an expert Vue/Tailwind developer.
- Do not add comments in the code such as "<!-- Add other navigation links as needed -->" and "<!-- ... other news items ... -->" in place of writing the full code. WRITE THE FULL CODE.
- Repeat elements as needed. For example, if there are 15 items, the code should have 15 items. DO NOT LEAVE comments like "<!-- Repeat for each news item -->" or bad things will happen.
- For images, use placeholder images from https://placehold.co and include a detailed description of the image in the alt text so that an image generation AI can generate the image later.
In terms of libraries,
- Use these script to include Vue so that it can run on a standalone page:
<script src="https://registry.npmmirror.com/vue/3.3.11/files/dist/vue.global.js"></script>
- Use Vue using the global build like so:
<div id="app">{{ message }}</div>
<script>
const { createApp, ref } = Vue
createApp({
setup() {
const message = ref('Hello vue!')
return {
message
}
}
}).mount('#app')
</script>
- Use this script to include Tailwind: <script src="https://cdn.tailwindcss.com"></script>
- You can use Google Fonts
- Font Awesome for icons: <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css"></link>
Return only the full code in <html></html> tags.
Do not include markdown "```" or "```html" at the start or end.
The return result must only include the code."""
IMPORTED_CODE_SVG_SYSTEM_PROMPT = """ IMPORTED_CODE_SVG_SYSTEM_PROMPT = """
You are an expert at building SVGs. You are an expert at building SVGs.
@ -90,3 +125,12 @@ You are an expert at building SVGs.
Return only the full code in <svg></svg> tags. Return only the full code in <svg></svg> tags.
Do not include markdown "```" or "```svg" at the start or end. Do not include markdown "```" or "```svg" at the start or end.
""" """
IMPORTED_CODE_SYSTEM_PROMPTS = SystemPrompts(
html_tailwind=IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
react_tailwind=IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT,
bootstrap=IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT,
ionic_tailwind=IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT,
vue_tailwind=IMPORTED_CODE_VUE_TAILWIND_SYSTEM_PROMPT,
svg=IMPORTED_CODE_SVG_SYSTEM_PROMPT,
)

View File

@ -1,4 +1,7 @@
TAILWIND_SYSTEM_PROMPT = """ from prompts.types import SystemPrompts
HTML_TAILWIND_SYSTEM_PROMPT = """
You are an expert Tailwind developer You are an expert Tailwind developer
You take screenshots of a reference web page from the user, and then build single page apps You take screenshots of a reference web page from the user, and then build single page apps
using Tailwind, HTML and JS. using Tailwind, HTML and JS.
@ -170,3 +173,13 @@ padding, margin, border, etc. Match the colors and sizes exactly.
Return only the full code in <svg></svg> tags. Return only the full code in <svg></svg> tags.
Do not include markdown "```" or "```svg" at the start or end. Do not include markdown "```" or "```svg" at the start or end.
""" """
SYSTEM_PROMPTS = SystemPrompts(
html_tailwind=HTML_TAILWIND_SYSTEM_PROMPT,
react_tailwind=REACT_TAILWIND_SYSTEM_PROMPT,
bootstrap=BOOTSTRAP_SYSTEM_PROMPT,
ionic_tailwind=IONIC_TAILWIND_SYSTEM_PROMPT,
vue_tailwind=VUE_TAILWIND_SYSTEM_PROMPT,
svg=SVG_SYSTEM_PROMPT,
)

View File

@ -253,6 +253,39 @@ Return only the full code in <html></html> tags.
Do not include markdown "```" or "```html" at the start or end. Do not include markdown "```" or "```html" at the start or end.
""" """
IMPORTED_CODE_VUE_TAILWIND_PROMPT = """
You are an expert Vue/Tailwind developer.
- Do not add comments in the code such as "<!-- Add other navigation links as needed -->" and "<!-- ... other news items ... -->" in place of writing the full code. WRITE THE FULL CODE.
- Repeat elements as needed. For example, if there are 15 items, the code should have 15 items. DO NOT LEAVE comments like "<!-- Repeat for each news item -->" or bad things will happen.
- For images, use placeholder images from https://placehold.co and include a detailed description of the image in the alt text so that an image generation AI can generate the image later.
In terms of libraries,
- Use these script to include Vue so that it can run on a standalone page:
<script src="https://registry.npmmirror.com/vue/3.3.11/files/dist/vue.global.js"></script>
- Use Vue using the global build like so:
<div id="app">{{ message }}</div>
<script>
const { createApp, ref } = Vue
createApp({
setup() {
const message = ref('Hello vue!')
return {
message
}
}
}).mount('#app')
</script>
- Use this script to include Tailwind: <script src="https://cdn.tailwindcss.com"></script>
- You can use Google Fonts
- Font Awesome for icons: <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css"></link>
Return only the full code in <html></html> tags.
Do not include markdown "```" or "```html" at the start or end.
The return result must only include the code."""
IMPORTED_CODE_SVG_SYSTEM_PROMPT = """ IMPORTED_CODE_SVG_SYSTEM_PROMPT = """
You are an expert at building SVGs. You are an expert at building SVGs.
@ -299,9 +332,11 @@ def test_prompts():
assert ionic_tailwind[0]["content"] == IONIC_TAILWIND_SYSTEM_PROMPT assert ionic_tailwind[0]["content"] == IONIC_TAILWIND_SYSTEM_PROMPT
assert ionic_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore assert ionic_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore
vue = assemble_prompt("image_data_url", "vue_tailwind", "result_image_data_url") vue_tailwind = assemble_prompt(
assert vue[0]["content"] == VUE_TAILWIND_SYSTEM_PROMPT "image_data_url", "vue_tailwind", "result_image_data_url"
assert vue[1]["content"][2]["text"] == USER_PROMPT # type: ignore )
assert vue_tailwind[0]["content"] == VUE_TAILWIND_SYSTEM_PROMPT
assert vue_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore
svg_prompt = assemble_prompt("image_data_url", "svg", "result_image_data_url") svg_prompt = assemble_prompt("image_data_url", "svg", "result_image_data_url")
assert svg_prompt[0]["content"] == SVG_SYSTEM_PROMPT assert svg_prompt[0]["content"] == SVG_SYSTEM_PROMPT
@ -345,6 +380,15 @@ def test_imported_code_prompts():
] ]
assert ionic_tailwind == expected_ionic_tailwind assert ionic_tailwind == expected_ionic_tailwind
vue_tailwind = assemble_imported_code_prompt(
"code", "vue_tailwind", "result_image_data_url"
)
expected_vue_tailwind = [
{"role": "system", "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT},
{"role": "user", "content": "Here is the code of the app: code"},
]
assert vue_tailwind == expected_vue_tailwind
svg = assemble_imported_code_prompt("code", "svg", "result_image_data_url") svg = assemble_imported_code_prompt("code", "svg", "result_image_data_url")
expected_svg = [ expected_svg = [
{"role": "system", "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT}, {"role": "system", "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT},

20
backend/prompts/types.py Normal file
View File

@ -0,0 +1,20 @@
from typing import Literal, TypedDict
class SystemPrompts(TypedDict):
html_tailwind: str
react_tailwind: str
bootstrap: str
ionic_tailwind: str
vue_tailwind: str
svg: str
Stack = Literal[
"html_tailwind",
"react_tailwind",
"bootstrap",
"ionic_tailwind",
"vue_tailwind",
"svg",
]

View File

@ -18,6 +18,7 @@ sentry-sdk = {extras = ["fastapi"], version = "^1.38.0"}
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pytest = "^7.4.3" pytest = "^7.4.3"
pyright = "^1.1.345"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]

View File

@ -0,0 +1,3 @@
{
"exclude": ["image_generation.py"]
}

View File

@ -1,8 +1,8 @@
import os import os
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
from eval_utils import image_to_data_url from evals.utils import image_to_data_url
from eval_config import EVALS_DIR from evals.config import EVALS_DIR
router = APIRouter() router = APIRouter()

View File

@ -6,7 +6,7 @@ from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
from llm import stream_openai_response from llm import stream_openai_response
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion from mock_llm import mock_completion
from typing import Dict, List from typing import Dict, List, cast, get_args
from image_generation import create_alt_url_mapping, generate_images from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_imported_code_prompt, assemble_prompt from prompts import assemble_imported_code_prompt, assemble_prompt
from access_token import validate_access_token from access_token import validate_access_token
@ -14,6 +14,7 @@ from datetime import datetime
import json import json
from routes.logging_utils import PaymentMethod, send_to_saas_backend from routes.logging_utils import PaymentMethod, send_to_saas_backend
from routes.saas_utils import does_user_have_subscription_credits from routes.saas_utils import does_user_have_subscription_credits
from prompts.types import Stack
from utils import pprint_prompt # type: ignore from utils import pprint_prompt # type: ignore
@ -124,6 +125,13 @@ async def stream_code(websocket: WebSocket):
) )
return return
# Validate the generated code config
if not generated_code_config in get_args(Stack):
await throw_error(f"Invalid generated code config: {generated_code_config}")
return
# Cast the variable to the Stack type
valid_stack = cast(Stack, generated_code_config)
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
openai_base_url = None openai_base_url = None
# Disable user-specified OpenAI Base URL in prod # Disable user-specified OpenAI Base URL in prod
@ -159,7 +167,7 @@ async def stream_code(websocket: WebSocket):
if params.get("isImportedFromCode") and params["isImportedFromCode"]: if params.get("isImportedFromCode") and params["isImportedFromCode"]:
original_imported_code = params["history"][0] original_imported_code = params["history"][0]
prompt_messages = assemble_imported_code_prompt( prompt_messages = assemble_imported_code_prompt(
original_imported_code, generated_code_config original_imported_code, valid_stack
) )
for index, text in enumerate(params["history"][1:]): for index, text in enumerate(params["history"][1:]):
if index % 2 == 0: if index % 2 == 0:
@ -178,12 +186,10 @@ async def stream_code(websocket: WebSocket):
try: try:
if params.get("resultImage") and params["resultImage"]: if params.get("resultImage") and params["resultImage"]:
prompt_messages = assemble_prompt( prompt_messages = assemble_prompt(
params["image"], generated_code_config, params["resultImage"] params["image"], valid_stack, params["resultImage"]
) )
else: else:
prompt_messages = assemble_prompt( prompt_messages = assemble_prompt(params["image"], valid_stack)
params["image"], generated_code_config
)
except: except:
await websocket.send_json( await websocket.send_json(
{ {

View File

@ -1,41 +1,15 @@
# Load environment variables first # Load environment variables first
from typing import Any, Coroutine
from dotenv import load_dotenv from dotenv import load_dotenv
from eval_config import EVALS_DIR
from eval_utils import image_to_data_url
load_dotenv() load_dotenv()
import os import os
from llm import stream_openai_response from typing import Any, Coroutine
from prompts import assemble_prompt
import asyncio import asyncio
from utils import pprint_prompt from evals.config import EVALS_DIR
from evals.core import generate_code_core
from evals.utils import image_to_data_url
async def generate_code_core(image_url: str, stack: str) -> str:
prompt_messages = assemble_prompt(image_url, stack)
openai_api_key = os.environ.get("OPENAI_API_KEY")
openai_base_url = None
pprint_prompt(prompt_messages)
async def process_chunk(content: str):
pass
if not openai_api_key:
raise Exception("OpenAI API key not found")
completion = await stream_openai_response(
prompt_messages,
api_key=openai_api_key,
base_url=openai_base_url,
callback=lambda x: process_chunk(x),
)
return completion
async def main(): async def main():

View File

@ -17,13 +17,7 @@ import { Button } from "@/components/ui/button";
import { Textarea } from "@/components/ui/textarea"; import { Textarea } from "@/components/ui/textarea";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "./components/ui/tabs"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "./components/ui/tabs";
import SettingsDialog from "./components/SettingsDialog"; import SettingsDialog from "./components/SettingsDialog";
import { import { AppState, CodeGenerationParams, EditorTheme, Settings } from "./types";
AppState,
CodeGenerationParams,
EditorTheme,
GeneratedCodeConfig,
Settings,
} from "./types";
import { IS_RUNNING_ON_CLOUD } from "./config"; import { IS_RUNNING_ON_CLOUD } from "./config";
import { PicoBadge } from "./components/PicoBadge"; import { PicoBadge } from "./components/PicoBadge";
import { OnboardingNote } from "./components/OnboardingNote"; import { OnboardingNote } from "./components/OnboardingNote";
@ -42,6 +36,7 @@ import toast from "react-hot-toast";
import ImportCodeSection from "./components/ImportCodeSection"; import ImportCodeSection from "./components/ImportCodeSection";
import { useAuth } from "@clerk/clerk-react"; import { useAuth } from "@clerk/clerk-react";
import { useStore } from "./store/store"; import { useStore } from "./store/store";
import { Stack } from "./lib/stacks/types";
const IS_OPENAI_DOWN = false; const IS_OPENAI_DOWN = false;
@ -71,7 +66,7 @@ function App({ navbarComponent }: Props) {
screenshotOneApiKey: null, screenshotOneApiKey: null,
isImageGenerationEnabled: true, isImageGenerationEnabled: true,
editorTheme: EditorTheme.COBALT, editorTheme: EditorTheme.COBALT,
generatedCodeConfig: GeneratedCodeConfig.HTML_TAILWIND, generatedCodeConfig: Stack.HTML_TAILWIND,
// Only relevant for hosted version // Only relevant for hosted version
isTermOfServiceAccepted: true, isTermOfServiceAccepted: true,
accessCode: null, accessCode: null,
@ -96,7 +91,7 @@ function App({ navbarComponent }: Props) {
if (!settings.generatedCodeConfig) { if (!settings.generatedCodeConfig) {
setSettings((prev) => ({ setSettings((prev) => ({
...prev, ...prev,
generatedCodeConfig: GeneratedCodeConfig.HTML_TAILWIND, generatedCodeConfig: Stack.HTML_TAILWIND,
})); }));
} }
}, [settings.generatedCodeConfig, setSettings]); }, [settings.generatedCodeConfig, setSettings]);
@ -310,15 +305,14 @@ function App({ navbarComponent }: Props) {
})); }));
}; };
// TODO: Rename everything to "stack" instead of "config" function setStack(stack: Stack) {
function setStack(stack: GeneratedCodeConfig) {
setSettings((prev) => ({ setSettings((prev) => ({
...prev, ...prev,
generatedCodeConfig: stack, generatedCodeConfig: stack,
})); }));
} }
function importFromCode(code: string, stack: GeneratedCodeConfig) { function importFromCode(code: string, stack: Stack) {
setIsImportedFromCode(true); setIsImportedFromCode(true);
// Set up this project // Set up this project
@ -354,8 +348,8 @@ function App({ navbarComponent }: Props) {
</div> </div>
<OutputSettingsSection <OutputSettingsSection
generatedCodeConfig={settings.generatedCodeConfig} stack={settings.generatedCodeConfig}
setGeneratedCodeConfig={(config) => setStack(config)} setStack={(config) => setStack(config)}
shouldDisableUpdates={ shouldDisableUpdates={
appState === AppState.CODING || appState === AppState.CODE_READY appState === AppState.CODING || appState === AppState.CODE_READY
} }

View File

@ -11,18 +11,16 @@ import {
} from "./ui/dialog"; } from "./ui/dialog";
import { Textarea } from "./ui/textarea"; import { Textarea } from "./ui/textarea";
import OutputSettingsSection from "./OutputSettingsSection"; import OutputSettingsSection from "./OutputSettingsSection";
import { GeneratedCodeConfig } from "../types";
import toast from "react-hot-toast"; import toast from "react-hot-toast";
import { Stack } from "../lib/stacks/types";
interface Props { interface Props {
importFromCode: (code: string, stack: GeneratedCodeConfig) => void; importFromCode: (code: string, stack: Stack) => void;
} }
function ImportCodeSection({ importFromCode }: Props) { function ImportCodeSection({ importFromCode }: Props) {
const [code, setCode] = useState(""); const [code, setCode] = useState("");
const [stack, setStack] = useState<GeneratedCodeConfig | undefined>( const [stack, setStack] = useState<Stack | undefined>(undefined);
undefined
);
const doImport = () => { const doImport = () => {
if (code === "") { if (code === "") {
@ -57,10 +55,8 @@ function ImportCodeSection({ importFromCode }: Props) {
/> />
<OutputSettingsSection <OutputSettingsSection
generatedCodeConfig={stack} stack={stack}
setGeneratedCodeConfig={(config: GeneratedCodeConfig) => setStack={(config: Stack) => setStack(config)}
setStack(config)
}
label="Stack:" label="Stack:"
shouldDisableUpdates={false} shouldDisableUpdates={false}
/> />

View File

@ -1,3 +1,4 @@
import React from "react";
import { import {
Select, Select,
SelectContent, SelectContent,
@ -5,69 +6,36 @@ import {
SelectItem, SelectItem,
SelectTrigger, SelectTrigger,
} from "./ui/select"; } from "./ui/select";
import { GeneratedCodeConfig } from "../types";
import { addEvent } from "../lib/analytics"; import { addEvent } from "../lib/analytics";
import { Badge } from "./ui/badge"; import { Badge } from "./ui/badge";
import { Stack } from "../lib/stacks/types";
import { STACK_DESCRIPTIONS } from "../lib/stacks/descriptions";
function generateDisplayComponent(config: GeneratedCodeConfig) { function generateDisplayComponent(stack: Stack) {
switch (config) { const stackComponents = STACK_DESCRIPTIONS[stack].components;
case GeneratedCodeConfig.HTML_TAILWIND:
return ( return (
<div> <div>
<span className="font-semibold">HTML</span> +{" "} {stackComponents.map((component, index) => (
<span className="font-semibold">Tailwind</span> <React.Fragment key={index}>
</div> <span className="font-semibold">{component}</span>
); {index < stackComponents.length - 1 && " + "}
case GeneratedCodeConfig.REACT_TAILWIND: </React.Fragment>
return ( ))}
<div> </div>
<span className="font-semibold">React</span> +{" "} );
<span className="font-semibold">Tailwind</span>
</div>
);
case GeneratedCodeConfig.BOOTSTRAP:
return (
<div>
<span className="font-semibold">Bootstrap</span>
</div>
);
case GeneratedCodeConfig.IONIC_TAILWIND:
return (
<div>
<span className="font-semibold">Ionic</span> +{" "}
<span className="font-semibold">Tailwind</span>
</div>
);
case GeneratedCodeConfig.VUE_TAILWIND:
return (
<div>
<span className="font-semibold">Vue</span> +{" "}
<span className="font-semibold">Tailwind</span>
</div>
);
case GeneratedCodeConfig.SVG:
return (
<div>
<span className="font-semibold">SVG</span>
</div>
);
default: {
const exhaustiveCheck: never = config;
throw new Error(`Unhandled case: ${exhaustiveCheck}`);
}
}
} }
interface Props { interface Props {
generatedCodeConfig: GeneratedCodeConfig | undefined; stack: Stack | undefined;
setGeneratedCodeConfig: (config: GeneratedCodeConfig) => void; setStack: (config: Stack) => void;
label?: string; label?: string;
shouldDisableUpdates?: boolean; shouldDisableUpdates?: boolean;
} }
function OutputSettingsSection({ function OutputSettingsSection({
generatedCodeConfig, stack,
setGeneratedCodeConfig, setStack,
label = "Generating:", label = "Generating:",
shouldDisableUpdates = false, shouldDisableUpdates = false,
}: Props) { }: Props) {
@ -76,53 +44,30 @@ function OutputSettingsSection({
<div className="grid grid-cols-3 items-center gap-4"> <div className="grid grid-cols-3 items-center gap-4">
<span>{label}</span> <span>{label}</span>
<Select <Select
value={generatedCodeConfig} value={stack}
onValueChange={(value: string) => { onValueChange={(value: string) => {
addEvent("OutputSettings", { stack: value }); addEvent("OutputSettings", { stack: value });
setGeneratedCodeConfig(value as GeneratedCodeConfig); setStack(value as Stack);
}} }}
disabled={shouldDisableUpdates} disabled={shouldDisableUpdates}
> >
<SelectTrigger className="col-span-2" id="output-settings-js"> <SelectTrigger className="col-span-2" id="output-settings-js">
{generatedCodeConfig {stack ? generateDisplayComponent(stack) : "Select a stack"}
? generateDisplayComponent(generatedCodeConfig)
: "Select a stack"}
</SelectTrigger> </SelectTrigger>
<SelectContent> <SelectContent>
<SelectGroup> <SelectGroup>
<SelectItem value={GeneratedCodeConfig.HTML_TAILWIND}> {Object.values(Stack).map((stack) => (
{generateDisplayComponent(GeneratedCodeConfig.HTML_TAILWIND)} <SelectItem value={stack}>
</SelectItem> <div className="flex items-center">
<SelectItem value={GeneratedCodeConfig.REACT_TAILWIND}> {generateDisplayComponent(stack)}
{generateDisplayComponent(GeneratedCodeConfig.REACT_TAILWIND)} {STACK_DESCRIPTIONS[stack].inBeta && (
</SelectItem> <Badge className="ml-2" variant="secondary">
<SelectItem value={GeneratedCodeConfig.BOOTSTRAP}> Beta
{generateDisplayComponent(GeneratedCodeConfig.BOOTSTRAP)} </Badge>
</SelectItem> )}
<SelectItem value={GeneratedCodeConfig.VUE_TAILWIND}> </div>
<div className="flex items-center"> </SelectItem>
{generateDisplayComponent(GeneratedCodeConfig.VUE_TAILWIND)} ))}
<Badge className="ml-2" variant="secondary">
Beta
</Badge>
</div>
</SelectItem>
<SelectItem value={GeneratedCodeConfig.IONIC_TAILWIND}>
<div className="flex items-center">
{generateDisplayComponent(GeneratedCodeConfig.IONIC_TAILWIND)}
<Badge className="ml-2" variant="secondary">
Beta
</Badge>
</div>
</SelectItem>
<SelectItem value={GeneratedCodeConfig.SVG}>
<div className="flex items-center">
{generateDisplayComponent(GeneratedCodeConfig.SVG)}
<Badge className="ml-2" variant="secondary">
Beta
</Badge>
</div>
</SelectItem>
</SelectGroup> </SelectGroup>
</SelectContent> </SelectContent>
</Select> </Select>

View File

@ -0,0 +1,12 @@
import { Stack } from "./types";
export const STACK_DESCRIPTIONS: {
[key in Stack]: { components: string[]; inBeta: boolean };
} = {
html_tailwind: { components: ["HTML", "Tailwind"], inBeta: false },
react_tailwind: { components: ["React", "Tailwind"], inBeta: false },
bootstrap: { components: ["Bootstrap"], inBeta: false },
vue_tailwind: { components: ["Vue", "Tailwind"], inBeta: true },
ionic_tailwind: { components: ["Ionic", "Tailwind"], inBeta: true },
svg: { components: ["SVG"], inBeta: true },
};

View File

@ -0,0 +1,9 @@
// Keep in sync with backend (prompts/types.py)
export enum Stack {
HTML_TAILWIND = "html_tailwind",
REACT_TAILWIND = "react_tailwind",
BOOTSTRAP = "bootstrap",
VUE_TAILWIND = "vue_tailwind",
IONIC_TAILWIND = "ionic_tailwind",
SVG = "svg",
}

View File

@ -1,25 +1,17 @@
import { Stack } from "./lib/stacks/types";
export enum EditorTheme { export enum EditorTheme {
ESPRESSO = "espresso", ESPRESSO = "espresso",
COBALT = "cobalt", COBALT = "cobalt",
} }
// Keep in sync with backend (prompts.py)
export enum GeneratedCodeConfig {
HTML_TAILWIND = "html_tailwind",
REACT_TAILWIND = "react_tailwind",
VUE_TAILWIND = "vue_tailwind",
BOOTSTRAP = "bootstrap",
IONIC_TAILWIND = "ionic_tailwind",
SVG = "svg",
}
export interface Settings { export interface Settings {
openAiApiKey: string | null; openAiApiKey: string | null;
openAiBaseURL: string | null; openAiBaseURL: string | null;
screenshotOneApiKey: string | null; screenshotOneApiKey: string | null;
isImageGenerationEnabled: boolean; isImageGenerationEnabled: boolean;
editorTheme: EditorTheme; editorTheme: EditorTheme;
generatedCodeConfig: GeneratedCodeConfig; generatedCodeConfig: Stack;
// Only relevant for hosted version // Only relevant for hosted version
isTermOfServiceAccepted: boolean; isTermOfServiceAccepted: boolean;
accessCode: string | null; accessCode: string | null;