screenshot-to-code/backend/image_generation/core.py
2024-09-08 16:50:18 +02:00

167 lines
4.8 KiB
Python

import asyncio
import re
from typing import Dict, List, Literal, Union
from openai import AsyncOpenAI
from bs4 import BeautifulSoup
import sentry_sdk
from image_generation.replicate import call_replicate
async def process_tasks(
prompts: List[str],
api_key: str,
base_url: str | None,
model: Literal["dalle3", "sdxl-lightning"],
):
import time
start_time = time.time()
if model == "dalle3":
tasks = [generate_image_dalle(prompt, api_key, base_url) for prompt in prompts]
else:
tasks = [generate_image_replicate(prompt, api_key) for prompt in prompts]
results = await asyncio.gather(*tasks, return_exceptions=True)
end_time = time.time()
generation_time = end_time - start_time
print(f"Image generation time: {generation_time:.2f} seconds")
processed_results: List[Union[str, None]] = []
for result in results:
if isinstance(result, BaseException):
print(f"An exception occurred: {result}")
try:
raise result
except Exception:
sentry_sdk.capture_exception()
processed_results.append(None)
else:
processed_results.append(result)
return processed_results
async def generate_image_dalle(
prompt: str, api_key: str, base_url: str | None
) -> Union[str, None]:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
res = await client.images.generate(
model="dall-e-3",
quality="standard",
style="natural",
n=1,
size="1024x1024",
prompt=prompt,
)
await client.close()
return res.data[0].url
async def generate_image_replicate(prompt: str, api_key: str) -> str:
# We use SDXL Lightning
return await call_replicate(
"5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
{
"width": 1024,
"height": 1024,
"prompt": prompt,
"scheduler": "K_EULER",
"num_outputs": 1,
"guidance_scale": 0,
"negative_prompt": "worst quality, low quality",
"num_inference_steps": 4,
},
api_key,
)
def extract_dimensions(url: str):
# Regular expression to match numbers in the format '300x200'
matches = re.findall(r"(\d+)x(\d+)", url)
if matches:
width, height = matches[0] # Extract the first match
width = int(width)
height = int(height)
return (width, height)
else:
return (100, 100)
def create_alt_url_mapping(code: str) -> Dict[str, str]:
soup = BeautifulSoup(code, "html.parser")
images = soup.find_all("img")
mapping: Dict[str, str] = {}
for image in images:
if not image["src"].startswith("https://placehold.co"):
mapping[image["alt"]] = image["src"]
return mapping
async def generate_images(
code: str,
api_key: str,
base_url: Union[str, None],
image_cache: Dict[str, str],
model: Literal["dalle3", "sdxl-lightning"] = "dalle3",
) -> str:
# Find all images
soup = BeautifulSoup(code, "html.parser")
images = soup.find_all("img")
# Extract alt texts as image prompts
alts: List[str | None] = []
for img in images:
# Only include URL if the image starts with https://placehold.co
# and it's not already in the image_cache
if (
img["src"].startswith("https://placehold.co")
and image_cache.get(img.get("alt")) is None
):
alts.append(img.get("alt", None))
# Exclude images with no alt text
filtered_alts: List[str] = [alt for alt in alts if alt is not None]
# Remove duplicates
prompts = list(set(filtered_alts))
# Return early if there are no images to replace
if len(prompts) == 0:
return code
# Generate images
results = await process_tasks(prompts, api_key, base_url, model)
# Create a dict mapping alt text to image URL
mapped_image_urls = dict(zip(prompts, results))
# Merge with image_cache
mapped_image_urls = {**mapped_image_urls, **image_cache}
# Replace old image URLs with the generated URLs
for img in images:
# Skip images that don't start with https://placehold.co (leave them alone)
if not img["src"].startswith("https://placehold.co"):
continue
new_url = mapped_image_urls[img.get("alt")]
if new_url:
# Set width and height attributes
width, height = extract_dimensions(img["src"])
img["width"] = width
img["height"] = height
# Replace img['src'] with the mapped image URL
img["src"] = new_url
else:
print("Image generation failed for alt text:" + img.get("alt"))
# Return the modified HTML
# (need to prettify it because BeautifulSoup messes up the formatting)
return soup.prettify()