From b158597d7e969de715a6bcc288530553487e666a Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Mon, 5 Aug 2024 14:11:01 -0400 Subject: [PATCH] fix bug with prompt assembly for imported code with Claude which disallows multiple user messages in a row --- backend/prompts/__init__.py | 36 ++++---------- backend/prompts/test_prompts.py | 84 ++++++++++++++++++++------------- backend/routes/generate_code.py | 7 ++- 3 files changed, 64 insertions(+), 63 deletions(-) diff --git a/backend/prompts/__init__.py b/backend/prompts/__init__.py index d7103f9..955502b 100644 --- a/backend/prompts/__init__.py +++ b/backend/prompts/__init__.py @@ -1,10 +1,8 @@ from typing import Union - from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam + from custom_types import InputMode from image_generation.core import create_alt_url_mapping -from llm import Llm - from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS from prompts.screenshot_system_prompts import SYSTEM_PROMPTS from prompts.types import Stack @@ -21,7 +19,7 @@ Generate code for a SVG that looks exactly like this. async def create_prompt( - params: dict[str, str], stack: Stack, model: Llm, input_mode: InputMode + params: dict[str, str], stack: Stack, input_mode: InputMode ) -> tuple[list[ChatCompletionMessageParam], dict[str, str]]: image_cache: dict[str, str] = {} @@ -29,9 +27,7 @@ async def create_prompt( # If this generation started off with imported code, we need to assemble the prompt differently if params.get("isImportedFromCode"): original_imported_code = params["history"][0] - prompt_messages = assemble_imported_code_prompt( - original_imported_code, stack, model - ) + prompt_messages = assemble_imported_code_prompt(original_imported_code, stack) for index, text in enumerate(params["history"][1:]): if index % 2 == 0: message: ChatCompletionMessageParam = { @@ -79,7 +75,7 @@ async def create_prompt( def assemble_imported_code_prompt( - code: str, stack: Stack, model: Llm + code: str, stack: Stack ) -> list[ChatCompletionMessageParam]: system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack] @@ -89,24 +85,12 @@ def assemble_imported_code_prompt( else "Here is the code of the SVG: " + code ) - if model == Llm.CLAUDE_3_5_SONNET_2024_06_20: - return [ - { - "role": "system", - "content": system_content + "\n " + user_content, - } - ] - else: - return [ - { - "role": "system", - "content": system_content, - }, - { - "role": "user", - "content": user_content, - }, - ] + return [ + { + "role": "system", + "content": system_content + "\n " + user_content, + } + ] # TODO: Use result_image_data_url diff --git a/backend/prompts/test_prompts.py b/backend/prompts/test_prompts.py index 9175fd8..049f9db 100644 --- a/backend/prompts/test_prompts.py +++ b/backend/prompts/test_prompts.py @@ -391,63 +391,81 @@ def test_prompts(): def test_imported_code_prompts(): - tailwind_prompt = assemble_imported_code_prompt( - "code", "html_tailwind", Llm.GPT_4O_2024_05_13 - ) + code = "Sample code" + + tailwind_prompt = assemble_imported_code_prompt(code, "html_tailwind") expected_tailwind_prompt = [ - {"role": "system", "content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert tailwind_prompt == expected_tailwind_prompt - html_css_prompt = assemble_imported_code_prompt( - "code", "html_css", Llm.GPT_4O_2024_05_13 - ) + html_css_prompt = assemble_imported_code_prompt(code, "html_css") expected_html_css_prompt = [ - {"role": "system", "content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert html_css_prompt == expected_html_css_prompt - react_tailwind_prompt = assemble_imported_code_prompt( - "code", "react_tailwind", Llm.GPT_4O_2024_05_13 - ) + react_tailwind_prompt = assemble_imported_code_prompt(code, "react_tailwind") expected_react_tailwind_prompt = [ - {"role": "system", "content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert react_tailwind_prompt == expected_react_tailwind_prompt - bootstrap_prompt = assemble_imported_code_prompt( - "code", "bootstrap", Llm.GPT_4O_2024_05_13 - ) + bootstrap_prompt = assemble_imported_code_prompt(code, "bootstrap") expected_bootstrap_prompt = [ - {"role": "system", "content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert bootstrap_prompt == expected_bootstrap_prompt - ionic_tailwind = assemble_imported_code_prompt( - "code", "ionic_tailwind", Llm.GPT_4O_2024_05_13 - ) + ionic_tailwind = assemble_imported_code_prompt(code, "ionic_tailwind") expected_ionic_tailwind = [ - {"role": "system", "content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert ionic_tailwind == expected_ionic_tailwind - vue_tailwind = assemble_imported_code_prompt( - "code", "vue_tailwind", Llm.GPT_4O_2024_05_13 - ) + vue_tailwind = assemble_imported_code_prompt(code, "vue_tailwind") expected_vue_tailwind = [ - {"role": "system", "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT}, - {"role": "user", "content": "Here is the code of the app: code"}, + { + "role": "system", + "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT + + "\n Here is the code of the app: " + + code, + } ] assert vue_tailwind == expected_vue_tailwind - svg = assemble_imported_code_prompt("code", "svg", Llm.GPT_4O_2024_05_13) + svg = assemble_imported_code_prompt(code, "svg") expected_svg = [ - {"role": "system", "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT}, - {"role": "user", "content": "Here is the code of the SVG: code"}, + { + "role": "system", + "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT + + "\n Here is the code of the SVG: " + + code, + } ] assert svg == expected_svg diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 8b8d45a..d7f46c5 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -27,8 +27,7 @@ from image_generation.core import generate_images from prompts import create_prompt from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack - -# from utils import pprint_prompt +from utils import pprint_prompt from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore @@ -240,7 +239,7 @@ async def stream_code(websocket: WebSocket): try: prompt_messages, image_cache = await create_prompt( - params, valid_stack, code_generation_model, validated_input_mode + params, valid_stack, validated_input_mode ) except: await throw_error( @@ -248,7 +247,7 @@ async def stream_code(websocket: WebSocket): ) raise - # pprint_prompt(prompt_messages) # type: ignore + pprint_prompt(prompt_messages) # type: ignore ### Code generation