standardize to using typed send_message

This commit is contained in:
Abi Raja 2024-07-30 16:27:04 -04:00
parent 46c480931a
commit 0700de7767
2 changed files with 31 additions and 48 deletions

View File

@ -8,7 +8,7 @@ STREAM_CHUNK_SIZE = 20
async def mock_completion( async def mock_completion(
process_chunk: Callable[[str], Awaitable[None]], input_mode: InputMode process_chunk: Callable[[str, int], Awaitable[None]], input_mode: InputMode
) -> str: ) -> str:
code_to_return = ( code_to_return = (
TALLY_FORM_VIDEO_PROMPT_MOCK TALLY_FORM_VIDEO_PROMPT_MOCK
@ -17,7 +17,7 @@ async def mock_completion(
) )
for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE): for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE):
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE]) await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE], 0)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
if input_mode == "video": if input_mode == "video":

View File

@ -15,7 +15,7 @@ from llm import (
) )
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion from mock_llm import mock_completion
from typing import Any, Coroutine, Dict, List, Union, cast, get_args from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args
from image_generation import create_alt_url_mapping, generate_images from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_imported_code_prompt, assemble_prompt from prompts import assemble_imported_code_prompt, assemble_prompt
from datetime import datetime from datetime import datetime
@ -52,23 +52,20 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st
# Generate images and return updated completions # Generate images and return updated completions
async def process_completion( async def process_completion(
websocket: WebSocket,
completion: str, completion: str,
index: int, index: int,
should_generate_images: bool, should_generate_images: bool,
openai_api_key: str | None, openai_api_key: str | None,
openai_base_url: str | None, openai_base_url: str | None,
image_cache: dict[str, str], image_cache: dict[str, str],
send_message: Callable[
[Literal["chunk", "status", "setCode", "error"], str, int],
Coroutine[Any, Any, None],
],
): ):
try: try:
if should_generate_images and openai_api_key: if should_generate_images and openai_api_key:
await websocket.send_json( await send_message("status", "Generating images...", index)
{
"type": "status",
"value": f"Generating images...",
"variantIndex": index,
}
)
updated_html = await generate_images( updated_html = await generate_images(
completion, completion,
api_key=openai_api_key, api_key=openai_api_key,
@ -78,28 +75,14 @@ async def process_completion(
else: else:
updated_html = completion updated_html = completion
await websocket.send_json( await send_message("setCode", updated_html, index)
{"type": "setCode", "value": updated_html, "variantIndex": index} await send_message("status", "Code generation complete.", index)
)
await websocket.send_json(
{
"type": "status",
"value": f"Code generation complete.",
"variantIndex": index,
}
)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
print(f"Image generation failed for variant {index}", e) print(f"Image generation failed for variant {index}", e)
await websocket.send_json( await send_message("setCode", completion, index)
{"type": "setCode", "value": completion, "variantIndex": index} await send_message(
) "status", "Image generation failed but code is complete.", index
await websocket.send_json(
{
"type": "status",
"value": f"Image generation failed but code is complete.",
"variantIndex": index,
}
) )
@ -115,6 +98,15 @@ async def stream_code(websocket: WebSocket):
await websocket.send_json({"type": "error", "value": message}) await websocket.send_json({"type": "error", "value": message})
await websocket.close(APP_ERROR_WEB_SOCKET_CODE) await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
async def send_message(
type: Literal["chunk", "status", "setCode", "error"],
value: str,
variantIndex: int,
):
await websocket.send_json(
{"type": type, "value": value, "variantIndex": variantIndex}
)
# TODO: Are the values always strings? # TODO: Are the values always strings?
params: Dict[str, str] = await websocket.receive_json() params: Dict[str, str] = await websocket.receive_json()
@ -216,17 +208,11 @@ async def stream_code(websocket: WebSocket):
should_generate_images = bool(params.get("isImageGenerationEnabled", True)) should_generate_images = bool(params.get("isImageGenerationEnabled", True))
print("generating code...") print("generating code...")
await websocket.send_json( await send_message("status", "Generating code...", 0)
{"type": "status", "value": "Generating code...", "variantIndex": 0} await send_message("status", "Generating code...", 1)
)
await websocket.send_json(
{"type": "status", "value": "Generating code...", "variantIndex": 1}
)
async def process_chunk(content: str, variantIndex: int = 0): async def process_chunk(content: str, variantIndex: int):
await websocket.send_json( await send_message("chunk", content, variantIndex)
{"type": "chunk", "value": content, "variantIndex": variantIndex}
)
# Image cache for updates so that we don't have to regenerate images # Image cache for updates so that we don't have to regenerate images
image_cache: Dict[str, str] = {} image_cache: Dict[str, str] = {}
@ -259,6 +245,7 @@ async def stream_code(websocket: WebSocket):
else: else:
prompt_messages = assemble_prompt(params["image"], valid_stack) prompt_messages = assemble_prompt(params["image"], valid_stack)
except: except:
# TODO: This should use variantIndex
await websocket.send_json( await websocket.send_json(
{ {
"type": "error", "type": "error",
@ -310,7 +297,7 @@ async def stream_code(websocket: WebSocket):
system_prompt=VIDEO_PROMPT, system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore messages=prompt_messages, # type: ignore
api_key=anthropic_api_key, api_key=anthropic_api_key,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x, 0),
model=Llm.CLAUDE_3_OPUS, model=Llm.CLAUDE_3_OPUS,
include_thinking=True, include_thinking=True,
) )
@ -412,13 +399,13 @@ async def stream_code(websocket: WebSocket):
try: try:
image_generation_tasks = [ image_generation_tasks = [
process_completion( process_completion(
websocket,
completion, completion,
index, index,
should_generate_images, should_generate_images,
openai_api_key, openai_api_key,
openai_base_url, openai_base_url,
image_cache, image_cache,
send_message,
) )
for index, completion in enumerate(completions) for index, completion in enumerate(completions)
] ]
@ -426,12 +413,8 @@ async def stream_code(websocket: WebSocket):
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
print("An error occurred during image generation and processing", e) print("An error occurred during image generation and processing", e)
await websocket.send_json( await send_message(
{ "status", "An error occurred during image generation and processing.", 0
"type": "status",
"value": "An error occurred during image generation and processing.",
"variantIndex": 0,
}
) )
await websocket.close() await websocket.close()