fix bug with prompt assembly for imported code with Claude which disallows multiple user messages in a row
This commit is contained in:
parent
fb5480b036
commit
b158597d7e
@ -1,10 +1,8 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
|
||||||
|
|
||||||
from custom_types import InputMode
|
from custom_types import InputMode
|
||||||
from image_generation.core import create_alt_url_mapping
|
from image_generation.core import create_alt_url_mapping
|
||||||
from llm import Llm
|
|
||||||
|
|
||||||
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
|
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
|
||||||
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
|
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
|
||||||
from prompts.types import Stack
|
from prompts.types import Stack
|
||||||
@ -21,7 +19,7 @@ Generate code for a SVG that looks exactly like this.
|
|||||||
|
|
||||||
|
|
||||||
async def create_prompt(
|
async def create_prompt(
|
||||||
params: dict[str, str], stack: Stack, model: Llm, input_mode: InputMode
|
params: dict[str, str], stack: Stack, input_mode: InputMode
|
||||||
) -> tuple[list[ChatCompletionMessageParam], dict[str, str]]:
|
) -> tuple[list[ChatCompletionMessageParam], dict[str, str]]:
|
||||||
|
|
||||||
image_cache: dict[str, str] = {}
|
image_cache: dict[str, str] = {}
|
||||||
@ -29,9 +27,7 @@ async def create_prompt(
|
|||||||
# If this generation started off with imported code, we need to assemble the prompt differently
|
# If this generation started off with imported code, we need to assemble the prompt differently
|
||||||
if params.get("isImportedFromCode"):
|
if params.get("isImportedFromCode"):
|
||||||
original_imported_code = params["history"][0]
|
original_imported_code = params["history"][0]
|
||||||
prompt_messages = assemble_imported_code_prompt(
|
prompt_messages = assemble_imported_code_prompt(original_imported_code, stack)
|
||||||
original_imported_code, stack, model
|
|
||||||
)
|
|
||||||
for index, text in enumerate(params["history"][1:]):
|
for index, text in enumerate(params["history"][1:]):
|
||||||
if index % 2 == 0:
|
if index % 2 == 0:
|
||||||
message: ChatCompletionMessageParam = {
|
message: ChatCompletionMessageParam = {
|
||||||
@ -79,7 +75,7 @@ async def create_prompt(
|
|||||||
|
|
||||||
|
|
||||||
def assemble_imported_code_prompt(
|
def assemble_imported_code_prompt(
|
||||||
code: str, stack: Stack, model: Llm
|
code: str, stack: Stack
|
||||||
) -> list[ChatCompletionMessageParam]:
|
) -> list[ChatCompletionMessageParam]:
|
||||||
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
|
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]
|
||||||
|
|
||||||
@ -89,24 +85,12 @@ def assemble_imported_code_prompt(
|
|||||||
else "Here is the code of the SVG: " + code
|
else "Here is the code of the SVG: " + code
|
||||||
)
|
)
|
||||||
|
|
||||||
if model == Llm.CLAUDE_3_5_SONNET_2024_06_20:
|
return [
|
||||||
return [
|
{
|
||||||
{
|
"role": "system",
|
||||||
"role": "system",
|
"content": system_content + "\n " + user_content,
|
||||||
"content": system_content + "\n " + user_content,
|
}
|
||||||
}
|
]
|
||||||
]
|
|
||||||
else:
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": system_content,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": user_content,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
# TODO: Use result_image_data_url
|
# TODO: Use result_image_data_url
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -391,63 +391,81 @@ def test_prompts():
|
|||||||
|
|
||||||
|
|
||||||
def test_imported_code_prompts():
|
def test_imported_code_prompts():
|
||||||
tailwind_prompt = assemble_imported_code_prompt(
|
code = "Sample code"
|
||||||
"code", "html_tailwind", Llm.GPT_4O_2024_05_13
|
|
||||||
)
|
tailwind_prompt = assemble_imported_code_prompt(code, "html_tailwind")
|
||||||
expected_tailwind_prompt = [
|
expected_tailwind_prompt = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the app: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
|
||||||
|
+ "\n Here is the code of the app: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert tailwind_prompt == expected_tailwind_prompt
|
assert tailwind_prompt == expected_tailwind_prompt
|
||||||
|
|
||||||
html_css_prompt = assemble_imported_code_prompt(
|
html_css_prompt = assemble_imported_code_prompt(code, "html_css")
|
||||||
"code", "html_css", Llm.GPT_4O_2024_05_13
|
|
||||||
)
|
|
||||||
expected_html_css_prompt = [
|
expected_html_css_prompt = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the app: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_HTML_CSS_SYSTEM_PROMPT
|
||||||
|
+ "\n Here is the code of the app: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert html_css_prompt == expected_html_css_prompt
|
assert html_css_prompt == expected_html_css_prompt
|
||||||
|
|
||||||
react_tailwind_prompt = assemble_imported_code_prompt(
|
react_tailwind_prompt = assemble_imported_code_prompt(code, "react_tailwind")
|
||||||
"code", "react_tailwind", Llm.GPT_4O_2024_05_13
|
|
||||||
)
|
|
||||||
expected_react_tailwind_prompt = [
|
expected_react_tailwind_prompt = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the app: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT
|
||||||
|
+ "\n Here is the code of the app: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert react_tailwind_prompt == expected_react_tailwind_prompt
|
assert react_tailwind_prompt == expected_react_tailwind_prompt
|
||||||
|
|
||||||
bootstrap_prompt = assemble_imported_code_prompt(
|
bootstrap_prompt = assemble_imported_code_prompt(code, "bootstrap")
|
||||||
"code", "bootstrap", Llm.GPT_4O_2024_05_13
|
|
||||||
)
|
|
||||||
expected_bootstrap_prompt = [
|
expected_bootstrap_prompt = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the app: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT
|
||||||
|
+ "\n Here is the code of the app: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert bootstrap_prompt == expected_bootstrap_prompt
|
assert bootstrap_prompt == expected_bootstrap_prompt
|
||||||
|
|
||||||
ionic_tailwind = assemble_imported_code_prompt(
|
ionic_tailwind = assemble_imported_code_prompt(code, "ionic_tailwind")
|
||||||
"code", "ionic_tailwind", Llm.GPT_4O_2024_05_13
|
|
||||||
)
|
|
||||||
expected_ionic_tailwind = [
|
expected_ionic_tailwind = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the app: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
|
||||||
|
+ "\n Here is the code of the app: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert ionic_tailwind == expected_ionic_tailwind
|
assert ionic_tailwind == expected_ionic_tailwind
|
||||||
|
|
||||||
vue_tailwind = assemble_imported_code_prompt(
|
vue_tailwind = assemble_imported_code_prompt(code, "vue_tailwind")
|
||||||
"code", "vue_tailwind", Llm.GPT_4O_2024_05_13
|
|
||||||
)
|
|
||||||
expected_vue_tailwind = [
|
expected_vue_tailwind = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_VUE_TAILWIND_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the app: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_VUE_TAILWIND_PROMPT
|
||||||
|
+ "\n Here is the code of the app: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert vue_tailwind == expected_vue_tailwind
|
assert vue_tailwind == expected_vue_tailwind
|
||||||
|
|
||||||
svg = assemble_imported_code_prompt("code", "svg", Llm.GPT_4O_2024_05_13)
|
svg = assemble_imported_code_prompt(code, "svg")
|
||||||
expected_svg = [
|
expected_svg = [
|
||||||
{"role": "system", "content": IMPORTED_CODE_SVG_SYSTEM_PROMPT},
|
{
|
||||||
{"role": "user", "content": "Here is the code of the SVG: code"},
|
"role": "system",
|
||||||
|
"content": IMPORTED_CODE_SVG_SYSTEM_PROMPT
|
||||||
|
+ "\n Here is the code of the SVG: "
|
||||||
|
+ code,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
assert svg == expected_svg
|
assert svg == expected_svg
|
||||||
|
|||||||
@ -27,8 +27,7 @@ from image_generation.core import generate_images
|
|||||||
from prompts import create_prompt
|
from prompts import create_prompt
|
||||||
from prompts.claude_prompts import VIDEO_PROMPT
|
from prompts.claude_prompts import VIDEO_PROMPT
|
||||||
from prompts.types import Stack
|
from prompts.types import Stack
|
||||||
|
from utils import pprint_prompt
|
||||||
# from utils import pprint_prompt
|
|
||||||
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@ -240,7 +239,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
prompt_messages, image_cache = await create_prompt(
|
prompt_messages, image_cache = await create_prompt(
|
||||||
params, valid_stack, code_generation_model, validated_input_mode
|
params, valid_stack, validated_input_mode
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
await throw_error(
|
await throw_error(
|
||||||
@ -248,7 +247,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# pprint_prompt(prompt_messages) # type: ignore
|
pprint_prompt(prompt_messages) # type: ignore
|
||||||
|
|
||||||
### Code generation
|
### Code generation
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user