diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index d436e83..ced9655 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -23,7 +23,7 @@ from prompts import assemble_imported_code_prompt, assemble_prompt from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack -# from utils import pprint_prompt +from utils import pprint_prompt from video.utils import assemble_claude_prompt_video from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore @@ -31,6 +31,64 @@ from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore router = APIRouter() +async def create_prompt( + params: Dict[str, str], stack: Stack, model: Llm, input_mode: InputMode +) -> tuple[list[ChatCompletionMessageParam], Dict[str, str]]: + + image_cache: Dict[str, str] = {} + + # If this generation started off with imported code, we need to assemble the prompt differently + if params.get("isImportedFromCode"): + original_imported_code = params["history"][0] + prompt_messages = assemble_imported_code_prompt( + original_imported_code, stack, model + ) + for index, text in enumerate(params["history"][1:]): + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + prompt_messages.append(message) + else: + # Assemble the prompt for non-imported code + if params.get("resultImage"): + prompt_messages = assemble_prompt( + params["image"], stack, params["resultImage"] + ) + else: + prompt_messages = assemble_prompt(params["image"], stack) + + if params["generationType"] == "update": + # Transform the history tree into message format + # TODO: Move this to frontend + for index, text in enumerate(params["history"]): + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + prompt_messages.append(message) + + image_cache = create_alt_url_mapping(params["history"][-2]) + + if input_mode == "video": + video_data_url = params["image"] + prompt_messages = await assemble_claude_prompt_video(video_data_url) + + return prompt_messages, image_cache + + # Auto-upgrade usage of older models def auto_upgrade_model(code_generation_model: Llm) -> Llm: if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}: @@ -185,76 +243,34 @@ async def stream_code(websocket: WebSocket): should_generate_images = bool(params.get("isImageGenerationEnabled", True)) print("generating code...") + + # TODO(*): Print with send_message instead of print statements await send_message("status", "Generating code...", 0) await send_message("status", "Generating code...", 1) + # TODO(*): Move down async def process_chunk(content: str, variantIndex: int): await send_message("chunk", content, variantIndex) # Image cache for updates so that we don't have to regenerate images image_cache: Dict[str, str] = {} - # If this generation started off with imported code, we need to assemble the prompt differently - if params.get("isImportedFromCode") and params["isImportedFromCode"]: - original_imported_code = params["history"][0] - prompt_messages = assemble_imported_code_prompt( - original_imported_code, valid_stack, code_generation_model + try: + prompt_messages, image_cache = await create_prompt( + params, valid_stack, code_generation_model, validated_input_mode ) - for index, text in enumerate(params["history"][1:]): - if index % 2 == 0: - message: ChatCompletionMessageParam = { - "role": "user", - "content": text, - } - else: - message: ChatCompletionMessageParam = { - "role": "assistant", - "content": text, - } - prompt_messages.append(message) - else: - # Assemble the prompt - try: - if params.get("resultImage") and params["resultImage"]: - prompt_messages = assemble_prompt( - params["image"], valid_stack, params["resultImage"] - ) - else: - prompt_messages = assemble_prompt(params["image"], valid_stack) - except: - # TODO: This should use variantIndex - await websocket.send_json( - { - "type": "error", - "value": "Error assembling prompt. Contact support at support@picoapps.xyz", - } - ) - await websocket.close() - return + except: + # TODO(*): This should use variantIndex + await websocket.send_json( + { + "type": "error", + "value": "Error assembling prompt. Contact support at support@picoapps.xyz", + } + ) + await websocket.close() + raise - if params["generationType"] == "update": - # Transform the history tree into message format - # TODO: Move this to frontend - for index, text in enumerate(params["history"]): - if index % 2 == 0: - message: ChatCompletionMessageParam = { - "role": "assistant", - "content": text, - } - else: - message: ChatCompletionMessageParam = { - "role": "user", - "content": text, - } - prompt_messages.append(message) - - image_cache = create_alt_url_mapping(params["history"][-2]) - - if validated_input_mode == "video": - video_data_url = params["image"] - prompt_messages = await assemble_claude_prompt_video(video_data_url) - - # pprint_prompt(prompt_messages) # type: ignore + pprint_prompt(prompt_messages) # type: ignore if SHOULD_MOCK_AI_RESPONSE: completions = [