strictly type python backend
This commit is contained in:
parent
68a8d2788d
commit
6a28ee2d3c
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"python.analysis.typeCheckingMode": "strict"
|
||||||
|
}
|
||||||
@ -1,15 +1,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
|
from typing import Dict, List, Union
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
|
||||||
async def process_tasks(prompts, api_key, base_url):
|
async def process_tasks(prompts: List[str], api_key: str, base_url: str):
|
||||||
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
|
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
processed_results = []
|
processed_results: List[Union[str, None]] = []
|
||||||
for result in results:
|
for result in results:
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
print(f"An exception occurred: {result}")
|
print(f"An exception occurred: {result}")
|
||||||
@ -20,9 +20,9 @@ async def process_tasks(prompts, api_key, base_url):
|
|||||||
return processed_results
|
return processed_results
|
||||||
|
|
||||||
|
|
||||||
async def generate_image(prompt, api_key, base_url):
|
async def generate_image(prompt: str, api_key: str, base_url: str):
|
||||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||||
image_params = {
|
image_params: Dict[str, Union[str, int]] = {
|
||||||
"model": "dall-e-3",
|
"model": "dall-e-3",
|
||||||
"quality": "standard",
|
"quality": "standard",
|
||||||
"style": "natural",
|
"style": "natural",
|
||||||
@ -35,7 +35,7 @@ async def generate_image(prompt, api_key, base_url):
|
|||||||
return res.data[0].url
|
return res.data[0].url
|
||||||
|
|
||||||
|
|
||||||
def extract_dimensions(url):
|
def extract_dimensions(url: str):
|
||||||
# Regular expression to match numbers in the format '300x200'
|
# Regular expression to match numbers in the format '300x200'
|
||||||
matches = re.findall(r"(\d+)x(\d+)", url)
|
matches = re.findall(r"(\d+)x(\d+)", url)
|
||||||
|
|
||||||
@ -48,11 +48,11 @@ def extract_dimensions(url):
|
|||||||
return (100, 100)
|
return (100, 100)
|
||||||
|
|
||||||
|
|
||||||
def create_alt_url_mapping(code):
|
def create_alt_url_mapping(code: str) -> Dict[str, str]:
|
||||||
soup = BeautifulSoup(code, "html.parser")
|
soup = BeautifulSoup(code, "html.parser")
|
||||||
images = soup.find_all("img")
|
images = soup.find_all("img")
|
||||||
|
|
||||||
mapping = {}
|
mapping: Dict[str, str] = {}
|
||||||
|
|
||||||
for image in images:
|
for image in images:
|
||||||
if not image["src"].startswith("https://placehold.co"):
|
if not image["src"].startswith("https://placehold.co"):
|
||||||
@ -61,7 +61,9 @@ def create_alt_url_mapping(code):
|
|||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
async def generate_images(code, api_key, base_url, image_cache):
|
async def generate_images(
|
||||||
|
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
|
||||||
|
):
|
||||||
# Find all images
|
# Find all images
|
||||||
soup = BeautifulSoup(code, "html.parser")
|
soup = BeautifulSoup(code, "html.parser")
|
||||||
images = soup.find_all("img")
|
images = soup.find_all("img")
|
||||||
|
|||||||
@ -1,16 +1,16 @@
|
|||||||
import os
|
from typing import Awaitable, Callable, List
|
||||||
from typing import Awaitable, Callable
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
||||||
|
|
||||||
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
||||||
|
|
||||||
|
|
||||||
async def stream_openai_response(
|
async def stream_openai_response(
|
||||||
messages,
|
messages: List[ChatCompletionMessageParam],
|
||||||
api_key: str,
|
api_key: str,
|
||||||
base_url: str | None,
|
base_url: str | None,
|
||||||
callback: Callable[[str], Awaitable[None]],
|
callback: Callable[[str], Awaitable[None]],
|
||||||
):
|
) -> str:
|
||||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
model = MODEL_GPT_4_VISION
|
model = MODEL_GPT_4_VISION
|
||||||
@ -23,9 +23,10 @@ async def stream_openai_response(
|
|||||||
params["max_tokens"] = 4096
|
params["max_tokens"] = 4096
|
||||||
params["temperature"] = 0
|
params["temperature"] = 0
|
||||||
|
|
||||||
completion = await client.chat.completions.create(**params)
|
stream = await client.chat.completions.create(**params) # type: ignore
|
||||||
full_response = ""
|
full_response = ""
|
||||||
async for chunk in completion:
|
async for chunk in stream: # type: ignore
|
||||||
|
assert isinstance(chunk, ChatCompletionChunk)
|
||||||
content = chunk.choices[0].delta.content or ""
|
content = chunk.choices[0].delta.content or ""
|
||||||
full_response += content
|
full_response += content
|
||||||
await callback(content)
|
await callback(content)
|
||||||
|
|||||||
@ -14,8 +14,10 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
import openai
|
import openai
|
||||||
from llm import stream_openai_response
|
from llm import stream_openai_response
|
||||||
from mock import mock_completion
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
from mock_llm import mock_completion
|
||||||
from utils import pprint_prompt
|
from utils import pprint_prompt
|
||||||
|
from typing import Dict, List
|
||||||
from image_generation import create_alt_url_mapping, generate_images
|
from image_generation import create_alt_url_mapping, generate_images
|
||||||
from prompts import assemble_prompt
|
from prompts import assemble_prompt
|
||||||
from routes import screenshot
|
from routes import screenshot
|
||||||
@ -53,7 +55,7 @@ async def get_status():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def write_logs(prompt_messages, completion):
|
def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str):
|
||||||
# Get the logs path from environment, default to the current working directory
|
# Get the logs path from environment, default to the current working directory
|
||||||
logs_path = os.environ.get("LOGS_PATH", os.getcwd())
|
logs_path = os.environ.get("LOGS_PATH", os.getcwd())
|
||||||
|
|
||||||
@ -84,7 +86,8 @@ async def stream_code(websocket: WebSocket):
|
|||||||
await websocket.send_json({"type": "error", "value": message})
|
await websocket.send_json({"type": "error", "value": message})
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
params = await websocket.receive_json()
|
# TODO: Are the values always strings?
|
||||||
|
params: Dict[str, str] = await websocket.receive_json()
|
||||||
|
|
||||||
print("Received params")
|
print("Received params")
|
||||||
|
|
||||||
@ -154,7 +157,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
print("generating code...")
|
print("generating code...")
|
||||||
await websocket.send_json({"type": "status", "value": "Generating code..."})
|
await websocket.send_json({"type": "status", "value": "Generating code..."})
|
||||||
|
|
||||||
async def process_chunk(content):
|
async def process_chunk(content: str):
|
||||||
await websocket.send_json({"type": "chunk", "value": content})
|
await websocket.send_json({"type": "chunk", "value": content})
|
||||||
|
|
||||||
# Assemble the prompt
|
# Assemble the prompt
|
||||||
@ -176,15 +179,23 @@ async def stream_code(websocket: WebSocket):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Image cache for updates so that we don't have to regenerate images
|
# Image cache for updates so that we don't have to regenerate images
|
||||||
image_cache = {}
|
image_cache: Dict[str, str] = {}
|
||||||
|
|
||||||
if params["generationType"] == "update":
|
if params["generationType"] == "update":
|
||||||
# Transform into message format
|
# Transform into message format
|
||||||
# TODO: Move this to frontend
|
# TODO: Move this to frontend
|
||||||
for index, text in enumerate(params["history"]):
|
for index, text in enumerate(params["history"]):
|
||||||
prompt_messages += [
|
if index % 2 == 0:
|
||||||
{"role": "assistant" if index % 2 == 0 else "user", "content": text}
|
message: ChatCompletionMessageParam = {
|
||||||
]
|
"role": "assistant",
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
message: ChatCompletionMessageParam = {
|
||||||
|
"role": "user",
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
prompt_messages.append(message)
|
||||||
image_cache = create_alt_url_mapping(params["history"][-2])
|
image_cache = create_alt_url_mapping(params["history"][-2])
|
||||||
|
|
||||||
if SHOULD_MOCK_AI_RESPONSE:
|
if SHOULD_MOCK_AI_RESPONSE:
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Awaitable, Callable
|
||||||
|
|
||||||
|
|
||||||
async def mock_completion(process_chunk):
|
async def mock_completion(process_chunk: Callable[[str], Awaitable[None]]) -> str:
|
||||||
code_to_return = NO_IMAGES_NYTIMES_MOCK_CODE
|
code_to_return = NO_IMAGES_NYTIMES_MOCK_CODE
|
||||||
|
|
||||||
for i in range(0, len(code_to_return), 10):
|
for i in range(0, len(code_to_return), 10):
|
||||||
@ -1,3 +1,8 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
|
||||||
|
|
||||||
|
|
||||||
TAILWIND_SYSTEM_PROMPT = """
|
TAILWIND_SYSTEM_PROMPT = """
|
||||||
You are an expert Tailwind developer
|
You are an expert Tailwind developer
|
||||||
You take screenshots of a reference web page from the user, and then build single page apps
|
You take screenshots of a reference web page from the user, and then build single page apps
|
||||||
@ -117,8 +122,10 @@ Generate code for a web page that looks exactly like this.
|
|||||||
|
|
||||||
|
|
||||||
def assemble_prompt(
|
def assemble_prompt(
|
||||||
image_data_url, generated_code_config: str, result_image_data_url=None
|
image_data_url: str,
|
||||||
):
|
generated_code_config: str,
|
||||||
|
result_image_data_url: Union[str, None] = None,
|
||||||
|
) -> List[ChatCompletionMessageParam]:
|
||||||
# Set the system prompt based on the output settings
|
# Set the system prompt based on the output settings
|
||||||
system_content = TAILWIND_SYSTEM_PROMPT
|
system_content = TAILWIND_SYSTEM_PROMPT
|
||||||
if generated_code_config == "html_tailwind":
|
if generated_code_config == "html_tailwind":
|
||||||
@ -132,7 +139,7 @@ def assemble_prompt(
|
|||||||
else:
|
else:
|
||||||
raise Exception("Code config is not one of available options")
|
raise Exception("Code config is not one of available options")
|
||||||
|
|
||||||
user_content = [
|
user_content: List[ChatCompletionContentPartParam] = [
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {"url": image_data_url, "detail": "high"},
|
"image_url": {"url": image_data_url, "detail": "high"},
|
||||||
|
|||||||
@ -11,7 +11,9 @@ def bytes_to_data_url(image_bytes: bytes, mime_type: str) -> str:
|
|||||||
return f"data:{mime_type};base64,{base64_image}"
|
return f"data:{mime_type};base64,{base64_image}"
|
||||||
|
|
||||||
|
|
||||||
async def capture_screenshot(target_url, api_key, device="desktop") -> bytes:
|
async def capture_screenshot(
|
||||||
|
target_url: str, api_key: str, device: str = "desktop"
|
||||||
|
) -> bytes:
|
||||||
api_base_url = "https://api.screenshotone.com/take"
|
api_base_url = "https://api.screenshotone.com/take"
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
|
|||||||
@ -1,28 +1,30 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
|
from typing import List
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
|
||||||
def pprint_prompt(prompt_messages):
|
def pprint_prompt(prompt_messages: List[ChatCompletionMessageParam]):
|
||||||
print(json.dumps(truncate_data_strings(prompt_messages), indent=4))
|
print(json.dumps(truncate_data_strings(prompt_messages), indent=4))
|
||||||
|
|
||||||
|
|
||||||
def truncate_data_strings(data):
|
def truncate_data_strings(data: List[ChatCompletionMessageParam]): # type: ignore
|
||||||
# Deep clone the data to avoid modifying the original object
|
# Deep clone the data to avoid modifying the original object
|
||||||
cloned_data = copy.deepcopy(data)
|
cloned_data = copy.deepcopy(data)
|
||||||
|
|
||||||
if isinstance(cloned_data, dict):
|
if isinstance(cloned_data, dict):
|
||||||
for key, value in cloned_data.items():
|
for key, value in cloned_data.items(): # type: ignore
|
||||||
# Recursively call the function if the value is a dictionary or a list
|
# Recursively call the function if the value is a dictionary or a list
|
||||||
if isinstance(value, (dict, list)):
|
if isinstance(value, (dict, list)):
|
||||||
cloned_data[key] = truncate_data_strings(value)
|
cloned_data[key] = truncate_data_strings(value) # type: ignore
|
||||||
# Truncate the string if it it's long and add ellipsis and length
|
# Truncate the string if it it's long and add ellipsis and length
|
||||||
elif isinstance(value, str):
|
elif isinstance(value, str):
|
||||||
cloned_data[key] = value[:40]
|
cloned_data[key] = value[:40] # type: ignore
|
||||||
if len(value) > 40:
|
if len(value) > 40:
|
||||||
cloned_data[key] += "..." + f" ({len(value)} chars)"
|
cloned_data[key] += "..." + f" ({len(value)} chars)" # type: ignore
|
||||||
|
|
||||||
elif isinstance(cloned_data, list):
|
elif isinstance(cloned_data, list): # type: ignore
|
||||||
# Process each item in the list
|
# Process each item in the list
|
||||||
cloned_data = [truncate_data_strings(item) for item in cloned_data]
|
cloned_data = [truncate_data_strings(item) for item in cloned_data] # type: ignore
|
||||||
|
|
||||||
return cloned_data
|
return cloned_data # type: ignore
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user