strictly type python backend

This commit is contained in:
Abi Raja 2023-12-09 15:34:16 -05:00
parent 68a8d2788d
commit 6a28ee2d3c
8 changed files with 66 additions and 37 deletions

3
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,3 @@
{
"python.analysis.typeCheckingMode": "strict"
}

View File

@ -1,15 +1,15 @@
import asyncio import asyncio
import os
import re import re
from typing import Dict, List, Union
from openai import AsyncOpenAI from openai import AsyncOpenAI
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
async def process_tasks(prompts, api_key, base_url): async def process_tasks(prompts: List[str], api_key: str, base_url: str):
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts] tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
processed_results = [] processed_results: List[Union[str, None]] = []
for result in results: for result in results:
if isinstance(result, Exception): if isinstance(result, Exception):
print(f"An exception occurred: {result}") print(f"An exception occurred: {result}")
@ -20,9 +20,9 @@ async def process_tasks(prompts, api_key, base_url):
return processed_results return processed_results
async def generate_image(prompt, api_key, base_url): async def generate_image(prompt: str, api_key: str, base_url: str):
client = AsyncOpenAI(api_key=api_key, base_url=base_url) client = AsyncOpenAI(api_key=api_key, base_url=base_url)
image_params = { image_params: Dict[str, Union[str, int]] = {
"model": "dall-e-3", "model": "dall-e-3",
"quality": "standard", "quality": "standard",
"style": "natural", "style": "natural",
@ -35,7 +35,7 @@ async def generate_image(prompt, api_key, base_url):
return res.data[0].url return res.data[0].url
def extract_dimensions(url): def extract_dimensions(url: str):
# Regular expression to match numbers in the format '300x200' # Regular expression to match numbers in the format '300x200'
matches = re.findall(r"(\d+)x(\d+)", url) matches = re.findall(r"(\d+)x(\d+)", url)
@ -48,11 +48,11 @@ def extract_dimensions(url):
return (100, 100) return (100, 100)
def create_alt_url_mapping(code): def create_alt_url_mapping(code: str) -> Dict[str, str]:
soup = BeautifulSoup(code, "html.parser") soup = BeautifulSoup(code, "html.parser")
images = soup.find_all("img") images = soup.find_all("img")
mapping = {} mapping: Dict[str, str] = {}
for image in images: for image in images:
if not image["src"].startswith("https://placehold.co"): if not image["src"].startswith("https://placehold.co"):
@ -61,7 +61,9 @@ def create_alt_url_mapping(code):
return mapping return mapping
async def generate_images(code, api_key, base_url, image_cache): async def generate_images(
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
):
# Find all images # Find all images
soup = BeautifulSoup(code, "html.parser") soup = BeautifulSoup(code, "html.parser")
images = soup.find_all("img") images = soup.find_all("img")

View File

@ -1,16 +1,16 @@
import os from typing import Awaitable, Callable, List
from typing import Awaitable, Callable
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
MODEL_GPT_4_VISION = "gpt-4-vision-preview" MODEL_GPT_4_VISION = "gpt-4-vision-preview"
async def stream_openai_response( async def stream_openai_response(
messages, messages: List[ChatCompletionMessageParam],
api_key: str, api_key: str,
base_url: str | None, base_url: str | None,
callback: Callable[[str], Awaitable[None]], callback: Callable[[str], Awaitable[None]],
): ) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url) client = AsyncOpenAI(api_key=api_key, base_url=base_url)
model = MODEL_GPT_4_VISION model = MODEL_GPT_4_VISION
@ -23,9 +23,10 @@ async def stream_openai_response(
params["max_tokens"] = 4096 params["max_tokens"] = 4096
params["temperature"] = 0 params["temperature"] = 0
completion = await client.chat.completions.create(**params) stream = await client.chat.completions.create(**params) # type: ignore
full_response = "" full_response = ""
async for chunk in completion: async for chunk in stream: # type: ignore
assert isinstance(chunk, ChatCompletionChunk)
content = chunk.choices[0].delta.content or "" content = chunk.choices[0].delta.content or ""
full_response += content full_response += content
await callback(content) await callback(content)

View File

@ -14,8 +14,10 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
import openai import openai
from llm import stream_openai_response from llm import stream_openai_response
from mock import mock_completion from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
from utils import pprint_prompt from utils import pprint_prompt
from typing import Dict, List
from image_generation import create_alt_url_mapping, generate_images from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_prompt from prompts import assemble_prompt
from routes import screenshot from routes import screenshot
@ -53,7 +55,7 @@ async def get_status():
) )
def write_logs(prompt_messages, completion): def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str):
# Get the logs path from environment, default to the current working directory # Get the logs path from environment, default to the current working directory
logs_path = os.environ.get("LOGS_PATH", os.getcwd()) logs_path = os.environ.get("LOGS_PATH", os.getcwd())
@ -84,7 +86,8 @@ async def stream_code(websocket: WebSocket):
await websocket.send_json({"type": "error", "value": message}) await websocket.send_json({"type": "error", "value": message})
await websocket.close() await websocket.close()
params = await websocket.receive_json() # TODO: Are the values always strings?
params: Dict[str, str] = await websocket.receive_json()
print("Received params") print("Received params")
@ -154,7 +157,7 @@ async def stream_code(websocket: WebSocket):
print("generating code...") print("generating code...")
await websocket.send_json({"type": "status", "value": "Generating code..."}) await websocket.send_json({"type": "status", "value": "Generating code..."})
async def process_chunk(content): async def process_chunk(content: str):
await websocket.send_json({"type": "chunk", "value": content}) await websocket.send_json({"type": "chunk", "value": content})
# Assemble the prompt # Assemble the prompt
@ -176,15 +179,23 @@ async def stream_code(websocket: WebSocket):
return return
# Image cache for updates so that we don't have to regenerate images # Image cache for updates so that we don't have to regenerate images
image_cache = {} image_cache: Dict[str, str] = {}
if params["generationType"] == "update": if params["generationType"] == "update":
# Transform into message format # Transform into message format
# TODO: Move this to frontend # TODO: Move this to frontend
for index, text in enumerate(params["history"]): for index, text in enumerate(params["history"]):
prompt_messages += [ if index % 2 == 0:
{"role": "assistant" if index % 2 == 0 else "user", "content": text} message: ChatCompletionMessageParam = {
] "role": "assistant",
"content": text,
}
else:
message: ChatCompletionMessageParam = {
"role": "user",
"content": text,
}
prompt_messages.append(message)
image_cache = create_alt_url_mapping(params["history"][-2]) image_cache = create_alt_url_mapping(params["history"][-2])
if SHOULD_MOCK_AI_RESPONSE: if SHOULD_MOCK_AI_RESPONSE:

View File

@ -1,7 +1,8 @@
import asyncio import asyncio
from typing import Awaitable, Callable
async def mock_completion(process_chunk): async def mock_completion(process_chunk: Callable[[str], Awaitable[None]]) -> str:
code_to_return = NO_IMAGES_NYTIMES_MOCK_CODE code_to_return = NO_IMAGES_NYTIMES_MOCK_CODE
for i in range(0, len(code_to_return), 10): for i in range(0, len(code_to_return), 10):

View File

@ -1,3 +1,8 @@
from typing import List, Union
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
TAILWIND_SYSTEM_PROMPT = """ TAILWIND_SYSTEM_PROMPT = """
You are an expert Tailwind developer You are an expert Tailwind developer
You take screenshots of a reference web page from the user, and then build single page apps You take screenshots of a reference web page from the user, and then build single page apps
@ -117,8 +122,10 @@ Generate code for a web page that looks exactly like this.
def assemble_prompt( def assemble_prompt(
image_data_url, generated_code_config: str, result_image_data_url=None image_data_url: str,
): generated_code_config: str,
result_image_data_url: Union[str, None] = None,
) -> List[ChatCompletionMessageParam]:
# Set the system prompt based on the output settings # Set the system prompt based on the output settings
system_content = TAILWIND_SYSTEM_PROMPT system_content = TAILWIND_SYSTEM_PROMPT
if generated_code_config == "html_tailwind": if generated_code_config == "html_tailwind":
@ -132,7 +139,7 @@ def assemble_prompt(
else: else:
raise Exception("Code config is not one of available options") raise Exception("Code config is not one of available options")
user_content = [ user_content: List[ChatCompletionContentPartParam] = [
{ {
"type": "image_url", "type": "image_url",
"image_url": {"url": image_data_url, "detail": "high"}, "image_url": {"url": image_data_url, "detail": "high"},

View File

@ -11,7 +11,9 @@ def bytes_to_data_url(image_bytes: bytes, mime_type: str) -> str:
return f"data:{mime_type};base64,{base64_image}" return f"data:{mime_type};base64,{base64_image}"
async def capture_screenshot(target_url, api_key, device="desktop") -> bytes: async def capture_screenshot(
target_url: str, api_key: str, device: str = "desktop"
) -> bytes:
api_base_url = "https://api.screenshotone.com/take" api_base_url = "https://api.screenshotone.com/take"
params = { params = {

View File

@ -1,28 +1,30 @@
import copy import copy
import json import json
from typing import List
from openai.types.chat import ChatCompletionMessageParam
def pprint_prompt(prompt_messages): def pprint_prompt(prompt_messages: List[ChatCompletionMessageParam]):
print(json.dumps(truncate_data_strings(prompt_messages), indent=4)) print(json.dumps(truncate_data_strings(prompt_messages), indent=4))
def truncate_data_strings(data): def truncate_data_strings(data: List[ChatCompletionMessageParam]): # type: ignore
# Deep clone the data to avoid modifying the original object # Deep clone the data to avoid modifying the original object
cloned_data = copy.deepcopy(data) cloned_data = copy.deepcopy(data)
if isinstance(cloned_data, dict): if isinstance(cloned_data, dict):
for key, value in cloned_data.items(): for key, value in cloned_data.items(): # type: ignore
# Recursively call the function if the value is a dictionary or a list # Recursively call the function if the value is a dictionary or a list
if isinstance(value, (dict, list)): if isinstance(value, (dict, list)):
cloned_data[key] = truncate_data_strings(value) cloned_data[key] = truncate_data_strings(value) # type: ignore
# Truncate the string if it it's long and add ellipsis and length # Truncate the string if it it's long and add ellipsis and length
elif isinstance(value, str): elif isinstance(value, str):
cloned_data[key] = value[:40] cloned_data[key] = value[:40] # type: ignore
if len(value) > 40: if len(value) > 40:
cloned_data[key] += "..." + f" ({len(value)} chars)" cloned_data[key] += "..." + f" ({len(value)} chars)" # type: ignore
elif isinstance(cloned_data, list): elif isinstance(cloned_data, list): # type: ignore
# Process each item in the list # Process each item in the list
cloned_data = [truncate_data_strings(item) for item in cloned_data] cloned_data = [truncate_data_strings(item) for item in cloned_data] # type: ignore
return cloned_data return cloned_data # type: ignore