From 1aeb8c4e1448288f8546cc688a854f85e07b3a83 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Mon, 8 Jan 2024 15:30:53 -0800 Subject: [PATCH] Type the backend properly to avoid code duplication and ensure type errors when a stack configuration is not properly added --- backend/prompts/__init__.py | 72 +++----------------- backend/prompts/imported_code_prompts.py | 13 ++++ backend/prompts/screenshot_system_prompts.py | 15 +++- backend/prompts/types.py | 20 ++++++ backend/routes/generate_code.py | 18 +++-- 5 files changed, 67 insertions(+), 71 deletions(-) create mode 100644 backend/prompts/types.py diff --git a/backend/prompts/__init__.py b/backend/prompts/__init__.py index fbfa5ab..4f2e329 100644 --- a/backend/prompts/__init__.py +++ b/backend/prompts/__init__.py @@ -1,22 +1,10 @@ -from typing import List, Literal, NoReturn, Union +from typing import List, NoReturn, Union from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam -from prompts.imported_code_prompts import ( - IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT, - IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT, - IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT, - IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT, - IMPORTED_CODE_SVG_SYSTEM_PROMPT, -) -from prompts.screenshot_system_prompts import ( - BOOTSTRAP_SYSTEM_PROMPT, - IONIC_TAILWIND_SYSTEM_PROMPT, - REACT_TAILWIND_SYSTEM_PROMPT, - TAILWIND_SYSTEM_PROMPT, - SVG_SYSTEM_PROMPT, - VUE_TAILWIND_SYSTEM_PROMPT, -) +from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS +from prompts.screenshot_system_prompts import SYSTEM_PROMPTS +from prompts.types import Stack USER_PROMPT = """ @@ -27,39 +15,11 @@ SVG_USER_PROMPT = """ Generate code for a SVG that looks exactly like this. """ -Stack = Literal[ - "html_tailwind", - "react_tailwind", - "bootstrap", - "ionic_tailwind", - "vue_tailwind", - "svg", -] - - -def assert_never(x: NoReturn) -> NoReturn: - raise AssertionError(f"Stack is not one of available options: {x}") - def assemble_imported_code_prompt( code: str, stack: Stack, result_image_data_url: Union[str, None] = None ) -> List[ChatCompletionMessageParam]: - system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT - if stack == "html_tailwind": - system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT - elif stack == "react_tailwind": - system_content = IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT - elif stack == "bootstrap": - system_content = IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT - elif stack == "ionic_tailwind": - system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT - elif stack == "vue_tailwind": - # TODO: Fix this prompt to be vue tailwind - system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT - elif stack == "svg": - system_content = IMPORTED_CODE_SVG_SYSTEM_PROMPT - else: - assert_never(stack) + system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack] user_content = ( "Here is the code of the app: " + code @@ -81,27 +41,11 @@ def assemble_imported_code_prompt( def assemble_prompt( image_data_url: str, - generated_code_config: Stack, + stack: Stack, result_image_data_url: Union[str, None] = None, ) -> List[ChatCompletionMessageParam]: - # Set the system prompt based on the output settings - system_content = TAILWIND_SYSTEM_PROMPT - if generated_code_config == "html_tailwind": - system_content = TAILWIND_SYSTEM_PROMPT - elif generated_code_config == "react_tailwind": - system_content = REACT_TAILWIND_SYSTEM_PROMPT - elif generated_code_config == "bootstrap": - system_content = BOOTSTRAP_SYSTEM_PROMPT - elif generated_code_config == "ionic_tailwind": - system_content = IONIC_TAILWIND_SYSTEM_PROMPT - elif generated_code_config == "vue_tailwind": - system_content = VUE_TAILWIND_SYSTEM_PROMPT - elif generated_code_config == "svg": - system_content = SVG_SYSTEM_PROMPT - else: - assert_never(generated_code_config) - - user_prompt = USER_PROMPT if generated_code_config != "svg" else SVG_USER_PROMPT + system_content = SYSTEM_PROMPTS[stack] + user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT user_content: List[ChatCompletionContentPartParam] = [ { diff --git a/backend/prompts/imported_code_prompts.py b/backend/prompts/imported_code_prompts.py index a8bfa6a..4cda821 100644 --- a/backend/prompts/imported_code_prompts.py +++ b/backend/prompts/imported_code_prompts.py @@ -1,3 +1,6 @@ +from prompts.types import SystemPrompts + + IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT = """ You are an expert Tailwind developer. @@ -90,3 +93,13 @@ You are an expert at building SVGs. Return only the full code in tags. Do not include markdown "```" or "```svg" at the start or end. """ + +IMPORTED_CODE_SYSTEM_PROMPTS = SystemPrompts( + html_tailwind=IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT, + react_tailwind=IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT, + bootstrap=IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT, + ionic_tailwind=IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT, + # TODO: Fix this prompt to actually be Vue + vue_tailwind=IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT, + svg=IMPORTED_CODE_SVG_SYSTEM_PROMPT, +) diff --git a/backend/prompts/screenshot_system_prompts.py b/backend/prompts/screenshot_system_prompts.py index 1dfcca4..fca91ba 100644 --- a/backend/prompts/screenshot_system_prompts.py +++ b/backend/prompts/screenshot_system_prompts.py @@ -1,4 +1,7 @@ -TAILWIND_SYSTEM_PROMPT = """ +from prompts.types import SystemPrompts + + +HTML_TAILWIND_SYSTEM_PROMPT = """ You are an expert Tailwind developer You take screenshots of a reference web page from the user, and then build single page apps using Tailwind, HTML and JS. @@ -170,3 +173,13 @@ padding, margin, border, etc. Match the colors and sizes exactly. Return only the full code in tags. Do not include markdown "```" or "```svg" at the start or end. """ + + +SYSTEM_PROMPTS = SystemPrompts( + html_tailwind=HTML_TAILWIND_SYSTEM_PROMPT, + react_tailwind=REACT_TAILWIND_SYSTEM_PROMPT, + bootstrap=BOOTSTRAP_SYSTEM_PROMPT, + ionic_tailwind=IONIC_TAILWIND_SYSTEM_PROMPT, + vue_tailwind=VUE_TAILWIND_SYSTEM_PROMPT, + svg=SVG_SYSTEM_PROMPT, +) diff --git a/backend/prompts/types.py b/backend/prompts/types.py new file mode 100644 index 0000000..9068443 --- /dev/null +++ b/backend/prompts/types.py @@ -0,0 +1,20 @@ +from typing import Literal, TypedDict + + +class SystemPrompts(TypedDict): + html_tailwind: str + react_tailwind: str + bootstrap: str + ionic_tailwind: str + vue_tailwind: str + svg: str + + +Stack = Literal[ + "html_tailwind", + "react_tailwind", + "bootstrap", + "ionic_tailwind", + "vue_tailwind", + "svg", +] diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index b260cc0..bd2e0c2 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -6,12 +6,13 @@ from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE from llm import stream_openai_response from openai.types.chat import ChatCompletionMessageParam from mock_llm import mock_completion -from typing import Dict, List +from typing import Dict, List, cast, get_args from image_generation import create_alt_url_mapping, generate_images from prompts import assemble_imported_code_prompt, assemble_prompt from access_token import validate_access_token from datetime import datetime import json +from prompts.types import Stack from utils import pprint_prompt # type: ignore @@ -96,6 +97,13 @@ async def stream_code(websocket: WebSocket): ) return + # Validate the generated code config + if not generated_code_config in get_args(Stack): + await throw_error(f"Invalid generated code config: {generated_code_config}") + return + # Cast the variable to the Stack type + valid_stack = cast(Stack, generated_code_config) + # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. openai_base_url = None # Disable user-specified OpenAI Base URL in prod @@ -131,7 +139,7 @@ async def stream_code(websocket: WebSocket): if params.get("isImportedFromCode") and params["isImportedFromCode"]: original_imported_code = params["history"][0] prompt_messages = assemble_imported_code_prompt( - original_imported_code, generated_code_config + original_imported_code, valid_stack ) for index, text in enumerate(params["history"][1:]): if index % 2 == 0: @@ -150,12 +158,10 @@ async def stream_code(websocket: WebSocket): try: if params.get("resultImage") and params["resultImage"]: prompt_messages = assemble_prompt( - params["image"], generated_code_config, params["resultImage"] + params["image"], valid_stack, params["resultImage"] ) else: - prompt_messages = assemble_prompt( - params["image"], generated_code_config - ) + prompt_messages = assemble_prompt(params["image"], valid_stack) except: await websocket.send_json( {