diff --git a/backend/prompts/__init__.py b/backend/prompts/__init__.py
index fbfa5ab..4f2e329 100644
--- a/backend/prompts/__init__.py
+++ b/backend/prompts/__init__.py
@@ -1,22 +1,10 @@
-from typing import List, Literal, NoReturn, Union
+from typing import List, NoReturn, Union
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
-from prompts.imported_code_prompts import (
- IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT,
- IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT,
- IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT,
- IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
- IMPORTED_CODE_SVG_SYSTEM_PROMPT,
-)
-from prompts.screenshot_system_prompts import (
- BOOTSTRAP_SYSTEM_PROMPT,
- IONIC_TAILWIND_SYSTEM_PROMPT,
- REACT_TAILWIND_SYSTEM_PROMPT,
- TAILWIND_SYSTEM_PROMPT,
- SVG_SYSTEM_PROMPT,
- VUE_TAILWIND_SYSTEM_PROMPT,
-)
+from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
+from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
+from prompts.types import Stack
USER_PROMPT = """
@@ -27,39 +15,11 @@ SVG_USER_PROMPT = """
Generate code for a SVG that looks exactly like this.
"""
-Stack = Literal[
- "html_tailwind",
- "react_tailwind",
- "bootstrap",
- "ionic_tailwind",
- "vue_tailwind",
- "svg",
-]
-
-
-def assert_never(x: NoReturn) -> NoReturn:
- raise AssertionError(f"Stack is not one of available options: {x}")
-
def assemble_imported_code_prompt(
code: str, stack: Stack, result_image_data_url: Union[str, None] = None
) -> List[ChatCompletionMessageParam]:
- system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
- if stack == "html_tailwind":
- system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
- elif stack == "react_tailwind":
- system_content = IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT
- elif stack == "bootstrap":
- system_content = IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT
- elif stack == "ionic_tailwind":
- system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
- elif stack == "vue_tailwind":
- # TODO: Fix this prompt to be vue tailwind
- system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
- elif stack == "svg":
- system_content = IMPORTED_CODE_SVG_SYSTEM_PROMPT
- else:
- assert_never(stack)
+ system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
user_content = (
"Here is the code of the app: " + code
@@ -81,27 +41,11 @@ def assemble_imported_code_prompt(
def assemble_prompt(
image_data_url: str,
- generated_code_config: Stack,
+ stack: Stack,
result_image_data_url: Union[str, None] = None,
) -> List[ChatCompletionMessageParam]:
- # Set the system prompt based on the output settings
- system_content = TAILWIND_SYSTEM_PROMPT
- if generated_code_config == "html_tailwind":
- system_content = TAILWIND_SYSTEM_PROMPT
- elif generated_code_config == "react_tailwind":
- system_content = REACT_TAILWIND_SYSTEM_PROMPT
- elif generated_code_config == "bootstrap":
- system_content = BOOTSTRAP_SYSTEM_PROMPT
- elif generated_code_config == "ionic_tailwind":
- system_content = IONIC_TAILWIND_SYSTEM_PROMPT
- elif generated_code_config == "vue_tailwind":
- system_content = VUE_TAILWIND_SYSTEM_PROMPT
- elif generated_code_config == "svg":
- system_content = SVG_SYSTEM_PROMPT
- else:
- assert_never(generated_code_config)
-
- user_prompt = USER_PROMPT if generated_code_config != "svg" else SVG_USER_PROMPT
+ system_content = SYSTEM_PROMPTS[stack]
+ user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT
user_content: List[ChatCompletionContentPartParam] = [
{
diff --git a/backend/prompts/imported_code_prompts.py b/backend/prompts/imported_code_prompts.py
index a8bfa6a..4cda821 100644
--- a/backend/prompts/imported_code_prompts.py
+++ b/backend/prompts/imported_code_prompts.py
@@ -1,3 +1,6 @@
+from prompts.types import SystemPrompts
+
+
IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT = """
You are an expert Tailwind developer.
@@ -90,3 +93,13 @@ You are an expert at building SVGs.
Return only the full code in tags.
Do not include markdown "```" or "```svg" at the start or end.
"""
+
+IMPORTED_CODE_SYSTEM_PROMPTS = SystemPrompts(
+ html_tailwind=IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
+ react_tailwind=IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT,
+ bootstrap=IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT,
+ ionic_tailwind=IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT,
+ # TODO: Fix this prompt to actually be Vue
+ vue_tailwind=IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
+ svg=IMPORTED_CODE_SVG_SYSTEM_PROMPT,
+)
diff --git a/backend/prompts/screenshot_system_prompts.py b/backend/prompts/screenshot_system_prompts.py
index 1dfcca4..fca91ba 100644
--- a/backend/prompts/screenshot_system_prompts.py
+++ b/backend/prompts/screenshot_system_prompts.py
@@ -1,4 +1,7 @@
-TAILWIND_SYSTEM_PROMPT = """
+from prompts.types import SystemPrompts
+
+
+HTML_TAILWIND_SYSTEM_PROMPT = """
You are an expert Tailwind developer
You take screenshots of a reference web page from the user, and then build single page apps
using Tailwind, HTML and JS.
@@ -170,3 +173,13 @@ padding, margin, border, etc. Match the colors and sizes exactly.
Return only the full code in tags.
Do not include markdown "```" or "```svg" at the start or end.
"""
+
+
+SYSTEM_PROMPTS = SystemPrompts(
+ html_tailwind=HTML_TAILWIND_SYSTEM_PROMPT,
+ react_tailwind=REACT_TAILWIND_SYSTEM_PROMPT,
+ bootstrap=BOOTSTRAP_SYSTEM_PROMPT,
+ ionic_tailwind=IONIC_TAILWIND_SYSTEM_PROMPT,
+ vue_tailwind=VUE_TAILWIND_SYSTEM_PROMPT,
+ svg=SVG_SYSTEM_PROMPT,
+)
diff --git a/backend/prompts/types.py b/backend/prompts/types.py
new file mode 100644
index 0000000..9068443
--- /dev/null
+++ b/backend/prompts/types.py
@@ -0,0 +1,20 @@
+from typing import Literal, TypedDict
+
+
+class SystemPrompts(TypedDict):
+ html_tailwind: str
+ react_tailwind: str
+ bootstrap: str
+ ionic_tailwind: str
+ vue_tailwind: str
+ svg: str
+
+
+Stack = Literal[
+ "html_tailwind",
+ "react_tailwind",
+ "bootstrap",
+ "ionic_tailwind",
+ "vue_tailwind",
+ "svg",
+]
diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py
index b260cc0..bd2e0c2 100644
--- a/backend/routes/generate_code.py
+++ b/backend/routes/generate_code.py
@@ -6,12 +6,13 @@ from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
from llm import stream_openai_response
from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
-from typing import Dict, List
+from typing import Dict, List, cast, get_args
from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_imported_code_prompt, assemble_prompt
from access_token import validate_access_token
from datetime import datetime
import json
+from prompts.types import Stack
from utils import pprint_prompt # type: ignore
@@ -96,6 +97,13 @@ async def stream_code(websocket: WebSocket):
)
return
+ # Validate the generated code config
+ if not generated_code_config in get_args(Stack):
+ await throw_error(f"Invalid generated code config: {generated_code_config}")
+ return
+ # Cast the variable to the Stack type
+ valid_stack = cast(Stack, generated_code_config)
+
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
openai_base_url = None
# Disable user-specified OpenAI Base URL in prod
@@ -131,7 +139,7 @@ async def stream_code(websocket: WebSocket):
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
original_imported_code = params["history"][0]
prompt_messages = assemble_imported_code_prompt(
- original_imported_code, generated_code_config
+ original_imported_code, valid_stack
)
for index, text in enumerate(params["history"][1:]):
if index % 2 == 0:
@@ -150,12 +158,10 @@ async def stream_code(websocket: WebSocket):
try:
if params.get("resultImage") and params["resultImage"]:
prompt_messages = assemble_prompt(
- params["image"], generated_code_config, params["resultImage"]
+ params["image"], valid_stack, params["resultImage"]
)
else:
- prompt_messages = assemble_prompt(
- params["image"], generated_code_config
- )
+ prompt_messages = assemble_prompt(params["image"], valid_stack)
except:
await websocket.send_json(
{