fix imported code prompt for Claude which doesn't allow multiple 'user' role messages in a row
This commit is contained in:
parent
6be83b4a2d
commit
4288cf2088
@ -1,6 +1,7 @@
|
||||
from typing import List, NoReturn, Union
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
|
||||
from llm import Llm
|
||||
|
||||
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
|
||||
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
|
||||
@ -17,7 +18,7 @@ Generate code for a SVG that looks exactly like this.
|
||||
|
||||
|
||||
def assemble_imported_code_prompt(
|
||||
code: str, stack: Stack, result_image_data_url: Union[str, None] = None
|
||||
code: str, stack: Stack, model: Llm
|
||||
) -> List[ChatCompletionMessageParam]:
|
||||
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
|
||||
|
||||
@ -26,16 +27,25 @@ def assemble_imported_code_prompt(
|
||||
if stack != "svg"
|
||||
else "Here is the code of the SVG: " + code
|
||||
)
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_content,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_content,
|
||||
},
|
||||
]
|
||||
|
||||
if model == Llm.CLAUDE_3_5_SONNET_2024_06_20:
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_content + "\n " + user_content,
|
||||
}
|
||||
]
|
||||
else:
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_content,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_content,
|
||||
},
|
||||
]
|
||||
# TODO: Use result_image_data_url
|
||||
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ from datetime import datetime
|
||||
import json
|
||||
from prompts.claude_prompts import VIDEO_PROMPT
|
||||
from prompts.types import Stack
|
||||
from utils import pprint_prompt
|
||||
|
||||
# from utils import pprint_prompt
|
||||
from video.utils import extract_tag_content, assemble_claude_prompt_video
|
||||
@ -166,7 +167,7 @@ async def stream_code(websocket: WebSocket):
|
||||
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
|
||||
original_imported_code = params["history"][0]
|
||||
prompt_messages = assemble_imported_code_prompt(
|
||||
original_imported_code, valid_stack
|
||||
original_imported_code, valid_stack, code_generation_model
|
||||
)
|
||||
for index, text in enumerate(params["history"][1:]):
|
||||
if index % 2 == 0:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user