fix imported code prompt for Claude which doesn't allow multiple 'user' role messages in a row
This commit is contained in:
parent
6be83b4a2d
commit
4288cf2088
@ -1,6 +1,7 @@
|
|||||||
from typing import List, NoReturn, Union
|
from typing import List, NoReturn, Union
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
|
||||||
|
from llm import Llm
|
||||||
|
|
||||||
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
|
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
|
||||||
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
|
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
|
||||||
@ -17,7 +18,7 @@ Generate code for a SVG that looks exactly like this.
|
|||||||
|
|
||||||
|
|
||||||
def assemble_imported_code_prompt(
|
def assemble_imported_code_prompt(
|
||||||
code: str, stack: Stack, result_image_data_url: Union[str, None] = None
|
code: str, stack: Stack, model: Llm
|
||||||
) -> List[ChatCompletionMessageParam]:
|
) -> List[ChatCompletionMessageParam]:
|
||||||
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
|
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
|
||||||
|
|
||||||
@ -26,16 +27,25 @@ def assemble_imported_code_prompt(
|
|||||||
if stack != "svg"
|
if stack != "svg"
|
||||||
else "Here is the code of the SVG: " + code
|
else "Here is the code of the SVG: " + code
|
||||||
)
|
)
|
||||||
return [
|
|
||||||
{
|
if model == Llm.CLAUDE_3_5_SONNET_2024_06_20:
|
||||||
"role": "system",
|
return [
|
||||||
"content": system_content,
|
{
|
||||||
},
|
"role": "system",
|
||||||
{
|
"content": system_content + "\n " + user_content,
|
||||||
"role": "user",
|
}
|
||||||
"content": user_content,
|
]
|
||||||
},
|
else:
|
||||||
]
|
return [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": system_content,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": user_content,
|
||||||
|
},
|
||||||
|
]
|
||||||
# TODO: Use result_image_data_url
|
# TODO: Use result_image_data_url
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from datetime import datetime
|
|||||||
import json
|
import json
|
||||||
from prompts.claude_prompts import VIDEO_PROMPT
|
from prompts.claude_prompts import VIDEO_PROMPT
|
||||||
from prompts.types import Stack
|
from prompts.types import Stack
|
||||||
|
from utils import pprint_prompt
|
||||||
|
|
||||||
# from utils import pprint_prompt
|
# from utils import pprint_prompt
|
||||||
from video.utils import extract_tag_content, assemble_claude_prompt_video
|
from video.utils import extract_tag_content, assemble_claude_prompt_video
|
||||||
@ -166,7 +167,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
|
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
|
||||||
original_imported_code = params["history"][0]
|
original_imported_code = params["history"][0]
|
||||||
prompt_messages = assemble_imported_code_prompt(
|
prompt_messages = assemble_imported_code_prompt(
|
||||||
original_imported_code, valid_stack
|
original_imported_code, valid_stack, code_generation_model
|
||||||
)
|
)
|
||||||
for index, text in enumerate(params["history"][1:]):
|
for index, text in enumerate(params["history"][1:]):
|
||||||
if index % 2 == 0:
|
if index % 2 == 0:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user