diff --git a/backend/prompts/__init__.py b/backend/prompts/__init__.py index dc96ab9..de884c5 100644 --- a/backend/prompts/__init__.py +++ b/backend/prompts/__init__.py @@ -1,11 +1,14 @@ -from typing import List, NoReturn, Union +from typing import Union from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam +from custom_types import InputMode +from image_generation import create_alt_url_mapping from llm import Llm from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS from prompts.screenshot_system_prompts import SYSTEM_PROMPTS from prompts.types import Stack +from video.utils import assemble_claude_prompt_video USER_PROMPT = """ @@ -17,9 +20,67 @@ Generate code for a SVG that looks exactly like this. """ +async def create_prompt( + params: dict[str, str], stack: Stack, model: Llm, input_mode: InputMode +) -> tuple[list[ChatCompletionMessageParam], dict[str, str]]: + + image_cache: dict[str, str] = {} + + # If this generation started off with imported code, we need to assemble the prompt differently + if params.get("isImportedFromCode"): + original_imported_code = params["history"][0] + prompt_messages = assemble_imported_code_prompt( + original_imported_code, stack, model + ) + for index, text in enumerate(params["history"][1:]): + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + prompt_messages.append(message) + else: + # Assemble the prompt for non-imported code + if params.get("resultImage"): + prompt_messages = assemble_prompt( + params["image"], stack, params["resultImage"] + ) + else: + prompt_messages = assemble_prompt(params["image"], stack) + + if params["generationType"] == "update": + # Transform the history tree into message format + # TODO: Move this to frontend + for index, text in enumerate(params["history"]): + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + prompt_messages.append(message) + + image_cache = create_alt_url_mapping(params["history"][-2]) + + if input_mode == "video": + video_data_url = params["image"] + prompt_messages = await assemble_claude_prompt_video(video_data_url) + + return prompt_messages, image_cache + + def assemble_imported_code_prompt( code: str, stack: Stack, model: Llm -) -> List[ChatCompletionMessageParam]: +) -> list[ChatCompletionMessageParam]: system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack] user_content = ( @@ -53,11 +114,11 @@ def assemble_prompt( image_data_url: str, stack: Stack, result_image_data_url: Union[str, None] = None, -) -> List[ChatCompletionMessageParam]: +) -> list[ChatCompletionMessageParam]: system_content = SYSTEM_PROMPTS[stack] user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT - user_content: List[ChatCompletionContentPartParam] = [ + user_content: list[ChatCompletionContentPartParam] = [ { "type": "image_url", "image_url": {"url": image_data_url, "detail": "high"}, diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index ced9655..0c6e921 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -14,81 +14,20 @@ from llm import ( stream_claude_response_native, stream_openai_response, ) -from openai.types.chat import ChatCompletionMessageParam from fs_logging.core import write_logs from mock_llm import mock_completion from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args -from image_generation import create_alt_url_mapping, generate_images -from prompts import assemble_imported_code_prompt, assemble_prompt +from image_generation import generate_images +from prompts import create_prompt from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack - from utils import pprint_prompt -from video.utils import assemble_claude_prompt_video from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore router = APIRouter() -async def create_prompt( - params: Dict[str, str], stack: Stack, model: Llm, input_mode: InputMode -) -> tuple[list[ChatCompletionMessageParam], Dict[str, str]]: - - image_cache: Dict[str, str] = {} - - # If this generation started off with imported code, we need to assemble the prompt differently - if params.get("isImportedFromCode"): - original_imported_code = params["history"][0] - prompt_messages = assemble_imported_code_prompt( - original_imported_code, stack, model - ) - for index, text in enumerate(params["history"][1:]): - if index % 2 == 0: - message: ChatCompletionMessageParam = { - "role": "user", - "content": text, - } - else: - message: ChatCompletionMessageParam = { - "role": "assistant", - "content": text, - } - prompt_messages.append(message) - else: - # Assemble the prompt for non-imported code - if params.get("resultImage"): - prompt_messages = assemble_prompt( - params["image"], stack, params["resultImage"] - ) - else: - prompt_messages = assemble_prompt(params["image"], stack) - - if params["generationType"] == "update": - # Transform the history tree into message format - # TODO: Move this to frontend - for index, text in enumerate(params["history"]): - if index % 2 == 0: - message: ChatCompletionMessageParam = { - "role": "assistant", - "content": text, - } - else: - message: ChatCompletionMessageParam = { - "role": "user", - "content": text, - } - prompt_messages.append(message) - - image_cache = create_alt_url_mapping(params["history"][-2]) - - if input_mode == "video": - video_data_url = params["image"] - prompt_messages = await assemble_claude_prompt_video(video_data_url) - - return prompt_messages, image_cache - - # Auto-upgrade usage of older models def auto_upgrade_model(code_generation_model: Llm) -> Llm: if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}: