diff --git a/backend/image_generation.py b/backend/image_generation.py index 424c95b..5d4b81c 100644 --- a/backend/image_generation.py +++ b/backend/image_generation.py @@ -47,15 +47,33 @@ def extract_dimensions(url): return (100, 100) -async def generate_images(code): - # Find all images and extract their alt texts +def create_alt_url_mapping(code): soup = BeautifulSoup(code, "html.parser") images = soup.find_all("img") + mapping = {} + + for image in images: + if not image["src"].startswith("https://placehold.co"): + mapping[image["alt"]] = image["src"] + + return mapping + + +async def generate_images(code, image_cache): + # Find all images + soup = BeautifulSoup(code, "html.parser") + images = soup.find_all("img") + + # Extract alt texts as image prompts alts = [] for img in images: # Only include URL if the image starts with https://placehold.co - if img["src"].startswith("https://placehold.co"): + # and it's not already in the image_cache + if ( + img["src"].startswith("https://placehold.co") + and image_cache.get(img.get("alt")) is None + ): alts.append(img.get("alt", None)) # Exclude images with no alt text @@ -70,6 +88,9 @@ async def generate_images(code): # Create a dict mapping alt text to image URL mapped_image_urls = dict(zip(prompts, results)) + # Merge with image_cache + mapped_image_urls = {**mapped_image_urls, **image_cache} + # Replace old image URLs with the generated URLs for img in images: # Skip images that don't start with https://placehold.co (leave them alone) diff --git a/backend/main.py b/backend/main.py index a80684e..5abd333 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,7 +12,7 @@ from fastapi import FastAPI, WebSocket from llm import stream_openai_response from mock import MOCK_HTML, mock_completion -from image_generation import generate_images +from image_generation import create_alt_url_mapping, generate_images from prompts import assemble_prompt app = FastAPI() @@ -48,13 +48,19 @@ async def stream_code_test(websocket: WebSocket): prompt_messages = assemble_prompt(params["image"]) + # Image cache for updates so that we don't have to regenerate images + image_cache = {} + if params["generationType"] == "update": # Transform into message format + # TODO: Move this to frontend for index, text in enumerate(params["history"]): prompt_messages += [ {"role": "assistant" if index % 2 == 0 else "user", "content": text} ] + image_cache = create_alt_url_mapping(params["history"][-2]) + if SHOULD_MOCK_AI_RESPONSE: completion = await mock_completion(process_chunk) else: @@ -70,7 +76,7 @@ async def stream_code_test(websocket: WebSocket): await websocket.send_json({"type": "status", "value": "Generating images..."}) try: - updated_html = await generate_images(completion) + updated_html = await generate_images(completion, image_cache=image_cache) await websocket.send_json({"type": "setCode", "value": updated_html}) await websocket.send_json( {"type": "status", "value": "Code generation complete."}