From 4288cf208871ffecc0785a45545c24de523f53e4 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 26 Jun 2024 13:21:12 +0800 Subject: [PATCH] fix imported code prompt for Claude which doesn't allow multiple 'user' role messages in a row --- backend/prompts/__init__.py | 32 +++++++++++++++++++++----------- backend/routes/generate_code.py | 3 ++- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/backend/prompts/__init__.py b/backend/prompts/__init__.py index 4f2e329..dc96ab9 100644 --- a/backend/prompts/__init__.py +++ b/backend/prompts/__init__.py @@ -1,6 +1,7 @@ from typing import List, NoReturn, Union from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam +from llm import Llm from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS from prompts.screenshot_system_prompts import SYSTEM_PROMPTS @@ -17,7 +18,7 @@ Generate code for a SVG that looks exactly like this. def assemble_imported_code_prompt( - code: str, stack: Stack, result_image_data_url: Union[str, None] = None + code: str, stack: Stack, model: Llm ) -> List[ChatCompletionMessageParam]: system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack] @@ -26,16 +27,25 @@ def assemble_imported_code_prompt( if stack != "svg" else "Here is the code of the SVG: " + code ) - return [ - { - "role": "system", - "content": system_content, - }, - { - "role": "user", - "content": user_content, - }, - ] + + if model == Llm.CLAUDE_3_5_SONNET_2024_06_20: + return [ + { + "role": "system", + "content": system_content + "\n " + user_content, + } + ] + else: + return [ + { + "role": "system", + "content": system_content, + }, + { + "role": "user", + "content": user_content, + }, + ] # TODO: Use result_image_data_url diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 362f742..d280830 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -20,6 +20,7 @@ from datetime import datetime import json from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack +from utils import pprint_prompt # from utils import pprint_prompt from video.utils import extract_tag_content, assemble_claude_prompt_video @@ -166,7 +167,7 @@ async def stream_code(websocket: WebSocket): if params.get("isImportedFromCode") and params["isImportedFromCode"]: original_imported_code = params["history"][0] prompt_messages = assemble_imported_code_prompt( - original_imported_code, valid_stack + original_imported_code, valid_stack, code_generation_model ) for index, text in enumerate(params["history"][1:]): if index % 2 == 0: