parallelize just image generation
This commit is contained in:
parent
701d97ec74
commit
5c3f915bce
@ -16,7 +16,7 @@ from llm import (
|
|||||||
)
|
)
|
||||||
from fs_logging.core import write_logs
|
from fs_logging.core import write_logs
|
||||||
from mock_llm import mock_completion
|
from mock_llm import mock_completion
|
||||||
from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args
|
from typing import Any, Coroutine, Dict, List, Literal, Union, cast, get_args
|
||||||
from image_generation import generate_images
|
from image_generation import generate_images
|
||||||
from prompts import create_prompt
|
from prompts import create_prompt
|
||||||
from prompts.claude_prompts import VIDEO_PROMPT
|
from prompts.claude_prompts import VIDEO_PROMPT
|
||||||
@ -43,40 +43,27 @@ def auto_upgrade_model(code_generation_model: Llm) -> Llm:
|
|||||||
return code_generation_model
|
return code_generation_model
|
||||||
|
|
||||||
|
|
||||||
# Generate images and return updated completions
|
# Generate images, if needed
|
||||||
async def process_completion(
|
async def perform_image_generation(
|
||||||
completion: str,
|
completion: str,
|
||||||
index: int,
|
|
||||||
should_generate_images: bool,
|
should_generate_images: bool,
|
||||||
openai_api_key: str | None,
|
openai_api_key: str | None,
|
||||||
openai_base_url: str | None,
|
openai_base_url: str | None,
|
||||||
image_cache: dict[str, str],
|
image_cache: dict[str, str],
|
||||||
send_message: Callable[
|
|
||||||
[Literal["chunk", "status", "setCode", "error"], str, int],
|
|
||||||
Coroutine[Any, Any, None],
|
|
||||||
],
|
|
||||||
):
|
):
|
||||||
try:
|
if not should_generate_images:
|
||||||
if should_generate_images and openai_api_key:
|
return completion
|
||||||
await send_message("status", "Generating images...", index)
|
|
||||||
updated_html = await generate_images(
|
if not openai_api_key:
|
||||||
|
print("No OpenAI API key found. Skipping image generation.")
|
||||||
|
return completion
|
||||||
|
|
||||||
|
return await generate_images(
|
||||||
completion,
|
completion,
|
||||||
api_key=openai_api_key,
|
api_key=openai_api_key,
|
||||||
base_url=openai_base_url,
|
base_url=openai_base_url,
|
||||||
image_cache=image_cache,
|
image_cache=image_cache,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
updated_html = completion
|
|
||||||
|
|
||||||
await send_message("setCode", updated_html, index)
|
|
||||||
await send_message("status", "Code generation complete.", index)
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
print(f"Image generation failed for variant {index}", e)
|
|
||||||
await send_message("setCode", completion, index)
|
|
||||||
await send_message(
|
|
||||||
"status", "Image generation failed but code is complete.", index
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/generate-code")
|
@router.websocket("/generate-code")
|
||||||
@ -85,7 +72,6 @@ async def stream_code(websocket: WebSocket):
|
|||||||
|
|
||||||
print("Incoming websocket connection...")
|
print("Incoming websocket connection...")
|
||||||
|
|
||||||
|
|
||||||
## Communication protocol setup
|
## Communication protocol setup
|
||||||
async def throw_error(
|
async def throw_error(
|
||||||
message: str,
|
message: str,
|
||||||
@ -328,31 +314,34 @@ async def stream_code(websocket: WebSocket):
|
|||||||
# if validated_input_mode == "video":
|
# if validated_input_mode == "video":
|
||||||
# completion = extract_tag_content("html", completions[0])
|
# completion = extract_tag_content("html", completions[0])
|
||||||
|
|
||||||
|
## Post-processing
|
||||||
|
|
||||||
# Strip the completion of everything except the HTML content
|
# Strip the completion of everything except the HTML content
|
||||||
completions = [extract_html_content(completion) for completion in completions]
|
completions = [extract_html_content(completion) for completion in completions]
|
||||||
|
|
||||||
# Write the messages dict into a log so that we can debug later
|
# Write the messages dict into a log so that we can debug later
|
||||||
write_logs(prompt_messages, completions[0])
|
write_logs(prompt_messages, completions[0])
|
||||||
|
|
||||||
try:
|
## Image Generation
|
||||||
|
|
||||||
|
for index, _ in enumerate(completions):
|
||||||
|
await send_message("status", "Generating images...", index)
|
||||||
|
|
||||||
image_generation_tasks = [
|
image_generation_tasks = [
|
||||||
process_completion(
|
perform_image_generation(
|
||||||
completion,
|
completion,
|
||||||
index,
|
|
||||||
should_generate_images,
|
should_generate_images,
|
||||||
openai_api_key,
|
openai_api_key,
|
||||||
openai_base_url,
|
openai_base_url,
|
||||||
image_cache,
|
image_cache,
|
||||||
send_message,
|
|
||||||
)
|
)
|
||||||
for index, completion in enumerate(completions)
|
for completion in completions
|
||||||
]
|
]
|
||||||
await asyncio.gather(*image_generation_tasks)
|
|
||||||
except Exception as e:
|
updated_completions = await asyncio.gather(*image_generation_tasks)
|
||||||
traceback.print_exc()
|
|
||||||
print("An error occurred during image generation and processing", e)
|
for index, updated_html in enumerate(updated_completions):
|
||||||
await send_message(
|
await send_message("setCode", updated_html, index)
|
||||||
"status", "An error occurred during image generation and processing.", 0
|
await send_message("status", "Code generation complete.", index)
|
||||||
)
|
|
||||||
|
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user