standardize to using typed send_message
This commit is contained in:
parent
46c480931a
commit
0700de7767
@ -8,7 +8,7 @@ STREAM_CHUNK_SIZE = 20
|
||||
|
||||
|
||||
async def mock_completion(
|
||||
process_chunk: Callable[[str], Awaitable[None]], input_mode: InputMode
|
||||
process_chunk: Callable[[str, int], Awaitable[None]], input_mode: InputMode
|
||||
) -> str:
|
||||
code_to_return = (
|
||||
TALLY_FORM_VIDEO_PROMPT_MOCK
|
||||
@ -17,7 +17,7 @@ async def mock_completion(
|
||||
)
|
||||
|
||||
for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE):
|
||||
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE])
|
||||
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE], 0)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
if input_mode == "video":
|
||||
|
||||
@ -15,7 +15,7 @@ from llm import (
|
||||
)
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from mock_llm import mock_completion
|
||||
from typing import Any, Coroutine, Dict, List, Union, cast, get_args
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args
|
||||
from image_generation import create_alt_url_mapping, generate_images
|
||||
from prompts import assemble_imported_code_prompt, assemble_prompt
|
||||
from datetime import datetime
|
||||
@ -52,23 +52,20 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st
|
||||
|
||||
# Generate images and return updated completions
|
||||
async def process_completion(
|
||||
websocket: WebSocket,
|
||||
completion: str,
|
||||
index: int,
|
||||
should_generate_images: bool,
|
||||
openai_api_key: str | None,
|
||||
openai_base_url: str | None,
|
||||
image_cache: dict[str, str],
|
||||
send_message: Callable[
|
||||
[Literal["chunk", "status", "setCode", "error"], str, int],
|
||||
Coroutine[Any, Any, None],
|
||||
],
|
||||
):
|
||||
try:
|
||||
if should_generate_images and openai_api_key:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "status",
|
||||
"value": f"Generating images...",
|
||||
"variantIndex": index,
|
||||
}
|
||||
)
|
||||
await send_message("status", "Generating images...", index)
|
||||
updated_html = await generate_images(
|
||||
completion,
|
||||
api_key=openai_api_key,
|
||||
@ -78,28 +75,14 @@ async def process_completion(
|
||||
else:
|
||||
updated_html = completion
|
||||
|
||||
await websocket.send_json(
|
||||
{"type": "setCode", "value": updated_html, "variantIndex": index}
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "status",
|
||||
"value": f"Code generation complete.",
|
||||
"variantIndex": index,
|
||||
}
|
||||
)
|
||||
await send_message("setCode", updated_html, index)
|
||||
await send_message("status", "Code generation complete.", index)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Image generation failed for variant {index}", e)
|
||||
await websocket.send_json(
|
||||
{"type": "setCode", "value": completion, "variantIndex": index}
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "status",
|
||||
"value": f"Image generation failed but code is complete.",
|
||||
"variantIndex": index,
|
||||
}
|
||||
await send_message("setCode", completion, index)
|
||||
await send_message(
|
||||
"status", "Image generation failed but code is complete.", index
|
||||
)
|
||||
|
||||
|
||||
@ -115,6 +98,15 @@ async def stream_code(websocket: WebSocket):
|
||||
await websocket.send_json({"type": "error", "value": message})
|
||||
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
|
||||
|
||||
async def send_message(
|
||||
type: Literal["chunk", "status", "setCode", "error"],
|
||||
value: str,
|
||||
variantIndex: int,
|
||||
):
|
||||
await websocket.send_json(
|
||||
{"type": type, "value": value, "variantIndex": variantIndex}
|
||||
)
|
||||
|
||||
# TODO: Are the values always strings?
|
||||
params: Dict[str, str] = await websocket.receive_json()
|
||||
|
||||
@ -216,17 +208,11 @@ async def stream_code(websocket: WebSocket):
|
||||
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
|
||||
|
||||
print("generating code...")
|
||||
await websocket.send_json(
|
||||
{"type": "status", "value": "Generating code...", "variantIndex": 0}
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "status", "value": "Generating code...", "variantIndex": 1}
|
||||
)
|
||||
await send_message("status", "Generating code...", 0)
|
||||
await send_message("status", "Generating code...", 1)
|
||||
|
||||
async def process_chunk(content: str, variantIndex: int = 0):
|
||||
await websocket.send_json(
|
||||
{"type": "chunk", "value": content, "variantIndex": variantIndex}
|
||||
)
|
||||
async def process_chunk(content: str, variantIndex: int):
|
||||
await send_message("chunk", content, variantIndex)
|
||||
|
||||
# Image cache for updates so that we don't have to regenerate images
|
||||
image_cache: Dict[str, str] = {}
|
||||
@ -259,6 +245,7 @@ async def stream_code(websocket: WebSocket):
|
||||
else:
|
||||
prompt_messages = assemble_prompt(params["image"], valid_stack)
|
||||
except:
|
||||
# TODO: This should use variantIndex
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
@ -310,7 +297,7 @@ async def stream_code(websocket: WebSocket):
|
||||
system_prompt=VIDEO_PROMPT,
|
||||
messages=prompt_messages, # type: ignore
|
||||
api_key=anthropic_api_key,
|
||||
callback=lambda x: process_chunk(x),
|
||||
callback=lambda x: process_chunk(x, 0),
|
||||
model=Llm.CLAUDE_3_OPUS,
|
||||
include_thinking=True,
|
||||
)
|
||||
@ -412,13 +399,13 @@ async def stream_code(websocket: WebSocket):
|
||||
try:
|
||||
image_generation_tasks = [
|
||||
process_completion(
|
||||
websocket,
|
||||
completion,
|
||||
index,
|
||||
should_generate_images,
|
||||
openai_api_key,
|
||||
openai_base_url,
|
||||
image_cache,
|
||||
send_message,
|
||||
)
|
||||
for index, completion in enumerate(completions)
|
||||
]
|
||||
@ -426,12 +413,8 @@ async def stream_code(websocket: WebSocket):
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print("An error occurred during image generation and processing", e)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "status",
|
||||
"value": "An error occurred during image generation and processing.",
|
||||
"variantIndex": 0,
|
||||
}
|
||||
await send_message(
|
||||
"status", "An error occurred during image generation and processing.", 0
|
||||
)
|
||||
|
||||
await websocket.close()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user