From 0700de7767790aa53f621b9a6fff1d99f3a8f554 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Tue, 30 Jul 2024 16:27:04 -0400 Subject: [PATCH] standardize to using typed send_message --- backend/mock_llm.py | 4 +- backend/routes/generate_code.py | 75 +++++++++++++-------------------- 2 files changed, 31 insertions(+), 48 deletions(-) diff --git a/backend/mock_llm.py b/backend/mock_llm.py index b85b1b1..a76b906 100644 --- a/backend/mock_llm.py +++ b/backend/mock_llm.py @@ -8,7 +8,7 @@ STREAM_CHUNK_SIZE = 20 async def mock_completion( - process_chunk: Callable[[str], Awaitable[None]], input_mode: InputMode + process_chunk: Callable[[str, int], Awaitable[None]], input_mode: InputMode ) -> str: code_to_return = ( TALLY_FORM_VIDEO_PROMPT_MOCK @@ -17,7 +17,7 @@ async def mock_completion( ) for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE): - await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE]) + await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE], 0) await asyncio.sleep(0.01) if input_mode == "video": diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 3d0c915..ed02b0e 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -15,7 +15,7 @@ from llm import ( ) from openai.types.chat import ChatCompletionMessageParam from mock_llm import mock_completion -from typing import Any, Coroutine, Dict, List, Union, cast, get_args +from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args from image_generation import create_alt_url_mapping, generate_images from prompts import assemble_imported_code_prompt, assemble_prompt from datetime import datetime @@ -52,23 +52,20 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st # Generate images and return updated completions async def process_completion( - websocket: WebSocket, completion: str, index: int, should_generate_images: bool, openai_api_key: str | None, openai_base_url: str | None, image_cache: dict[str, str], + send_message: Callable[ + [Literal["chunk", "status", "setCode", "error"], str, int], + Coroutine[Any, Any, None], + ], ): try: if should_generate_images and openai_api_key: - await websocket.send_json( - { - "type": "status", - "value": f"Generating images...", - "variantIndex": index, - } - ) + await send_message("status", "Generating images...", index) updated_html = await generate_images( completion, api_key=openai_api_key, @@ -78,28 +75,14 @@ async def process_completion( else: updated_html = completion - await websocket.send_json( - {"type": "setCode", "value": updated_html, "variantIndex": index} - ) - await websocket.send_json( - { - "type": "status", - "value": f"Code generation complete.", - "variantIndex": index, - } - ) + await send_message("setCode", updated_html, index) + await send_message("status", "Code generation complete.", index) except Exception as e: traceback.print_exc() print(f"Image generation failed for variant {index}", e) - await websocket.send_json( - {"type": "setCode", "value": completion, "variantIndex": index} - ) - await websocket.send_json( - { - "type": "status", - "value": f"Image generation failed but code is complete.", - "variantIndex": index, - } + await send_message("setCode", completion, index) + await send_message( + "status", "Image generation failed but code is complete.", index ) @@ -115,6 +98,15 @@ async def stream_code(websocket: WebSocket): await websocket.send_json({"type": "error", "value": message}) await websocket.close(APP_ERROR_WEB_SOCKET_CODE) + async def send_message( + type: Literal["chunk", "status", "setCode", "error"], + value: str, + variantIndex: int, + ): + await websocket.send_json( + {"type": type, "value": value, "variantIndex": variantIndex} + ) + # TODO: Are the values always strings? params: Dict[str, str] = await websocket.receive_json() @@ -216,17 +208,11 @@ async def stream_code(websocket: WebSocket): should_generate_images = bool(params.get("isImageGenerationEnabled", True)) print("generating code...") - await websocket.send_json( - {"type": "status", "value": "Generating code...", "variantIndex": 0} - ) - await websocket.send_json( - {"type": "status", "value": "Generating code...", "variantIndex": 1} - ) + await send_message("status", "Generating code...", 0) + await send_message("status", "Generating code...", 1) - async def process_chunk(content: str, variantIndex: int = 0): - await websocket.send_json( - {"type": "chunk", "value": content, "variantIndex": variantIndex} - ) + async def process_chunk(content: str, variantIndex: int): + await send_message("chunk", content, variantIndex) # Image cache for updates so that we don't have to regenerate images image_cache: Dict[str, str] = {} @@ -259,6 +245,7 @@ async def stream_code(websocket: WebSocket): else: prompt_messages = assemble_prompt(params["image"], valid_stack) except: + # TODO: This should use variantIndex await websocket.send_json( { "type": "error", @@ -310,7 +297,7 @@ async def stream_code(websocket: WebSocket): system_prompt=VIDEO_PROMPT, messages=prompt_messages, # type: ignore api_key=anthropic_api_key, - callback=lambda x: process_chunk(x), + callback=lambda x: process_chunk(x, 0), model=Llm.CLAUDE_3_OPUS, include_thinking=True, ) @@ -412,13 +399,13 @@ async def stream_code(websocket: WebSocket): try: image_generation_tasks = [ process_completion( - websocket, completion, index, should_generate_images, openai_api_key, openai_base_url, image_cache, + send_message, ) for index, completion in enumerate(completions) ] @@ -426,12 +413,8 @@ async def stream_code(websocket: WebSocket): except Exception as e: traceback.print_exc() print("An error occurred during image generation and processing", e) - await websocket.send_json( - { - "type": "status", - "value": "An error occurred during image generation and processing.", - "variantIndex": 0, - } + await send_message( + "status", "An error occurred during image generation and processing.", 0 ) await websocket.close()