From 51c7334c0e99e75c86dff50653aec1f3c150060c Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Tue, 14 Nov 2023 23:05:41 -0500 Subject: [PATCH] add image generation for placeholder images --- backend/image_generation.py | 83 +++++++++++++++++++++ backend/main.py | 46 +++++++++--- backend/mock.py | 140 +++++++++++++++++++++++++++++++++++ backend/poetry.lock | 33 ++++++++- backend/pyproject.toml | 1 + frontend/src/App.tsx | 11 ++- frontend/src/generateCode.ts | 3 + 7 files changed, 301 insertions(+), 16 deletions(-) create mode 100644 backend/image_generation.py create mode 100644 backend/mock.py diff --git a/backend/image_generation.py b/backend/image_generation.py new file mode 100644 index 0000000..ec9802d --- /dev/null +++ b/backend/image_generation.py @@ -0,0 +1,83 @@ +import asyncio +import os +import re +from openai import AsyncOpenAI +from bs4 import BeautifulSoup + + +async def process_tasks(prompts): + tasks = [generate_image(prompt) for prompt in prompts] + results = await asyncio.gather(*tasks, return_exceptions=True) + + processed_results = [] + for result in results: + if isinstance(result, Exception): + print(f"An exception occurred: {result}") + processed_results.append(None) + else: + processed_results.append(result) + + return processed_results + + +async def generate_image(prompt): + client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + image_params = { + "model": "dall-e-3", + "quality": "standard", + "style": "natural", + "n": 1, + "size": "1024x1024", + "prompt": prompt, + } + res = await client.images.generate(**image_params) + return res.data[0].url + + +def extract_dimensions(url): + # Regular expression to match numbers in the format '300x200' + matches = re.findall(r"(\d+)x(\d+)", url) + + if matches: + width, height = matches[0] # Extract the first match + width = int(width) + height = int(height) + return (width, height) + else: + return (100, 100) + + +async def generate_images(code): + # Find all images and extract their alt texts + soup = BeautifulSoup(code, "html.parser") + images = soup.find_all("img") + alts = [img.get("alt", None) for img in images] + + # Exclude images with no alt text + alts = [alt for alt in alts if alt is not None] + + # Remove duplicates + prompts = list(set(alts)) + + # Generate images + results = await process_tasks(prompts) + + # Create a dict mapping alt text to image URL + mapped_image_urls = dict(zip(prompts, results)) + + # Replace alt text with image URLs + for img in images: + new_url = mapped_image_urls[img.get("alt")] + + if new_url: + # Set width and height attributes + width, height = extract_dimensions(img["src"]) + img["width"] = width + img["height"] = height + # Replace img['src'] with the mapped image URL + img["src"] = new_url + else: + print("Image generation failed for alt text:" + img.get("alt")) + + # Return the modified HTML + return str(soup) diff --git a/backend/main.py b/backend/main.py index 7d30ebc..f93284a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,19 +1,26 @@ # Load environment variables first -import json from dotenv import load_dotenv -import os -from datetime import datetime - -from prompts import assemble_prompt load_dotenv() +import json +import os +import traceback +from datetime import datetime from fastapi import FastAPI, WebSocket + from llm import stream_openai_response +from mock import MOCK_HTML, mock_completion +from image_generation import generate_images +from prompts import assemble_prompt app = FastAPI() +# Useful for debugging purposes when you don't want to waste GPT4-Vision credits +# Setting to True will stream a mock response instead of calling the OpenAI API +SHOULD_MOCK_AI_RESPONSE = False + def write_logs(prompt_messages, completion): # Create run_logs directory if it doesn't exist @@ -41,14 +48,31 @@ async def stream_code_test(websocket: WebSocket): prompt_messages = assemble_prompt(params["image"]) - completion = await stream_openai_response( - prompt_messages, - lambda x: process_chunk(x), - ) + if SHOULD_MOCK_AI_RESPONSE: + completion = await mock_completion(process_chunk) + else: + completion = await stream_openai_response( + prompt_messages, + lambda x: process_chunk(x), + ) # Write the messages dict into a log so that we can debug later write_logs(prompt_messages, completion) - await websocket.send_json({"type": "status", "value": "Code generation complete."}) + # Generate images + await websocket.send_json({"type": "status", "value": "Generating images..."}) - await websocket.close() + try: + updated_html = await generate_images(completion) + await websocket.send_json({"type": "setCode", "value": updated_html}) + await websocket.send_json( + {"type": "status", "value": "Code generation complete."} + ) + except Exception as e: + traceback.print_exc() + print("Image generation failed", e) + await websocket.send_json( + {"type": "status", "value": "Image generation failed but code is complete."} + ) + finally: + await websocket.close() diff --git a/backend/mock.py b/backend/mock.py new file mode 100644 index 0000000..ec26339 --- /dev/null +++ b/backend/mock.py @@ -0,0 +1,140 @@ +import asyncio + + +async def mock_completion(process_chunk): + code_to_return = MOCK_HTML_2 + + for i in range(0, len(code_to_return), 10): + await process_chunk(code_to_return[i : i + 10]) + await asyncio.sleep(0.01) + + return code_to_return + + +MOCK_HTML = """ + + + + Product Showcase + + + + + + + +
+
+
+ Brand Logo +

WATCH SERIES 9

+

Smarter. Brighter. Mightier.

+ +
+
+ Product image of a smartwatch with a pink band and a circular interface displaying various health metrics. + Product image of a smartwatch with a blue band and a square interface showing a classic analog clock face. +
+
+
+ +""" + +MOCK_HTML_2 = """ + + + + + The New York Times - News + + + + + + +
+
+
+
+ + +
Tuesday, November 14, 2023
Today's Paper
+
+
+ The New York Times Logo +
+
+ +
Account
+
+
+ +
+
+
+
+
+
+

Israeli Military Raids Gaza’s Largest Hospital

+

Israeli troops have entered the Al-Shifa Hospital complex, where conditions have grown dire and Israel says Hamas fighters are embedded.

+ See more updates +
+ +
+
+
+ Flares and plumes of smoke over the northern Gaza skyline on Tuesday. +

From Elvis to Elopements, the Evolution of the Las Vegas Wedding

+

The glittering city that attracts thousands of couples seeking unconventional nuptials has grown beyond the drive-through wedding.

+ 8 MIN READ +
+ +
+
+
+
+
+ + +""" diff --git a/backend/poetry.lock b/backend/poetry.lock index d233cef..7a623b9 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -22,6 +22,25 @@ doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd- test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] trio = ["trio (<0.22)"] +[[package]] +name = "beautifulsoup4" +version = "4.12.2" +description = "Screen-scraping library" +category = "main" +optional = false +python-versions = ">=3.6.0" +files = [ + {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"}, + {file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"}, +] + +[package.dependencies] +soupsieve = ">1.2" + +[package.extras] +html5lib = ["html5lib"] +lxml = ["lxml"] + [[package]] name = "certifi" version = "2023.7.22" @@ -284,6 +303,18 @@ files = [ {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, ] +[[package]] +name = "soupsieve" +version = "2.5" +description = "A modern CSS selector implementation for Beautiful Soup." +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, + {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, +] + [[package]] name = "starlette" version = "0.27.0" @@ -440,4 +471,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "37bf71ae4f77aaeda11cbb524e3999464fceb20706135680a4c77add8712a847" +content-hash = "5e4aa03dda279f66a9b3d30f7327109bcfd395795470d95f8c563897ce1bff84" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index d80df48..9983a61 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -13,6 +13,7 @@ uvicorn = "^0.24.0.post1" websockets = "^12.0" openai = "^1.2.4" python-dotenv = "^1.0.0" +beautifulsoup4 = "^4.12.2" [build-system] requires = ["poetry-core"] diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 28d1bf7..19129af 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -12,7 +12,7 @@ function App() { ); const [generatedCode, setGeneratedCode] = useState(""); const [referenceImages, setReferenceImages] = useState([]); - const [console, setConsole] = useState([]); + const [executionConsole, setExecutionConsole] = useState([]); const [blobUrl, setBlobUrl] = useState(""); const createBlobUrl = () => { @@ -29,8 +29,11 @@ function App() { function (token) { setGeneratedCode((prev) => prev + token); }, + function (code) { + setGeneratedCode(code); + }, function (line) { - setConsole((prev) => [...prev, line]); + setExecutionConsole((prev) => [...prev, line]); }, function () { setAppState("CODE_READY"); @@ -67,7 +70,7 @@ function App() {

Console

- {console.map((line, index) => ( + {executionConsole.map((line, index) => (
- Generating... + {executionConsole.slice(-1)[0]}
diff --git a/frontend/src/generateCode.ts b/frontend/src/generateCode.ts index 8d4ef4a..787249b 100644 --- a/frontend/src/generateCode.ts +++ b/frontend/src/generateCode.ts @@ -7,6 +7,7 @@ const ERROR_MESSAGE = export function generateCode( imageUrl: string, onChange: (chunk: string) => void, + onSetCode: (code: string) => void, onStatusUpdate: (status: string) => void, onComplete: () => void ) { @@ -29,6 +30,8 @@ export function generateCode( onChange(response.value); } else if (response.type === "status") { onStatusUpdate(response.value); + } else if (response.type === "setCode") { + onSetCode(response.value); } });