abstract out prompt assembly into a separate function

This commit is contained in:
Abi Raja 2024-07-31 11:07:30 -04:00
parent dd7a51dd34
commit 3591588e2b

View File

@ -23,7 +23,7 @@ from prompts import assemble_imported_code_prompt, assemble_prompt
from prompts.claude_prompts import VIDEO_PROMPT from prompts.claude_prompts import VIDEO_PROMPT
from prompts.types import Stack from prompts.types import Stack
# from utils import pprint_prompt from utils import pprint_prompt
from video.utils import assemble_claude_prompt_video from video.utils import assemble_claude_prompt_video
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
@ -31,6 +31,64 @@ from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
router = APIRouter() router = APIRouter()
async def create_prompt(
params: Dict[str, str], stack: Stack, model: Llm, input_mode: InputMode
) -> tuple[list[ChatCompletionMessageParam], Dict[str, str]]:
image_cache: Dict[str, str] = {}
# If this generation started off with imported code, we need to assemble the prompt differently
if params.get("isImportedFromCode"):
original_imported_code = params["history"][0]
prompt_messages = assemble_imported_code_prompt(
original_imported_code, stack, model
)
for index, text in enumerate(params["history"][1:]):
if index % 2 == 0:
message: ChatCompletionMessageParam = {
"role": "user",
"content": text,
}
else:
message: ChatCompletionMessageParam = {
"role": "assistant",
"content": text,
}
prompt_messages.append(message)
else:
# Assemble the prompt for non-imported code
if params.get("resultImage"):
prompt_messages = assemble_prompt(
params["image"], stack, params["resultImage"]
)
else:
prompt_messages = assemble_prompt(params["image"], stack)
if params["generationType"] == "update":
# Transform the history tree into message format
# TODO: Move this to frontend
for index, text in enumerate(params["history"]):
if index % 2 == 0:
message: ChatCompletionMessageParam = {
"role": "assistant",
"content": text,
}
else:
message: ChatCompletionMessageParam = {
"role": "user",
"content": text,
}
prompt_messages.append(message)
image_cache = create_alt_url_mapping(params["history"][-2])
if input_mode == "video":
video_data_url = params["image"]
prompt_messages = await assemble_claude_prompt_video(video_data_url)
return prompt_messages, image_cache
# Auto-upgrade usage of older models # Auto-upgrade usage of older models
def auto_upgrade_model(code_generation_model: Llm) -> Llm: def auto_upgrade_model(code_generation_model: Llm) -> Llm:
if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}: if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}:
@ -185,44 +243,24 @@ async def stream_code(websocket: WebSocket):
should_generate_images = bool(params.get("isImageGenerationEnabled", True)) should_generate_images = bool(params.get("isImageGenerationEnabled", True))
print("generating code...") print("generating code...")
# TODO(*): Print with send_message instead of print statements
await send_message("status", "Generating code...", 0) await send_message("status", "Generating code...", 0)
await send_message("status", "Generating code...", 1) await send_message("status", "Generating code...", 1)
# TODO(*): Move down
async def process_chunk(content: str, variantIndex: int): async def process_chunk(content: str, variantIndex: int):
await send_message("chunk", content, variantIndex) await send_message("chunk", content, variantIndex)
# Image cache for updates so that we don't have to regenerate images # Image cache for updates so that we don't have to regenerate images
image_cache: Dict[str, str] = {} image_cache: Dict[str, str] = {}
# If this generation started off with imported code, we need to assemble the prompt differently
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
original_imported_code = params["history"][0]
prompt_messages = assemble_imported_code_prompt(
original_imported_code, valid_stack, code_generation_model
)
for index, text in enumerate(params["history"][1:]):
if index % 2 == 0:
message: ChatCompletionMessageParam = {
"role": "user",
"content": text,
}
else:
message: ChatCompletionMessageParam = {
"role": "assistant",
"content": text,
}
prompt_messages.append(message)
else:
# Assemble the prompt
try: try:
if params.get("resultImage") and params["resultImage"]: prompt_messages, image_cache = await create_prompt(
prompt_messages = assemble_prompt( params, valid_stack, code_generation_model, validated_input_mode
params["image"], valid_stack, params["resultImage"]
) )
else:
prompt_messages = assemble_prompt(params["image"], valid_stack)
except: except:
# TODO: This should use variantIndex # TODO(*): This should use variantIndex
await websocket.send_json( await websocket.send_json(
{ {
"type": "error", "type": "error",
@ -230,31 +268,9 @@ async def stream_code(websocket: WebSocket):
} }
) )
await websocket.close() await websocket.close()
return raise
if params["generationType"] == "update": pprint_prompt(prompt_messages) # type: ignore
# Transform the history tree into message format
# TODO: Move this to frontend
for index, text in enumerate(params["history"]):
if index % 2 == 0:
message: ChatCompletionMessageParam = {
"role": "assistant",
"content": text,
}
else:
message: ChatCompletionMessageParam = {
"role": "user",
"content": text,
}
prompt_messages.append(message)
image_cache = create_alt_url_mapping(params["history"][-2])
if validated_input_mode == "video":
video_data_url = params["image"]
prompt_messages = await assemble_claude_prompt_video(video_data_url)
# pprint_prompt(prompt_messages) # type: ignore
if SHOULD_MOCK_AI_RESPONSE: if SHOULD_MOCK_AI_RESPONSE:
completions = [ completions = [