standardize to using typed send_message
This commit is contained in:
parent
46c480931a
commit
0700de7767
@ -8,7 +8,7 @@ STREAM_CHUNK_SIZE = 20
|
|||||||
|
|
||||||
|
|
||||||
async def mock_completion(
|
async def mock_completion(
|
||||||
process_chunk: Callable[[str], Awaitable[None]], input_mode: InputMode
|
process_chunk: Callable[[str, int], Awaitable[None]], input_mode: InputMode
|
||||||
) -> str:
|
) -> str:
|
||||||
code_to_return = (
|
code_to_return = (
|
||||||
TALLY_FORM_VIDEO_PROMPT_MOCK
|
TALLY_FORM_VIDEO_PROMPT_MOCK
|
||||||
@ -17,7 +17,7 @@ async def mock_completion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE):
|
for i in range(0, len(code_to_return), STREAM_CHUNK_SIZE):
|
||||||
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE])
|
await process_chunk(code_to_return[i : i + STREAM_CHUNK_SIZE], 0)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
if input_mode == "video":
|
if input_mode == "video":
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from llm import (
|
|||||||
)
|
)
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
from mock_llm import mock_completion
|
from mock_llm import mock_completion
|
||||||
from typing import Any, Coroutine, Dict, List, Union, cast, get_args
|
from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args
|
||||||
from image_generation import create_alt_url_mapping, generate_images
|
from image_generation import create_alt_url_mapping, generate_images
|
||||||
from prompts import assemble_imported_code_prompt, assemble_prompt
|
from prompts import assemble_imported_code_prompt, assemble_prompt
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -52,23 +52,20 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st
|
|||||||
|
|
||||||
# Generate images and return updated completions
|
# Generate images and return updated completions
|
||||||
async def process_completion(
|
async def process_completion(
|
||||||
websocket: WebSocket,
|
|
||||||
completion: str,
|
completion: str,
|
||||||
index: int,
|
index: int,
|
||||||
should_generate_images: bool,
|
should_generate_images: bool,
|
||||||
openai_api_key: str | None,
|
openai_api_key: str | None,
|
||||||
openai_base_url: str | None,
|
openai_base_url: str | None,
|
||||||
image_cache: dict[str, str],
|
image_cache: dict[str, str],
|
||||||
|
send_message: Callable[
|
||||||
|
[Literal["chunk", "status", "setCode", "error"], str, int],
|
||||||
|
Coroutine[Any, Any, None],
|
||||||
|
],
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if should_generate_images and openai_api_key:
|
if should_generate_images and openai_api_key:
|
||||||
await websocket.send_json(
|
await send_message("status", "Generating images...", index)
|
||||||
{
|
|
||||||
"type": "status",
|
|
||||||
"value": f"Generating images...",
|
|
||||||
"variantIndex": index,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
updated_html = await generate_images(
|
updated_html = await generate_images(
|
||||||
completion,
|
completion,
|
||||||
api_key=openai_api_key,
|
api_key=openai_api_key,
|
||||||
@ -78,28 +75,14 @@ async def process_completion(
|
|||||||
else:
|
else:
|
||||||
updated_html = completion
|
updated_html = completion
|
||||||
|
|
||||||
await websocket.send_json(
|
await send_message("setCode", updated_html, index)
|
||||||
{"type": "setCode", "value": updated_html, "variantIndex": index}
|
await send_message("status", "Code generation complete.", index)
|
||||||
)
|
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"type": "status",
|
|
||||||
"value": f"Code generation complete.",
|
|
||||||
"variantIndex": index,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print(f"Image generation failed for variant {index}", e)
|
print(f"Image generation failed for variant {index}", e)
|
||||||
await websocket.send_json(
|
await send_message("setCode", completion, index)
|
||||||
{"type": "setCode", "value": completion, "variantIndex": index}
|
await send_message(
|
||||||
)
|
"status", "Image generation failed but code is complete.", index
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"type": "status",
|
|
||||||
"value": f"Image generation failed but code is complete.",
|
|
||||||
"variantIndex": index,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -115,6 +98,15 @@ async def stream_code(websocket: WebSocket):
|
|||||||
await websocket.send_json({"type": "error", "value": message})
|
await websocket.send_json({"type": "error", "value": message})
|
||||||
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
|
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
type: Literal["chunk", "status", "setCode", "error"],
|
||||||
|
value: str,
|
||||||
|
variantIndex: int,
|
||||||
|
):
|
||||||
|
await websocket.send_json(
|
||||||
|
{"type": type, "value": value, "variantIndex": variantIndex}
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: Are the values always strings?
|
# TODO: Are the values always strings?
|
||||||
params: Dict[str, str] = await websocket.receive_json()
|
params: Dict[str, str] = await websocket.receive_json()
|
||||||
|
|
||||||
@ -216,17 +208,11 @@ async def stream_code(websocket: WebSocket):
|
|||||||
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
|
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
|
||||||
|
|
||||||
print("generating code...")
|
print("generating code...")
|
||||||
await websocket.send_json(
|
await send_message("status", "Generating code...", 0)
|
||||||
{"type": "status", "value": "Generating code...", "variantIndex": 0}
|
await send_message("status", "Generating code...", 1)
|
||||||
)
|
|
||||||
await websocket.send_json(
|
|
||||||
{"type": "status", "value": "Generating code...", "variantIndex": 1}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def process_chunk(content: str, variantIndex: int = 0):
|
async def process_chunk(content: str, variantIndex: int):
|
||||||
await websocket.send_json(
|
await send_message("chunk", content, variantIndex)
|
||||||
{"type": "chunk", "value": content, "variantIndex": variantIndex}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Image cache for updates so that we don't have to regenerate images
|
# Image cache for updates so that we don't have to regenerate images
|
||||||
image_cache: Dict[str, str] = {}
|
image_cache: Dict[str, str] = {}
|
||||||
@ -259,6 +245,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
else:
|
else:
|
||||||
prompt_messages = assemble_prompt(params["image"], valid_stack)
|
prompt_messages = assemble_prompt(params["image"], valid_stack)
|
||||||
except:
|
except:
|
||||||
|
# TODO: This should use variantIndex
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
"type": "error",
|
"type": "error",
|
||||||
@ -310,7 +297,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
system_prompt=VIDEO_PROMPT,
|
system_prompt=VIDEO_PROMPT,
|
||||||
messages=prompt_messages, # type: ignore
|
messages=prompt_messages, # type: ignore
|
||||||
api_key=anthropic_api_key,
|
api_key=anthropic_api_key,
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x: process_chunk(x, 0),
|
||||||
model=Llm.CLAUDE_3_OPUS,
|
model=Llm.CLAUDE_3_OPUS,
|
||||||
include_thinking=True,
|
include_thinking=True,
|
||||||
)
|
)
|
||||||
@ -412,13 +399,13 @@ async def stream_code(websocket: WebSocket):
|
|||||||
try:
|
try:
|
||||||
image_generation_tasks = [
|
image_generation_tasks = [
|
||||||
process_completion(
|
process_completion(
|
||||||
websocket,
|
|
||||||
completion,
|
completion,
|
||||||
index,
|
index,
|
||||||
should_generate_images,
|
should_generate_images,
|
||||||
openai_api_key,
|
openai_api_key,
|
||||||
openai_base_url,
|
openai_base_url,
|
||||||
image_cache,
|
image_cache,
|
||||||
|
send_message,
|
||||||
)
|
)
|
||||||
for index, completion in enumerate(completions)
|
for index, completion in enumerate(completions)
|
||||||
]
|
]
|
||||||
@ -426,12 +413,8 @@ async def stream_code(websocket: WebSocket):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print("An error occurred during image generation and processing", e)
|
print("An error occurred during image generation and processing", e)
|
||||||
await websocket.send_json(
|
await send_message(
|
||||||
{
|
"status", "An error occurred during image generation and processing.", 0
|
||||||
"type": "status",
|
|
||||||
"value": "An error occurred during image generation and processing.",
|
|
||||||
"variantIndex": 0,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user