improve type checking for stack on backend

This commit is contained in:
Abi Raja 2024-01-08 14:55:41 -08:00
parent adda6852f3
commit 15dc74a328

View File

@ -1,4 +1,4 @@
from typing import List, Union from typing import List, Literal, NoReturn, Union
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
@ -27,9 +27,22 @@ SVG_USER_PROMPT = """
Generate code for a SVG that looks exactly like this. Generate code for a SVG that looks exactly like this.
""" """
Stack = Literal[
"html_tailwind",
"react_tailwind",
"bootstrap",
"ionic_tailwind",
"vue_tailwind",
"svg",
]
def assert_never(x: NoReturn) -> NoReturn:
raise AssertionError(f"Stack is not one of available options: {x}")
def assemble_imported_code_prompt( def assemble_imported_code_prompt(
code: str, stack: str, result_image_data_url: Union[str, None] = None code: str, stack: Stack, result_image_data_url: Union[str, None] = None
) -> List[ChatCompletionMessageParam]: ) -> List[ChatCompletionMessageParam]:
system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
if stack == "html_tailwind": if stack == "html_tailwind":
@ -40,10 +53,13 @@ def assemble_imported_code_prompt(
system_content = IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT system_content = IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT
elif stack == "ionic_tailwind": elif stack == "ionic_tailwind":
system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
elif stack == "vue_tailwind":
# TODO: Fix this prompt to be vue tailwind
system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
elif stack == "svg": elif stack == "svg":
system_content = IMPORTED_CODE_SVG_SYSTEM_PROMPT system_content = IMPORTED_CODE_SVG_SYSTEM_PROMPT
else: else:
raise Exception("Code config is not one of available options") assert_never(stack)
user_content = ( user_content = (
"Here is the code of the app: " + code "Here is the code of the app: " + code
@ -65,7 +81,7 @@ def assemble_imported_code_prompt(
def assemble_prompt( def assemble_prompt(
image_data_url: str, image_data_url: str,
generated_code_config: str, generated_code_config: Stack,
result_image_data_url: Union[str, None] = None, result_image_data_url: Union[str, None] = None,
) -> List[ChatCompletionMessageParam]: ) -> List[ChatCompletionMessageParam]:
# Set the system prompt based on the output settings # Set the system prompt based on the output settings
@ -83,7 +99,7 @@ def assemble_prompt(
elif generated_code_config == "svg": elif generated_code_config == "svg":
system_content = SVG_SYSTEM_PROMPT system_content = SVG_SYSTEM_PROMPT
else: else:
raise Exception("Code config is not one of available options") assert_never(generated_code_config)
user_prompt = USER_PROMPT if generated_code_config != "svg" else SVG_USER_PROMPT user_prompt = USER_PROMPT if generated_code_config != "svg" else SVG_USER_PROMPT