parallelize just image generation

This commit is contained in:
Abi Raja 2024-07-31 13:36:22 -04:00
parent 701d97ec74
commit 5c3f915bce

View File

@ -16,7 +16,7 @@ from llm import (
)
from fs_logging.core import write_logs
from mock_llm import mock_completion
from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args
from typing import Any, Coroutine, Dict, List, Literal, Union, cast, get_args
from image_generation import generate_images
from prompts import create_prompt
from prompts.claude_prompts import VIDEO_PROMPT
@ -43,40 +43,27 @@ def auto_upgrade_model(code_generation_model: Llm) -> Llm:
return code_generation_model
# Generate images and return updated completions
async def process_completion(
# Generate images, if needed
async def perform_image_generation(
completion: str,
index: int,
should_generate_images: bool,
openai_api_key: str | None,
openai_base_url: str | None,
image_cache: dict[str, str],
send_message: Callable[
[Literal["chunk", "status", "setCode", "error"], str, int],
Coroutine[Any, Any, None],
],
):
try:
if should_generate_images and openai_api_key:
await send_message("status", "Generating images...", index)
updated_html = await generate_images(
completion,
api_key=openai_api_key,
base_url=openai_base_url,
image_cache=image_cache,
)
else:
updated_html = completion
if not should_generate_images:
return completion
await send_message("setCode", updated_html, index)
await send_message("status", "Code generation complete.", index)
except Exception as e:
traceback.print_exc()
print(f"Image generation failed for variant {index}", e)
await send_message("setCode", completion, index)
await send_message(
"status", "Image generation failed but code is complete.", index
)
if not openai_api_key:
print("No OpenAI API key found. Skipping image generation.")
return completion
return await generate_images(
completion,
api_key=openai_api_key,
base_url=openai_base_url,
image_cache=image_cache,
)
@router.websocket("/generate-code")
@ -85,7 +72,6 @@ async def stream_code(websocket: WebSocket):
print("Incoming websocket connection...")
## Communication protocol setup
async def throw_error(
message: str,
@ -328,31 +314,34 @@ async def stream_code(websocket: WebSocket):
# if validated_input_mode == "video":
# completion = extract_tag_content("html", completions[0])
## Post-processing
# Strip the completion of everything except the HTML content
completions = [extract_html_content(completion) for completion in completions]
# Write the messages dict into a log so that we can debug later
write_logs(prompt_messages, completions[0])
try:
image_generation_tasks = [
process_completion(
completion,
index,
should_generate_images,
openai_api_key,
openai_base_url,
image_cache,
send_message,
)
for index, completion in enumerate(completions)
]
await asyncio.gather(*image_generation_tasks)
except Exception as e:
traceback.print_exc()
print("An error occurred during image generation and processing", e)
await send_message(
"status", "An error occurred during image generation and processing.", 0
## Image Generation
for index, _ in enumerate(completions):
await send_message("status", "Generating images...", index)
image_generation_tasks = [
perform_image_generation(
completion,
should_generate_images,
openai_api_key,
openai_base_url,
image_cache,
)
for completion in completions
]
updated_completions = await asyncio.gather(*image_generation_tasks)
for index, updated_html in enumerate(updated_completions):
await send_message("setCode", updated_html, index)
await send_message("status", "Code generation complete.", index)
await websocket.close()