improve image cache so we don't re-generate the same images on update
This commit is contained in:
parent
62c5458e08
commit
c061c9b610
@ -47,15 +47,33 @@ def extract_dimensions(url):
|
|||||||
return (100, 100)
|
return (100, 100)
|
||||||
|
|
||||||
|
|
||||||
async def generate_images(code):
|
def create_alt_url_mapping(code):
|
||||||
# Find all images and extract their alt texts
|
|
||||||
soup = BeautifulSoup(code, "html.parser")
|
soup = BeautifulSoup(code, "html.parser")
|
||||||
images = soup.find_all("img")
|
images = soup.find_all("img")
|
||||||
|
|
||||||
|
mapping = {}
|
||||||
|
|
||||||
|
for image in images:
|
||||||
|
if not image["src"].startswith("https://placehold.co"):
|
||||||
|
mapping[image["alt"]] = image["src"]
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_images(code, image_cache):
|
||||||
|
# Find all images
|
||||||
|
soup = BeautifulSoup(code, "html.parser")
|
||||||
|
images = soup.find_all("img")
|
||||||
|
|
||||||
|
# Extract alt texts as image prompts
|
||||||
alts = []
|
alts = []
|
||||||
for img in images:
|
for img in images:
|
||||||
# Only include URL if the image starts with https://placehold.co
|
# Only include URL if the image starts with https://placehold.co
|
||||||
if img["src"].startswith("https://placehold.co"):
|
# and it's not already in the image_cache
|
||||||
|
if (
|
||||||
|
img["src"].startswith("https://placehold.co")
|
||||||
|
and image_cache.get(img.get("alt")) is None
|
||||||
|
):
|
||||||
alts.append(img.get("alt", None))
|
alts.append(img.get("alt", None))
|
||||||
|
|
||||||
# Exclude images with no alt text
|
# Exclude images with no alt text
|
||||||
@ -70,6 +88,9 @@ async def generate_images(code):
|
|||||||
# Create a dict mapping alt text to image URL
|
# Create a dict mapping alt text to image URL
|
||||||
mapped_image_urls = dict(zip(prompts, results))
|
mapped_image_urls = dict(zip(prompts, results))
|
||||||
|
|
||||||
|
# Merge with image_cache
|
||||||
|
mapped_image_urls = {**mapped_image_urls, **image_cache}
|
||||||
|
|
||||||
# Replace old image URLs with the generated URLs
|
# Replace old image URLs with the generated URLs
|
||||||
for img in images:
|
for img in images:
|
||||||
# Skip images that don't start with https://placehold.co (leave them alone)
|
# Skip images that don't start with https://placehold.co (leave them alone)
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from fastapi import FastAPI, WebSocket
|
|||||||
|
|
||||||
from llm import stream_openai_response
|
from llm import stream_openai_response
|
||||||
from mock import MOCK_HTML, mock_completion
|
from mock import MOCK_HTML, mock_completion
|
||||||
from image_generation import generate_images
|
from image_generation import create_alt_url_mapping, generate_images
|
||||||
from prompts import assemble_prompt
|
from prompts import assemble_prompt
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
@ -48,13 +48,19 @@ async def stream_code_test(websocket: WebSocket):
|
|||||||
|
|
||||||
prompt_messages = assemble_prompt(params["image"])
|
prompt_messages = assemble_prompt(params["image"])
|
||||||
|
|
||||||
|
# Image cache for updates so that we don't have to regenerate images
|
||||||
|
image_cache = {}
|
||||||
|
|
||||||
if params["generationType"] == "update":
|
if params["generationType"] == "update":
|
||||||
# Transform into message format
|
# Transform into message format
|
||||||
|
# TODO: Move this to frontend
|
||||||
for index, text in enumerate(params["history"]):
|
for index, text in enumerate(params["history"]):
|
||||||
prompt_messages += [
|
prompt_messages += [
|
||||||
{"role": "assistant" if index % 2 == 0 else "user", "content": text}
|
{"role": "assistant" if index % 2 == 0 else "user", "content": text}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
image_cache = create_alt_url_mapping(params["history"][-2])
|
||||||
|
|
||||||
if SHOULD_MOCK_AI_RESPONSE:
|
if SHOULD_MOCK_AI_RESPONSE:
|
||||||
completion = await mock_completion(process_chunk)
|
completion = await mock_completion(process_chunk)
|
||||||
else:
|
else:
|
||||||
@ -70,7 +76,7 @@ async def stream_code_test(websocket: WebSocket):
|
|||||||
await websocket.send_json({"type": "status", "value": "Generating images..."})
|
await websocket.send_json({"type": "status", "value": "Generating images..."})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
updated_html = await generate_images(completion)
|
updated_html = await generate_images(completion, image_cache=image_cache)
|
||||||
await websocket.send_json({"type": "setCode", "value": updated_html})
|
await websocket.send_json({"type": "setCode", "value": updated_html})
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{"type": "status", "value": "Code generation complete."}
|
{"type": "status", "value": "Code generation complete."}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user