improve image cache so we don't re-generate the same images on update

This commit is contained in:
Abi Raja 2023-11-15 17:31:01 -05:00
parent 62c5458e08
commit c061c9b610
2 changed files with 32 additions and 5 deletions

View File

@ -47,15 +47,33 @@ def extract_dimensions(url):
return (100, 100)
async def generate_images(code):
# Find all images and extract their alt texts
def create_alt_url_mapping(code):
soup = BeautifulSoup(code, "html.parser")
images = soup.find_all("img")
mapping = {}
for image in images:
if not image["src"].startswith("https://placehold.co"):
mapping[image["alt"]] = image["src"]
return mapping
async def generate_images(code, image_cache):
# Find all images
soup = BeautifulSoup(code, "html.parser")
images = soup.find_all("img")
# Extract alt texts as image prompts
alts = []
for img in images:
# Only include URL if the image starts with https://placehold.co
if img["src"].startswith("https://placehold.co"):
# and it's not already in the image_cache
if (
img["src"].startswith("https://placehold.co")
and image_cache.get(img.get("alt")) is None
):
alts.append(img.get("alt", None))
# Exclude images with no alt text
@ -70,6 +88,9 @@ async def generate_images(code):
# Create a dict mapping alt text to image URL
mapped_image_urls = dict(zip(prompts, results))
# Merge with image_cache
mapped_image_urls = {**mapped_image_urls, **image_cache}
# Replace old image URLs with the generated URLs
for img in images:
# Skip images that don't start with https://placehold.co (leave them alone)

View File

@ -12,7 +12,7 @@ from fastapi import FastAPI, WebSocket
from llm import stream_openai_response
from mock import MOCK_HTML, mock_completion
from image_generation import generate_images
from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_prompt
app = FastAPI()
@ -48,13 +48,19 @@ async def stream_code_test(websocket: WebSocket):
prompt_messages = assemble_prompt(params["image"])
# Image cache for updates so that we don't have to regenerate images
image_cache = {}
if params["generationType"] == "update":
# Transform into message format
# TODO: Move this to frontend
for index, text in enumerate(params["history"]):
prompt_messages += [
{"role": "assistant" if index % 2 == 0 else "user", "content": text}
]
image_cache = create_alt_url_mapping(params["history"][-2])
if SHOULD_MOCK_AI_RESPONSE:
completion = await mock_completion(process_chunk)
else:
@ -70,7 +76,7 @@ async def stream_code_test(websocket: WebSocket):
await websocket.send_json({"type": "status", "value": "Generating images..."})
try:
updated_html = await generate_images(completion)
updated_html = await generate_images(completion, image_cache=image_cache)
await websocket.send_json({"type": "setCode", "value": updated_html})
await websocket.send_json(
{"type": "status", "value": "Code generation complete."}