From 5c3f915bce2be697ad027da93aebea1226e14762 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Wed, 31 Jul 2024 13:36:22 -0400 Subject: [PATCH] parallelize just image generation --- backend/routes/generate_code.py | 87 ++++++++++++++------------------- 1 file changed, 38 insertions(+), 49 deletions(-) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index bd4238b..c33dc07 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -16,7 +16,7 @@ from llm import ( ) from fs_logging.core import write_logs from mock_llm import mock_completion -from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args +from typing import Any, Coroutine, Dict, List, Literal, Union, cast, get_args from image_generation import generate_images from prompts import create_prompt from prompts.claude_prompts import VIDEO_PROMPT @@ -43,40 +43,27 @@ def auto_upgrade_model(code_generation_model: Llm) -> Llm: return code_generation_model -# Generate images and return updated completions -async def process_completion( +# Generate images, if needed +async def perform_image_generation( completion: str, - index: int, should_generate_images: bool, openai_api_key: str | None, openai_base_url: str | None, image_cache: dict[str, str], - send_message: Callable[ - [Literal["chunk", "status", "setCode", "error"], str, int], - Coroutine[Any, Any, None], - ], ): - try: - if should_generate_images and openai_api_key: - await send_message("status", "Generating images...", index) - updated_html = await generate_images( - completion, - api_key=openai_api_key, - base_url=openai_base_url, - image_cache=image_cache, - ) - else: - updated_html = completion + if not should_generate_images: + return completion - await send_message("setCode", updated_html, index) - await send_message("status", "Code generation complete.", index) - except Exception as e: - traceback.print_exc() - print(f"Image generation failed for variant {index}", e) - await send_message("setCode", completion, index) - await send_message( - "status", "Image generation failed but code is complete.", index - ) + if not openai_api_key: + print("No OpenAI API key found. Skipping image generation.") + return completion + + return await generate_images( + completion, + api_key=openai_api_key, + base_url=openai_base_url, + image_cache=image_cache, + ) @router.websocket("/generate-code") @@ -85,7 +72,6 @@ async def stream_code(websocket: WebSocket): print("Incoming websocket connection...") - ## Communication protocol setup async def throw_error( message: str, @@ -110,7 +96,7 @@ async def stream_code(websocket: WebSocket): ) ## Parameter validation - + # TODO: Are the values always strings? params: Dict[str, str] = await websocket.receive_json() print("Received params") @@ -328,31 +314,34 @@ async def stream_code(websocket: WebSocket): # if validated_input_mode == "video": # completion = extract_tag_content("html", completions[0]) + ## Post-processing + # Strip the completion of everything except the HTML content completions = [extract_html_content(completion) for completion in completions] # Write the messages dict into a log so that we can debug later write_logs(prompt_messages, completions[0]) - try: - image_generation_tasks = [ - process_completion( - completion, - index, - should_generate_images, - openai_api_key, - openai_base_url, - image_cache, - send_message, - ) - for index, completion in enumerate(completions) - ] - await asyncio.gather(*image_generation_tasks) - except Exception as e: - traceback.print_exc() - print("An error occurred during image generation and processing", e) - await send_message( - "status", "An error occurred during image generation and processing.", 0 + ## Image Generation + + for index, _ in enumerate(completions): + await send_message("status", "Generating images...", index) + + image_generation_tasks = [ + perform_image_generation( + completion, + should_generate_images, + openai_api_key, + openai_base_url, + image_cache, ) + for completion in completions + ] + + updated_completions = await asyncio.gather(*image_generation_tasks) + + for index, updated_html in enumerate(updated_completions): + await send_message("setCode", updated_html, index) + await send_message("status", "Code generation complete.", index) await websocket.close()