Type the backend properly to avoid code duplication and ensure type errors when a stack configuration is not properly added

This commit is contained in:
Abi Raja 2024-01-08 15:30:53 -08:00
parent 15dc74a328
commit 1aeb8c4e14
5 changed files with 67 additions and 71 deletions

View File

@ -1,22 +1,10 @@
from typing import List, Literal, NoReturn, Union
from typing import List, NoReturn, Union
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
from prompts.imported_code_prompts import (
IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT,
IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT,
IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT,
IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
IMPORTED_CODE_SVG_SYSTEM_PROMPT,
)
from prompts.screenshot_system_prompts import (
BOOTSTRAP_SYSTEM_PROMPT,
IONIC_TAILWIND_SYSTEM_PROMPT,
REACT_TAILWIND_SYSTEM_PROMPT,
TAILWIND_SYSTEM_PROMPT,
SVG_SYSTEM_PROMPT,
VUE_TAILWIND_SYSTEM_PROMPT,
)
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
from prompts.types import Stack
USER_PROMPT = """
@ -27,39 +15,11 @@ SVG_USER_PROMPT = """
Generate code for a SVG that looks exactly like this.
"""
Stack = Literal[
"html_tailwind",
"react_tailwind",
"bootstrap",
"ionic_tailwind",
"vue_tailwind",
"svg",
]
def assert_never(x: NoReturn) -> NoReturn:
raise AssertionError(f"Stack is not one of available options: {x}")
def assemble_imported_code_prompt(
code: str, stack: Stack, result_image_data_url: Union[str, None] = None
) -> List[ChatCompletionMessageParam]:
system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
if stack == "html_tailwind":
system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
elif stack == "react_tailwind":
system_content = IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT
elif stack == "bootstrap":
system_content = IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT
elif stack == "ionic_tailwind":
system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
elif stack == "vue_tailwind":
# TODO: Fix this prompt to be vue tailwind
system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
elif stack == "svg":
system_content = IMPORTED_CODE_SVG_SYSTEM_PROMPT
else:
assert_never(stack)
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
user_content = (
"Here is the code of the app: " + code
@ -81,27 +41,11 @@ def assemble_imported_code_prompt(
def assemble_prompt(
image_data_url: str,
generated_code_config: Stack,
stack: Stack,
result_image_data_url: Union[str, None] = None,
) -> List[ChatCompletionMessageParam]:
# Set the system prompt based on the output settings
system_content = TAILWIND_SYSTEM_PROMPT
if generated_code_config == "html_tailwind":
system_content = TAILWIND_SYSTEM_PROMPT
elif generated_code_config == "react_tailwind":
system_content = REACT_TAILWIND_SYSTEM_PROMPT
elif generated_code_config == "bootstrap":
system_content = BOOTSTRAP_SYSTEM_PROMPT
elif generated_code_config == "ionic_tailwind":
system_content = IONIC_TAILWIND_SYSTEM_PROMPT
elif generated_code_config == "vue_tailwind":
system_content = VUE_TAILWIND_SYSTEM_PROMPT
elif generated_code_config == "svg":
system_content = SVG_SYSTEM_PROMPT
else:
assert_never(generated_code_config)
user_prompt = USER_PROMPT if generated_code_config != "svg" else SVG_USER_PROMPT
system_content = SYSTEM_PROMPTS[stack]
user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT
user_content: List[ChatCompletionContentPartParam] = [
{

View File

@ -1,3 +1,6 @@
from prompts.types import SystemPrompts
IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT = """
You are an expert Tailwind developer.
@ -90,3 +93,13 @@ You are an expert at building SVGs.
Return only the full code in <svg></svg> tags.
Do not include markdown "```" or "```svg" at the start or end.
"""
IMPORTED_CODE_SYSTEM_PROMPTS = SystemPrompts(
html_tailwind=IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
react_tailwind=IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT,
bootstrap=IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT,
ionic_tailwind=IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT,
# TODO: Fix this prompt to actually be Vue
vue_tailwind=IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
svg=IMPORTED_CODE_SVG_SYSTEM_PROMPT,
)

View File

@ -1,4 +1,7 @@
TAILWIND_SYSTEM_PROMPT = """
from prompts.types import SystemPrompts
HTML_TAILWIND_SYSTEM_PROMPT = """
You are an expert Tailwind developer
You take screenshots of a reference web page from the user, and then build single page apps
using Tailwind, HTML and JS.
@ -170,3 +173,13 @@ padding, margin, border, etc. Match the colors and sizes exactly.
Return only the full code in <svg></svg> tags.
Do not include markdown "```" or "```svg" at the start or end.
"""
SYSTEM_PROMPTS = SystemPrompts(
html_tailwind=HTML_TAILWIND_SYSTEM_PROMPT,
react_tailwind=REACT_TAILWIND_SYSTEM_PROMPT,
bootstrap=BOOTSTRAP_SYSTEM_PROMPT,
ionic_tailwind=IONIC_TAILWIND_SYSTEM_PROMPT,
vue_tailwind=VUE_TAILWIND_SYSTEM_PROMPT,
svg=SVG_SYSTEM_PROMPT,
)

20
backend/prompts/types.py Normal file
View File

@ -0,0 +1,20 @@
from typing import Literal, TypedDict
class SystemPrompts(TypedDict):
html_tailwind: str
react_tailwind: str
bootstrap: str
ionic_tailwind: str
vue_tailwind: str
svg: str
Stack = Literal[
"html_tailwind",
"react_tailwind",
"bootstrap",
"ionic_tailwind",
"vue_tailwind",
"svg",
]

View File

@ -6,12 +6,13 @@ from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
from llm import stream_openai_response
from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
from typing import Dict, List
from typing import Dict, List, cast, get_args
from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_imported_code_prompt, assemble_prompt
from access_token import validate_access_token
from datetime import datetime
import json
from prompts.types import Stack
from utils import pprint_prompt # type: ignore
@ -96,6 +97,13 @@ async def stream_code(websocket: WebSocket):
)
return
# Validate the generated code config
if not generated_code_config in get_args(Stack):
await throw_error(f"Invalid generated code config: {generated_code_config}")
return
# Cast the variable to the Stack type
valid_stack = cast(Stack, generated_code_config)
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
openai_base_url = None
# Disable user-specified OpenAI Base URL in prod
@ -131,7 +139,7 @@ async def stream_code(websocket: WebSocket):
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
original_imported_code = params["history"][0]
prompt_messages = assemble_imported_code_prompt(
original_imported_code, generated_code_config
original_imported_code, valid_stack
)
for index, text in enumerate(params["history"][1:]):
if index % 2 == 0:
@ -150,12 +158,10 @@ async def stream_code(websocket: WebSocket):
try:
if params.get("resultImage") and params["resultImage"]:
prompt_messages = assemble_prompt(
params["image"], generated_code_config, params["resultImage"]
params["image"], valid_stack, params["resultImage"]
)
else:
prompt_messages = assemble_prompt(
params["image"], generated_code_config
)
prompt_messages = assemble_prompt(params["image"], valid_stack)
except:
await websocket.send_json(
{