parallelize just image generation
This commit is contained in:
parent
701d97ec74
commit
5c3f915bce
@ -16,7 +16,7 @@ from llm import (
|
||||
)
|
||||
from fs_logging.core import write_logs
|
||||
from mock_llm import mock_completion
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args
|
||||
from typing import Any, Coroutine, Dict, List, Literal, Union, cast, get_args
|
||||
from image_generation import generate_images
|
||||
from prompts import create_prompt
|
||||
from prompts.claude_prompts import VIDEO_PROMPT
|
||||
@ -43,40 +43,27 @@ def auto_upgrade_model(code_generation_model: Llm) -> Llm:
|
||||
return code_generation_model
|
||||
|
||||
|
||||
# Generate images and return updated completions
|
||||
async def process_completion(
|
||||
# Generate images, if needed
|
||||
async def perform_image_generation(
|
||||
completion: str,
|
||||
index: int,
|
||||
should_generate_images: bool,
|
||||
openai_api_key: str | None,
|
||||
openai_base_url: str | None,
|
||||
image_cache: dict[str, str],
|
||||
send_message: Callable[
|
||||
[Literal["chunk", "status", "setCode", "error"], str, int],
|
||||
Coroutine[Any, Any, None],
|
||||
],
|
||||
):
|
||||
try:
|
||||
if should_generate_images and openai_api_key:
|
||||
await send_message("status", "Generating images...", index)
|
||||
updated_html = await generate_images(
|
||||
completion,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
image_cache=image_cache,
|
||||
)
|
||||
else:
|
||||
updated_html = completion
|
||||
if not should_generate_images:
|
||||
return completion
|
||||
|
||||
await send_message("setCode", updated_html, index)
|
||||
await send_message("status", "Code generation complete.", index)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Image generation failed for variant {index}", e)
|
||||
await send_message("setCode", completion, index)
|
||||
await send_message(
|
||||
"status", "Image generation failed but code is complete.", index
|
||||
)
|
||||
if not openai_api_key:
|
||||
print("No OpenAI API key found. Skipping image generation.")
|
||||
return completion
|
||||
|
||||
return await generate_images(
|
||||
completion,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
image_cache=image_cache,
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/generate-code")
|
||||
@ -85,7 +72,6 @@ async def stream_code(websocket: WebSocket):
|
||||
|
||||
print("Incoming websocket connection...")
|
||||
|
||||
|
||||
## Communication protocol setup
|
||||
async def throw_error(
|
||||
message: str,
|
||||
@ -110,7 +96,7 @@ async def stream_code(websocket: WebSocket):
|
||||
)
|
||||
|
||||
## Parameter validation
|
||||
|
||||
|
||||
# TODO: Are the values always strings?
|
||||
params: Dict[str, str] = await websocket.receive_json()
|
||||
print("Received params")
|
||||
@ -328,31 +314,34 @@ async def stream_code(websocket: WebSocket):
|
||||
# if validated_input_mode == "video":
|
||||
# completion = extract_tag_content("html", completions[0])
|
||||
|
||||
## Post-processing
|
||||
|
||||
# Strip the completion of everything except the HTML content
|
||||
completions = [extract_html_content(completion) for completion in completions]
|
||||
|
||||
# Write the messages dict into a log so that we can debug later
|
||||
write_logs(prompt_messages, completions[0])
|
||||
|
||||
try:
|
||||
image_generation_tasks = [
|
||||
process_completion(
|
||||
completion,
|
||||
index,
|
||||
should_generate_images,
|
||||
openai_api_key,
|
||||
openai_base_url,
|
||||
image_cache,
|
||||
send_message,
|
||||
)
|
||||
for index, completion in enumerate(completions)
|
||||
]
|
||||
await asyncio.gather(*image_generation_tasks)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print("An error occurred during image generation and processing", e)
|
||||
await send_message(
|
||||
"status", "An error occurred during image generation and processing.", 0
|
||||
## Image Generation
|
||||
|
||||
for index, _ in enumerate(completions):
|
||||
await send_message("status", "Generating images...", index)
|
||||
|
||||
image_generation_tasks = [
|
||||
perform_image_generation(
|
||||
completion,
|
||||
should_generate_images,
|
||||
openai_api_key,
|
||||
openai_base_url,
|
||||
image_cache,
|
||||
)
|
||||
for completion in completions
|
||||
]
|
||||
|
||||
updated_completions = await asyncio.gather(*image_generation_tasks)
|
||||
|
||||
for index, updated_html in enumerate(updated_completions):
|
||||
await send_message("setCode", updated_html, index)
|
||||
await send_message("status", "Code generation complete.", index)
|
||||
|
||||
await websocket.close()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user