improve image cache so we don't re-generate the same images on update
This commit is contained in:
parent
62c5458e08
commit
c061c9b610
@ -47,15 +47,33 @@ def extract_dimensions(url):
|
||||
return (100, 100)
|
||||
|
||||
|
||||
async def generate_images(code):
|
||||
# Find all images and extract their alt texts
|
||||
def create_alt_url_mapping(code):
|
||||
soup = BeautifulSoup(code, "html.parser")
|
||||
images = soup.find_all("img")
|
||||
|
||||
mapping = {}
|
||||
|
||||
for image in images:
|
||||
if not image["src"].startswith("https://placehold.co"):
|
||||
mapping[image["alt"]] = image["src"]
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
async def generate_images(code, image_cache):
|
||||
# Find all images
|
||||
soup = BeautifulSoup(code, "html.parser")
|
||||
images = soup.find_all("img")
|
||||
|
||||
# Extract alt texts as image prompts
|
||||
alts = []
|
||||
for img in images:
|
||||
# Only include URL if the image starts with https://placehold.co
|
||||
if img["src"].startswith("https://placehold.co"):
|
||||
# and it's not already in the image_cache
|
||||
if (
|
||||
img["src"].startswith("https://placehold.co")
|
||||
and image_cache.get(img.get("alt")) is None
|
||||
):
|
||||
alts.append(img.get("alt", None))
|
||||
|
||||
# Exclude images with no alt text
|
||||
@ -70,6 +88,9 @@ async def generate_images(code):
|
||||
# Create a dict mapping alt text to image URL
|
||||
mapped_image_urls = dict(zip(prompts, results))
|
||||
|
||||
# Merge with image_cache
|
||||
mapped_image_urls = {**mapped_image_urls, **image_cache}
|
||||
|
||||
# Replace old image URLs with the generated URLs
|
||||
for img in images:
|
||||
# Skip images that don't start with https://placehold.co (leave them alone)
|
||||
|
||||
@ -12,7 +12,7 @@ from fastapi import FastAPI, WebSocket
|
||||
|
||||
from llm import stream_openai_response
|
||||
from mock import MOCK_HTML, mock_completion
|
||||
from image_generation import generate_images
|
||||
from image_generation import create_alt_url_mapping, generate_images
|
||||
from prompts import assemble_prompt
|
||||
|
||||
app = FastAPI()
|
||||
@ -48,13 +48,19 @@ async def stream_code_test(websocket: WebSocket):
|
||||
|
||||
prompt_messages = assemble_prompt(params["image"])
|
||||
|
||||
# Image cache for updates so that we don't have to regenerate images
|
||||
image_cache = {}
|
||||
|
||||
if params["generationType"] == "update":
|
||||
# Transform into message format
|
||||
# TODO: Move this to frontend
|
||||
for index, text in enumerate(params["history"]):
|
||||
prompt_messages += [
|
||||
{"role": "assistant" if index % 2 == 0 else "user", "content": text}
|
||||
]
|
||||
|
||||
image_cache = create_alt_url_mapping(params["history"][-2])
|
||||
|
||||
if SHOULD_MOCK_AI_RESPONSE:
|
||||
completion = await mock_completion(process_chunk)
|
||||
else:
|
||||
@ -70,7 +76,7 @@ async def stream_code_test(websocket: WebSocket):
|
||||
await websocket.send_json({"type": "status", "value": "Generating images..."})
|
||||
|
||||
try:
|
||||
updated_html = await generate_images(completion)
|
||||
updated_html = await generate_images(completion, image_cache=image_cache)
|
||||
await websocket.send_json({"type": "setCode", "value": updated_html})
|
||||
await websocket.send_json(
|
||||
{"type": "status", "value": "Code generation complete."}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user