standardize to using typed send_message

This commit is contained in:
Abi Raja 2024-07-30 16:27:04 -04:00
parent 46c480931a
commit 0700de7767
2 changed files with 31 additions and 48 deletions

View File

@ -8,7 +8,7 @@ STREAM_CHUNK_SIZE = 20
async def mock_completion(
process_chunk: Callable[[str], Awaitable[None]], input_mode: InputMode
process_chunk: Callable[[str, int], Awaitable[None]], input_mode: InputMode
) -> str:
code_to_return = (
TALLY_FORM_VIDEO_PROMPT_MOCK
@ -17,7 +17,7 @@ async def mock_completion(
)
for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE):
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE])
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE], 0)
await asyncio.sleep(0.01)
if input_mode == "video":

View File

@ -15,7 +15,7 @@ from llm import (
)
from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
from typing import Any, Coroutine, Dict, List, Union, cast, get_args
from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args
from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_imported_code_prompt, assemble_prompt
from datetime import datetime
@ -52,23 +52,20 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st
# Generate images and return updated completions
async def process_completion(
websocket: WebSocket,
completion: str,
index: int,
should_generate_images: bool,
openai_api_key: str | None,
openai_base_url: str | None,
image_cache: dict[str, str],
send_message: Callable[
[Literal["chunk", "status", "setCode", "error"], str, int],
Coroutine[Any, Any, None],
],
):
try:
if should_generate_images and openai_api_key:
await websocket.send_json(
{
"type": "status",
"value": f"Generating images...",
"variantIndex": index,
}
)
await send_message("status", "Generating images...", index)
updated_html = await generate_images(
completion,
api_key=openai_api_key,
@ -78,28 +75,14 @@ async def process_completion(
else:
updated_html = completion
await websocket.send_json(
{"type": "setCode", "value": updated_html, "variantIndex": index}
)
await websocket.send_json(
{
"type": "status",
"value": f"Code generation complete.",
"variantIndex": index,
}
)
await send_message("setCode", updated_html, index)
await send_message("status", "Code generation complete.", index)
except Exception as e:
traceback.print_exc()
print(f"Image generation failed for variant {index}", e)
await websocket.send_json(
{"type": "setCode", "value": completion, "variantIndex": index}
)
await websocket.send_json(
{
"type": "status",
"value": f"Image generation failed but code is complete.",
"variantIndex": index,
}
await send_message("setCode", completion, index)
await send_message(
"status", "Image generation failed but code is complete.", index
)
@ -115,6 +98,15 @@ async def stream_code(websocket: WebSocket):
await websocket.send_json({"type": "error", "value": message})
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
async def send_message(
type: Literal["chunk", "status", "setCode", "error"],
value: str,
variantIndex: int,
):
await websocket.send_json(
{"type": type, "value": value, "variantIndex": variantIndex}
)
# TODO: Are the values always strings?
params: Dict[str, str] = await websocket.receive_json()
@ -216,17 +208,11 @@ async def stream_code(websocket: WebSocket):
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
print("generating code...")
await websocket.send_json(
{"type": "status", "value": "Generating code...", "variantIndex": 0}
)
await websocket.send_json(
{"type": "status", "value": "Generating code...", "variantIndex": 1}
)
await send_message("status", "Generating code...", 0)
await send_message("status", "Generating code...", 1)
async def process_chunk(content: str, variantIndex: int = 0):
await websocket.send_json(
{"type": "chunk", "value": content, "variantIndex": variantIndex}
)
async def process_chunk(content: str, variantIndex: int):
await send_message("chunk", content, variantIndex)
# Image cache for updates so that we don't have to regenerate images
image_cache: Dict[str, str] = {}
@ -259,6 +245,7 @@ async def stream_code(websocket: WebSocket):
else:
prompt_messages = assemble_prompt(params["image"], valid_stack)
except:
# TODO: This should use variantIndex
await websocket.send_json(
{
"type": "error",
@ -310,7 +297,7 @@ async def stream_code(websocket: WebSocket):
system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
callback=lambda x: process_chunk(x, 0),
model=Llm.CLAUDE_3_OPUS,
include_thinking=True,
)
@ -412,13 +399,13 @@ async def stream_code(websocket: WebSocket):
try:
image_generation_tasks = [
process_completion(
websocket,
completion,
index,
should_generate_images,
openai_api_key,
openai_base_url,
image_cache,
send_message,
)
for index, completion in enumerate(completions)
]
@ -426,12 +413,8 @@ async def stream_code(websocket: WebSocket):
except Exception as e:
traceback.print_exc()
print("An error occurred during image generation and processing", e)
await websocket.send_json(
{
"type": "status",
"value": "An error occurred during image generation and processing.",
"variantIndex": 0,
}
await send_message(
"status", "An error occurred during image generation and processing.", 0
)
await websocket.close()