import asyncio import re from typing import Dict, List, Literal, Union from openai import AsyncOpenAI from bs4 import BeautifulSoup from image_generation.replicate import call_replicate async def process_tasks( prompts: List[str], api_key: str, base_url: str | None, model: Literal["dalle3", "sdxl-lightning"], ): import time start_time = time.time() if model == "dalle3": tasks = [generate_image_dalle(prompt, api_key, base_url) for prompt in prompts] else: tasks = [generate_image_replicate(prompt, api_key) for prompt in prompts] results = await asyncio.gather(*tasks, return_exceptions=True) end_time = time.time() generation_time = end_time - start_time print(f"Image generation time: {generation_time:.2f} seconds") processed_results: List[Union[str, None]] = [] for result in results: if isinstance(result, Exception): print(f"An exception occurred: {result}") processed_results.append(None) else: processed_results.append(result) return processed_results async def generate_image_dalle( prompt: str, api_key: str, base_url: str | None ) -> Union[str, None]: client = AsyncOpenAI(api_key=api_key, base_url=base_url) res = await client.images.generate( model="dall-e-3", quality="standard", style="natural", n=1, size="1024x1024", prompt=prompt, ) await client.close() return res.data[0].url async def generate_image_replicate(prompt: str, api_key: str) -> str: # We use SDXL Lightning return await call_replicate( "5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f", { "width": 1024, "height": 1024, "prompt": prompt, "scheduler": "K_EULER", "num_outputs": 1, "guidance_scale": 0, "negative_prompt": "worst quality, low quality", "num_inference_steps": 4, }, api_key, ) def extract_dimensions(url: str): # Regular expression to match numbers in the format '300x200' matches = re.findall(r"(\d+)x(\d+)", url) if matches: width, height = matches[0] # Extract the first match width = int(width) height = int(height) return (width, height) else: return (100, 100) def create_alt_url_mapping(code: str) -> Dict[str, str]: soup = BeautifulSoup(code, "html.parser") images = soup.find_all("img") mapping: Dict[str, str] = {} for image in images: if not image["src"].startswith("https://placehold.co"): mapping[image["alt"]] = image["src"] return mapping async def generate_images( code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str], model: Literal["dalle3", "sdxl-lightning"] = "dalle3", ) -> str: # Find all images soup = BeautifulSoup(code, "html.parser") images = soup.find_all("img") # Extract alt texts as image prompts alts: List[str | None] = [] for img in images: # Only include URL if the image starts with https://placehold.co # and it's not already in the image_cache if ( img["src"].startswith("https://placehold.co") and image_cache.get(img.get("alt")) is None ): alts.append(img.get("alt", None)) # Exclude images with no alt text filtered_alts: List[str] = [alt for alt in alts if alt is not None] # Remove duplicates prompts = list(set(filtered_alts)) # Return early if there are no images to replace if len(prompts) == 0: return code # Generate images results = await process_tasks(prompts, api_key, base_url, model) # Create a dict mapping alt text to image URL mapped_image_urls = dict(zip(prompts, results)) # Merge with image_cache mapped_image_urls = {**mapped_image_urls, **image_cache} # Replace old image URLs with the generated URLs for img in images: # Skip images that don't start with https://placehold.co (leave them alone) if not img["src"].startswith("https://placehold.co"): continue new_url = mapped_image_urls[img.get("alt")] if new_url: # Set width and height attributes width, height = extract_dimensions(img["src"]) img["width"] = width img["height"] = height # Replace img['src'] with the mapped image URL img["src"] = new_url else: print("Image generation failed for alt text:" + img.get("alt")) # Return the modified HTML # (need to prettify it because BeautifulSoup messes up the formatting) return soup.prettify()