From 6a28ee2d3cac362d8367c4efbf1d07f98c88743e Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Sat, 9 Dec 2023 15:34:16 -0500 Subject: [PATCH] strictly type python backend --- .vscode/settings.json | 3 +++ backend/image_generation.py | 20 +++++++++++--------- backend/llm.py | 13 +++++++------ backend/main.py | 27 +++++++++++++++++++-------- backend/{mock.py => mock_llm.py} | 3 ++- backend/prompts.py | 13 ++++++++++--- backend/routes/screenshot.py | 4 +++- backend/utils.py | 20 +++++++++++--------- 8 files changed, 66 insertions(+), 37 deletions(-) create mode 100644 .vscode/settings.json rename backend/{mock.py => mock_llm.py} (98%) diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..d6e2638 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.analysis.typeCheckingMode": "strict" +} diff --git a/backend/image_generation.py b/backend/image_generation.py index bb272f8..d3e71b1 100644 --- a/backend/image_generation.py +++ b/backend/image_generation.py @@ -1,15 +1,15 @@ import asyncio -import os import re +from typing import Dict, List, Union from openai import AsyncOpenAI from bs4 import BeautifulSoup -async def process_tasks(prompts, api_key, base_url): +async def process_tasks(prompts: List[str], api_key: str, base_url: str): tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts] results = await asyncio.gather(*tasks, return_exceptions=True) - processed_results = [] + processed_results: List[Union[str, None]] = [] for result in results: if isinstance(result, Exception): print(f"An exception occurred: {result}") @@ -20,9 +20,9 @@ async def process_tasks(prompts, api_key, base_url): return processed_results -async def generate_image(prompt, api_key, base_url): +async def generate_image(prompt: str, api_key: str, base_url: str): client = AsyncOpenAI(api_key=api_key, base_url=base_url) - image_params = { + image_params: Dict[str, Union[str, int]] = { "model": "dall-e-3", "quality": "standard", "style": "natural", @@ -35,7 +35,7 @@ async def generate_image(prompt, api_key, base_url): return res.data[0].url -def extract_dimensions(url): +def extract_dimensions(url: str): # Regular expression to match numbers in the format '300x200' matches = re.findall(r"(\d+)x(\d+)", url) @@ -48,11 +48,11 @@ def extract_dimensions(url): return (100, 100) -def create_alt_url_mapping(code): +def create_alt_url_mapping(code: str) -> Dict[str, str]: soup = BeautifulSoup(code, "html.parser") images = soup.find_all("img") - mapping = {} + mapping: Dict[str, str] = {} for image in images: if not image["src"].startswith("https://placehold.co"): @@ -61,7 +61,9 @@ def create_alt_url_mapping(code): return mapping -async def generate_images(code, api_key, base_url, image_cache): +async def generate_images( + code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str] +): # Find all images soup = BeautifulSoup(code, "html.parser") images = soup.find_all("img") diff --git a/backend/llm.py b/backend/llm.py index e2b41c4..66e3a47 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -1,16 +1,16 @@ -import os -from typing import Awaitable, Callable +from typing import Awaitable, Callable, List from openai import AsyncOpenAI +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk MODEL_GPT_4_VISION = "gpt-4-vision-preview" async def stream_openai_response( - messages, + messages: List[ChatCompletionMessageParam], api_key: str, base_url: str | None, callback: Callable[[str], Awaitable[None]], -): +) -> str: client = AsyncOpenAI(api_key=api_key, base_url=base_url) model = MODEL_GPT_4_VISION @@ -23,9 +23,10 @@ async def stream_openai_response( params["max_tokens"] = 4096 params["temperature"] = 0 - completion = await client.chat.completions.create(**params) + stream = await client.chat.completions.create(**params) # type: ignore full_response = "" - async for chunk in completion: + async for chunk in stream: # type: ignore + assert isinstance(chunk, ChatCompletionChunk) content = chunk.choices[0].delta.content or "" full_response += content await callback(content) diff --git a/backend/main.py b/backend/main.py index 517eef7..593ec3a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,8 +14,10 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse import openai from llm import stream_openai_response -from mock import mock_completion +from openai.types.chat import ChatCompletionMessageParam +from mock_llm import mock_completion from utils import pprint_prompt +from typing import Dict, List from image_generation import create_alt_url_mapping, generate_images from prompts import assemble_prompt from routes import screenshot @@ -53,7 +55,7 @@ async def get_status(): ) -def write_logs(prompt_messages, completion): +def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str): # Get the logs path from environment, default to the current working directory logs_path = os.environ.get("LOGS_PATH", os.getcwd()) @@ -84,7 +86,8 @@ async def stream_code(websocket: WebSocket): await websocket.send_json({"type": "error", "value": message}) await websocket.close() - params = await websocket.receive_json() + # TODO: Are the values always strings? + params: Dict[str, str] = await websocket.receive_json() print("Received params") @@ -154,7 +157,7 @@ async def stream_code(websocket: WebSocket): print("generating code...") await websocket.send_json({"type": "status", "value": "Generating code..."}) - async def process_chunk(content): + async def process_chunk(content: str): await websocket.send_json({"type": "chunk", "value": content}) # Assemble the prompt @@ -176,15 +179,23 @@ async def stream_code(websocket: WebSocket): return # Image cache for updates so that we don't have to regenerate images - image_cache = {} + image_cache: Dict[str, str] = {} if params["generationType"] == "update": # Transform into message format # TODO: Move this to frontend for index, text in enumerate(params["history"]): - prompt_messages += [ - {"role": "assistant" if index % 2 == 0 else "user", "content": text} - ] + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + prompt_messages.append(message) image_cache = create_alt_url_mapping(params["history"][-2]) if SHOULD_MOCK_AI_RESPONSE: diff --git a/backend/mock.py b/backend/mock_llm.py similarity index 98% rename from backend/mock.py rename to backend/mock_llm.py index 90dc7d3..0102bad 100644 --- a/backend/mock.py +++ b/backend/mock_llm.py @@ -1,7 +1,8 @@ import asyncio +from typing import Awaitable, Callable -async def mock_completion(process_chunk): +async def mock_completion(process_chunk: Callable[[str], Awaitable[None]]) -> str: code_to_return = NO_IMAGES_NYTIMES_MOCK_CODE for i in range(0, len(code_to_return), 10): diff --git a/backend/prompts.py b/backend/prompts.py index c9e48cb..f52c195 100644 --- a/backend/prompts.py +++ b/backend/prompts.py @@ -1,3 +1,8 @@ +from typing import List, Union + +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam + + TAILWIND_SYSTEM_PROMPT = """ You are an expert Tailwind developer You take screenshots of a reference web page from the user, and then build single page apps @@ -117,8 +122,10 @@ Generate code for a web page that looks exactly like this. def assemble_prompt( - image_data_url, generated_code_config: str, result_image_data_url=None -): + image_data_url: str, + generated_code_config: str, + result_image_data_url: Union[str, None] = None, +) -> List[ChatCompletionMessageParam]: # Set the system prompt based on the output settings system_content = TAILWIND_SYSTEM_PROMPT if generated_code_config == "html_tailwind": @@ -132,7 +139,7 @@ def assemble_prompt( else: raise Exception("Code config is not one of available options") - user_content = [ + user_content: List[ChatCompletionContentPartParam] = [ { "type": "image_url", "image_url": {"url": image_data_url, "detail": "high"}, diff --git a/backend/routes/screenshot.py b/backend/routes/screenshot.py index 7efcfb8..258cd7e 100644 --- a/backend/routes/screenshot.py +++ b/backend/routes/screenshot.py @@ -11,7 +11,9 @@ def bytes_to_data_url(image_bytes: bytes, mime_type: str) -> str: return f"data:{mime_type};base64,{base64_image}" -async def capture_screenshot(target_url, api_key, device="desktop") -> bytes: +async def capture_screenshot( + target_url: str, api_key: str, device: str = "desktop" +) -> bytes: api_base_url = "https://api.screenshotone.com/take" params = { diff --git a/backend/utils.py b/backend/utils.py index 17d6423..6c28e14 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -1,28 +1,30 @@ import copy import json +from typing import List +from openai.types.chat import ChatCompletionMessageParam -def pprint_prompt(prompt_messages): +def pprint_prompt(prompt_messages: List[ChatCompletionMessageParam]): print(json.dumps(truncate_data_strings(prompt_messages), indent=4)) -def truncate_data_strings(data): +def truncate_data_strings(data: List[ChatCompletionMessageParam]): # type: ignore # Deep clone the data to avoid modifying the original object cloned_data = copy.deepcopy(data) if isinstance(cloned_data, dict): - for key, value in cloned_data.items(): + for key, value in cloned_data.items(): # type: ignore # Recursively call the function if the value is a dictionary or a list if isinstance(value, (dict, list)): - cloned_data[key] = truncate_data_strings(value) + cloned_data[key] = truncate_data_strings(value) # type: ignore # Truncate the string if it it's long and add ellipsis and length elif isinstance(value, str): - cloned_data[key] = value[:40] + cloned_data[key] = value[:40] # type: ignore if len(value) > 40: - cloned_data[key] += "..." + f" ({len(value)} chars)" + cloned_data[key] += "..." + f" ({len(value)} chars)" # type: ignore - elif isinstance(cloned_data, list): + elif isinstance(cloned_data, list): # type: ignore # Process each item in the list - cloned_data = [truncate_data_strings(item) for item in cloned_data] + cloned_data = [truncate_data_strings(item) for item in cloned_data] # type: ignore - return cloned_data + return cloned_data # type: ignore