diff --git a/backend/prompts/__init__.py b/backend/prompts/__init__.py index 4f2e329..dc96ab9 100644 --- a/backend/prompts/__init__.py +++ b/backend/prompts/__init__.py @@ -1,6 +1,7 @@ from typing import List, NoReturn, Union from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam +from llm import Llm from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS from prompts.screenshot_system_prompts import SYSTEM_PROMPTS @@ -17,7 +18,7 @@ Generate code for a SVG that looks exactly like this. def assemble_imported_code_prompt( - code: str, stack: Stack, result_image_data_url: Union[str, None] = None + code: str, stack: Stack, model: Llm ) -> List[ChatCompletionMessageParam]: system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack] @@ -26,16 +27,25 @@ def assemble_imported_code_prompt( if stack != "svg" else "Here is the code of the SVG: " + code ) - return [ - { - "role": "system", - "content": system_content, - }, - { - "role": "user", - "content": user_content, - }, - ] + + if model == Llm.CLAUDE_3_5_SONNET_2024_06_20: + return [ + { + "role": "system", + "content": system_content + "\n " + user_content, + } + ] + else: + return [ + { + "role": "system", + "content": system_content, + }, + { + "role": "user", + "content": user_content, + }, + ] # TODO: Use result_image_data_url diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index cb6a549..2f21b94 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -23,6 +23,7 @@ from routes.logging_utils import PaymentMethod, send_to_saas_backend from routes.saas_utils import does_user_have_subscription_credits from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack +from utils import pprint_prompt # from utils import pprint_prompt from video.utils import extract_tag_content, assemble_claude_prompt_video @@ -207,7 +208,7 @@ async def stream_code(websocket: WebSocket): if params.get("isImportedFromCode") and params["isImportedFromCode"]: original_imported_code = params["history"][0] prompt_messages = assemble_imported_code_prompt( - original_imported_code, valid_stack + original_imported_code, valid_stack, code_generation_model ) for index, text in enumerate(params["history"][1:]): # TODO: Remove after "Select and edit" is fully implemented