fix imported code prompt for Claude which doesn't allow multiple 'user' role messages in a row

This commit is contained in:
Abi Raja 2024-06-26 13:21:12 +08:00
parent 6be83b4a2d
commit 4288cf2088
2 changed files with 23 additions and 12 deletions

View File

@ -1,6 +1,7 @@
from typing import List, NoReturn, Union from typing import List, NoReturn, Union
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
from llm import Llm
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
@ -17,7 +18,7 @@ Generate code for a SVG that looks exactly like this.
def assemble_imported_code_prompt( def assemble_imported_code_prompt(
code: str, stack: Stack, result_image_data_url: Union[str, None] = None code: str, stack: Stack, model: Llm
) -> List[ChatCompletionMessageParam]: ) -> List[ChatCompletionMessageParam]:
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack] system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
@ -26,6 +27,15 @@ def assemble_imported_code_prompt(
if stack != "svg" if stack != "svg"
else "Here is the code of the SVG: " + code else "Here is the code of the SVG: " + code
) )
if model == Llm.CLAUDE_3_5_SONNET_2024_06_20:
return [
{
"role": "system",
"content": system_content + "\n " + user_content,
}
]
else:
return [ return [
{ {
"role": "system", "role": "system",

View File

@ -20,6 +20,7 @@ from datetime import datetime
import json import json
from prompts.claude_prompts import VIDEO_PROMPT from prompts.claude_prompts import VIDEO_PROMPT
from prompts.types import Stack from prompts.types import Stack
from utils import pprint_prompt
# from utils import pprint_prompt # from utils import pprint_prompt
from video.utils import extract_tag_content, assemble_claude_prompt_video from video.utils import extract_tag_content, assemble_claude_prompt_video
@ -166,7 +167,7 @@ async def stream_code(websocket: WebSocket):
if params.get("isImportedFromCode") and params["isImportedFromCode"]: if params.get("isImportedFromCode") and params["isImportedFromCode"]:
original_imported_code = params["history"][0] original_imported_code = params["history"][0]
prompt_messages = assemble_imported_code_prompt( prompt_messages = assemble_imported_code_prompt(
original_imported_code, valid_stack original_imported_code, valid_stack, code_generation_model
) )
for index, text in enumerate(params["history"][1:]): for index, text in enumerate(params["history"][1:]):
if index % 2 == 0: if index % 2 == 0: