abstract out prompt assembly into a separate function
This commit is contained in:
parent
dd7a51dd34
commit
3591588e2b
@ -23,7 +23,7 @@ from prompts import assemble_imported_code_prompt, assemble_prompt
|
|||||||
from prompts.claude_prompts import VIDEO_PROMPT
|
from prompts.claude_prompts import VIDEO_PROMPT
|
||||||
from prompts.types import Stack
|
from prompts.types import Stack
|
||||||
|
|
||||||
# from utils import pprint_prompt
|
from utils import pprint_prompt
|
||||||
from video.utils import assemble_claude_prompt_video
|
from video.utils import assemble_claude_prompt_video
|
||||||
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
||||||
|
|
||||||
@ -31,6 +31,64 @@ from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
async def create_prompt(
|
||||||
|
params: Dict[str, str], stack: Stack, model: Llm, input_mode: InputMode
|
||||||
|
) -> tuple[list[ChatCompletionMessageParam], Dict[str, str]]:
|
||||||
|
|
||||||
|
image_cache: Dict[str, str] = {}
|
||||||
|
|
||||||
|
# If this generation started off with imported code, we need to assemble the prompt differently
|
||||||
|
if params.get("isImportedFromCode"):
|
||||||
|
original_imported_code = params["history"][0]
|
||||||
|
prompt_messages = assemble_imported_code_prompt(
|
||||||
|
original_imported_code, stack, model
|
||||||
|
)
|
||||||
|
for index, text in enumerate(params["history"][1:]):
|
||||||
|
if index % 2 == 0:
|
||||||
|
message: ChatCompletionMessageParam = {
|
||||||
|
"role": "user",
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
message: ChatCompletionMessageParam = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
prompt_messages.append(message)
|
||||||
|
else:
|
||||||
|
# Assemble the prompt for non-imported code
|
||||||
|
if params.get("resultImage"):
|
||||||
|
prompt_messages = assemble_prompt(
|
||||||
|
params["image"], stack, params["resultImage"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt_messages = assemble_prompt(params["image"], stack)
|
||||||
|
|
||||||
|
if params["generationType"] == "update":
|
||||||
|
# Transform the history tree into message format
|
||||||
|
# TODO: Move this to frontend
|
||||||
|
for index, text in enumerate(params["history"]):
|
||||||
|
if index % 2 == 0:
|
||||||
|
message: ChatCompletionMessageParam = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
message: ChatCompletionMessageParam = {
|
||||||
|
"role": "user",
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
prompt_messages.append(message)
|
||||||
|
|
||||||
|
image_cache = create_alt_url_mapping(params["history"][-2])
|
||||||
|
|
||||||
|
if input_mode == "video":
|
||||||
|
video_data_url = params["image"]
|
||||||
|
prompt_messages = await assemble_claude_prompt_video(video_data_url)
|
||||||
|
|
||||||
|
return prompt_messages, image_cache
|
||||||
|
|
||||||
|
|
||||||
# Auto-upgrade usage of older models
|
# Auto-upgrade usage of older models
|
||||||
def auto_upgrade_model(code_generation_model: Llm) -> Llm:
|
def auto_upgrade_model(code_generation_model: Llm) -> Llm:
|
||||||
if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}:
|
if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}:
|
||||||
@ -185,76 +243,34 @@ async def stream_code(websocket: WebSocket):
|
|||||||
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
|
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
|
||||||
|
|
||||||
print("generating code...")
|
print("generating code...")
|
||||||
|
|
||||||
|
# TODO(*): Print with send_message instead of print statements
|
||||||
await send_message("status", "Generating code...", 0)
|
await send_message("status", "Generating code...", 0)
|
||||||
await send_message("status", "Generating code...", 1)
|
await send_message("status", "Generating code...", 1)
|
||||||
|
|
||||||
|
# TODO(*): Move down
|
||||||
async def process_chunk(content: str, variantIndex: int):
|
async def process_chunk(content: str, variantIndex: int):
|
||||||
await send_message("chunk", content, variantIndex)
|
await send_message("chunk", content, variantIndex)
|
||||||
|
|
||||||
# Image cache for updates so that we don't have to regenerate images
|
# Image cache for updates so that we don't have to regenerate images
|
||||||
image_cache: Dict[str, str] = {}
|
image_cache: Dict[str, str] = {}
|
||||||
|
|
||||||
# If this generation started off with imported code, we need to assemble the prompt differently
|
try:
|
||||||
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
|
prompt_messages, image_cache = await create_prompt(
|
||||||
original_imported_code = params["history"][0]
|
params, valid_stack, code_generation_model, validated_input_mode
|
||||||
prompt_messages = assemble_imported_code_prompt(
|
|
||||||
original_imported_code, valid_stack, code_generation_model
|
|
||||||
)
|
)
|
||||||
for index, text in enumerate(params["history"][1:]):
|
except:
|
||||||
if index % 2 == 0:
|
# TODO(*): This should use variantIndex
|
||||||
message: ChatCompletionMessageParam = {
|
await websocket.send_json(
|
||||||
"role": "user",
|
{
|
||||||
"content": text,
|
"type": "error",
|
||||||
}
|
"value": "Error assembling prompt. Contact support at support@picoapps.xyz",
|
||||||
else:
|
}
|
||||||
message: ChatCompletionMessageParam = {
|
)
|
||||||
"role": "assistant",
|
await websocket.close()
|
||||||
"content": text,
|
raise
|
||||||
}
|
|
||||||
prompt_messages.append(message)
|
|
||||||
else:
|
|
||||||
# Assemble the prompt
|
|
||||||
try:
|
|
||||||
if params.get("resultImage") and params["resultImage"]:
|
|
||||||
prompt_messages = assemble_prompt(
|
|
||||||
params["image"], valid_stack, params["resultImage"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt_messages = assemble_prompt(params["image"], valid_stack)
|
|
||||||
except:
|
|
||||||
# TODO: This should use variantIndex
|
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"type": "error",
|
|
||||||
"value": "Error assembling prompt. Contact support at support@picoapps.xyz",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
await websocket.close()
|
|
||||||
return
|
|
||||||
|
|
||||||
if params["generationType"] == "update":
|
pprint_prompt(prompt_messages) # type: ignore
|
||||||
# Transform the history tree into message format
|
|
||||||
# TODO: Move this to frontend
|
|
||||||
for index, text in enumerate(params["history"]):
|
|
||||||
if index % 2 == 0:
|
|
||||||
message: ChatCompletionMessageParam = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": text,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
message: ChatCompletionMessageParam = {
|
|
||||||
"role": "user",
|
|
||||||
"content": text,
|
|
||||||
}
|
|
||||||
prompt_messages.append(message)
|
|
||||||
|
|
||||||
image_cache = create_alt_url_mapping(params["history"][-2])
|
|
||||||
|
|
||||||
if validated_input_mode == "video":
|
|
||||||
video_data_url = params["image"]
|
|
||||||
prompt_messages = await assemble_claude_prompt_video(video_data_url)
|
|
||||||
|
|
||||||
# pprint_prompt(prompt_messages) # type: ignore
|
|
||||||
|
|
||||||
if SHOULD_MOCK_AI_RESPONSE:
|
if SHOULD_MOCK_AI_RESPONSE:
|
||||||
completions = [
|
completions = [
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user