Type the backend properly to avoid code duplication and ensure type errors when a stack configuration is not properly added
This commit is contained in:
parent
15dc74a328
commit
1aeb8c4e14
@ -1,22 +1,10 @@
|
|||||||
from typing import List, Literal, NoReturn, Union
|
from typing import List, NoReturn, Union
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
|
||||||
|
|
||||||
from prompts.imported_code_prompts import (
|
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
|
||||||
IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT,
|
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
|
||||||
IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT,
|
from prompts.types import Stack
|
||||||
IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT,
|
|
||||||
IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
|
|
||||||
IMPORTED_CODE_SVG_SYSTEM_PROMPT,
|
|
||||||
)
|
|
||||||
from prompts.screenshot_system_prompts import (
|
|
||||||
BOOTSTRAP_SYSTEM_PROMPT,
|
|
||||||
IONIC_TAILWIND_SYSTEM_PROMPT,
|
|
||||||
REACT_TAILWIND_SYSTEM_PROMPT,
|
|
||||||
TAILWIND_SYSTEM_PROMPT,
|
|
||||||
SVG_SYSTEM_PROMPT,
|
|
||||||
VUE_TAILWIND_SYSTEM_PROMPT,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
USER_PROMPT = """
|
USER_PROMPT = """
|
||||||
@ -27,39 +15,11 @@ SVG_USER_PROMPT = """
|
|||||||
Generate code for a SVG that looks exactly like this.
|
Generate code for a SVG that looks exactly like this.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Stack = Literal[
|
|
||||||
"html_tailwind",
|
|
||||||
"react_tailwind",
|
|
||||||
"bootstrap",
|
|
||||||
"ionic_tailwind",
|
|
||||||
"vue_tailwind",
|
|
||||||
"svg",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def assert_never(x: NoReturn) -> NoReturn:
|
|
||||||
raise AssertionError(f"Stack is not one of available options: {x}")
|
|
||||||
|
|
||||||
|
|
||||||
def assemble_imported_code_prompt(
|
def assemble_imported_code_prompt(
|
||||||
code: str, stack: Stack, result_image_data_url: Union[str, None] = None
|
code: str, stack: Stack, result_image_data_url: Union[str, None] = None
|
||||||
) -> List[ChatCompletionMessageParam]:
|
) -> List[ChatCompletionMessageParam]:
|
||||||
system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
|
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
|
||||||
if stack == "html_tailwind":
|
|
||||||
system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
|
|
||||||
elif stack == "react_tailwind":
|
|
||||||
system_content = IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT
|
|
||||||
elif stack == "bootstrap":
|
|
||||||
system_content = IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT
|
|
||||||
elif stack == "ionic_tailwind":
|
|
||||||
system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
|
|
||||||
elif stack == "vue_tailwind":
|
|
||||||
# TODO: Fix this prompt to be vue tailwind
|
|
||||||
system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
|
|
||||||
elif stack == "svg":
|
|
||||||
system_content = IMPORTED_CODE_SVG_SYSTEM_PROMPT
|
|
||||||
else:
|
|
||||||
assert_never(stack)
|
|
||||||
|
|
||||||
user_content = (
|
user_content = (
|
||||||
"Here is the code of the app: " + code
|
"Here is the code of the app: " + code
|
||||||
@ -81,27 +41,11 @@ def assemble_imported_code_prompt(
|
|||||||
|
|
||||||
def assemble_prompt(
|
def assemble_prompt(
|
||||||
image_data_url: str,
|
image_data_url: str,
|
||||||
generated_code_config: Stack,
|
stack: Stack,
|
||||||
result_image_data_url: Union[str, None] = None,
|
result_image_data_url: Union[str, None] = None,
|
||||||
) -> List[ChatCompletionMessageParam]:
|
) -> List[ChatCompletionMessageParam]:
|
||||||
# Set the system prompt based on the output settings
|
system_content = SYSTEM_PROMPTS[stack]
|
||||||
system_content = TAILWIND_SYSTEM_PROMPT
|
user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT
|
||||||
if generated_code_config == "html_tailwind":
|
|
||||||
system_content = TAILWIND_SYSTEM_PROMPT
|
|
||||||
elif generated_code_config == "react_tailwind":
|
|
||||||
system_content = REACT_TAILWIND_SYSTEM_PROMPT
|
|
||||||
elif generated_code_config == "bootstrap":
|
|
||||||
system_content = BOOTSTRAP_SYSTEM_PROMPT
|
|
||||||
elif generated_code_config == "ionic_tailwind":
|
|
||||||
system_content = IONIC_TAILWIND_SYSTEM_PROMPT
|
|
||||||
elif generated_code_config == "vue_tailwind":
|
|
||||||
system_content = VUE_TAILWIND_SYSTEM_PROMPT
|
|
||||||
elif generated_code_config == "svg":
|
|
||||||
system_content = SVG_SYSTEM_PROMPT
|
|
||||||
else:
|
|
||||||
assert_never(generated_code_config)
|
|
||||||
|
|
||||||
user_prompt = USER_PROMPT if generated_code_config != "svg" else SVG_USER_PROMPT
|
|
||||||
|
|
||||||
user_content: List[ChatCompletionContentPartParam] = [
|
user_content: List[ChatCompletionContentPartParam] = [
|
||||||
{
|
{
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
from prompts.types import SystemPrompts
|
||||||
|
|
||||||
|
|
||||||
IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT = """
|
IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT = """
|
||||||
You are an expert Tailwind developer.
|
You are an expert Tailwind developer.
|
||||||
|
|
||||||
@ -90,3 +93,13 @@ You are an expert at building SVGs.
|
|||||||
Return only the full code in <svg></svg> tags.
|
Return only the full code in <svg></svg> tags.
|
||||||
Do not include markdown "```" or "```svg" at the start or end.
|
Do not include markdown "```" or "```svg" at the start or end.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
IMPORTED_CODE_SYSTEM_PROMPTS = SystemPrompts(
|
||||||
|
html_tailwind=IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
|
||||||
|
react_tailwind=IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT,
|
||||||
|
bootstrap=IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT,
|
||||||
|
ionic_tailwind=IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT,
|
||||||
|
# TODO: Fix this prompt to actually be Vue
|
||||||
|
vue_tailwind=IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
|
||||||
|
svg=IMPORTED_CODE_SVG_SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
|||||||
@ -1,4 +1,7 @@
|
|||||||
TAILWIND_SYSTEM_PROMPT = """
|
from prompts.types import SystemPrompts
|
||||||
|
|
||||||
|
|
||||||
|
HTML_TAILWIND_SYSTEM_PROMPT = """
|
||||||
You are an expert Tailwind developer
|
You are an expert Tailwind developer
|
||||||
You take screenshots of a reference web page from the user, and then build single page apps
|
You take screenshots of a reference web page from the user, and then build single page apps
|
||||||
using Tailwind, HTML and JS.
|
using Tailwind, HTML and JS.
|
||||||
@ -170,3 +173,13 @@ padding, margin, border, etc. Match the colors and sizes exactly.
|
|||||||
Return only the full code in <svg></svg> tags.
|
Return only the full code in <svg></svg> tags.
|
||||||
Do not include markdown "```" or "```svg" at the start or end.
|
Do not include markdown "```" or "```svg" at the start or end.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
SYSTEM_PROMPTS = SystemPrompts(
|
||||||
|
html_tailwind=HTML_TAILWIND_SYSTEM_PROMPT,
|
||||||
|
react_tailwind=REACT_TAILWIND_SYSTEM_PROMPT,
|
||||||
|
bootstrap=BOOTSTRAP_SYSTEM_PROMPT,
|
||||||
|
ionic_tailwind=IONIC_TAILWIND_SYSTEM_PROMPT,
|
||||||
|
vue_tailwind=VUE_TAILWIND_SYSTEM_PROMPT,
|
||||||
|
svg=SVG_SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
|||||||
20
backend/prompts/types.py
Normal file
20
backend/prompts/types.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class SystemPrompts(TypedDict):
|
||||||
|
html_tailwind: str
|
||||||
|
react_tailwind: str
|
||||||
|
bootstrap: str
|
||||||
|
ionic_tailwind: str
|
||||||
|
vue_tailwind: str
|
||||||
|
svg: str
|
||||||
|
|
||||||
|
|
||||||
|
Stack = Literal[
|
||||||
|
"html_tailwind",
|
||||||
|
"react_tailwind",
|
||||||
|
"bootstrap",
|
||||||
|
"ionic_tailwind",
|
||||||
|
"vue_tailwind",
|
||||||
|
"svg",
|
||||||
|
]
|
||||||
@ -6,12 +6,13 @@ from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
|||||||
from llm import stream_openai_response
|
from llm import stream_openai_response
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
from mock_llm import mock_completion
|
from mock_llm import mock_completion
|
||||||
from typing import Dict, List
|
from typing import Dict, List, cast, get_args
|
||||||
from image_generation import create_alt_url_mapping, generate_images
|
from image_generation import create_alt_url_mapping, generate_images
|
||||||
from prompts import assemble_imported_code_prompt, assemble_prompt
|
from prompts import assemble_imported_code_prompt, assemble_prompt
|
||||||
from access_token import validate_access_token
|
from access_token import validate_access_token
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
|
from prompts.types import Stack
|
||||||
|
|
||||||
from utils import pprint_prompt # type: ignore
|
from utils import pprint_prompt # type: ignore
|
||||||
|
|
||||||
@ -96,6 +97,13 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Validate the generated code config
|
||||||
|
if not generated_code_config in get_args(Stack):
|
||||||
|
await throw_error(f"Invalid generated code config: {generated_code_config}")
|
||||||
|
return
|
||||||
|
# Cast the variable to the Stack type
|
||||||
|
valid_stack = cast(Stack, generated_code_config)
|
||||||
|
|
||||||
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
|
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
|
||||||
openai_base_url = None
|
openai_base_url = None
|
||||||
# Disable user-specified OpenAI Base URL in prod
|
# Disable user-specified OpenAI Base URL in prod
|
||||||
@ -131,7 +139,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
|
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
|
||||||
original_imported_code = params["history"][0]
|
original_imported_code = params["history"][0]
|
||||||
prompt_messages = assemble_imported_code_prompt(
|
prompt_messages = assemble_imported_code_prompt(
|
||||||
original_imported_code, generated_code_config
|
original_imported_code, valid_stack
|
||||||
)
|
)
|
||||||
for index, text in enumerate(params["history"][1:]):
|
for index, text in enumerate(params["history"][1:]):
|
||||||
if index % 2 == 0:
|
if index % 2 == 0:
|
||||||
@ -150,12 +158,10 @@ async def stream_code(websocket: WebSocket):
|
|||||||
try:
|
try:
|
||||||
if params.get("resultImage") and params["resultImage"]:
|
if params.get("resultImage") and params["resultImage"]:
|
||||||
prompt_messages = assemble_prompt(
|
prompt_messages = assemble_prompt(
|
||||||
params["image"], generated_code_config, params["resultImage"]
|
params["image"], valid_stack, params["resultImage"]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt_messages = assemble_prompt(
|
prompt_messages = assemble_prompt(params["image"], valid_stack)
|
||||||
params["image"], generated_code_config
|
|
||||||
)
|
|
||||||
except:
|
except:
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user