Merge pull request #348 from naman1608/type-fixes
full typed in image generation file
This commit is contained in:
commit
561ac0b088
@ -5,7 +5,7 @@ from openai import AsyncOpenAI
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
async def process_tasks(prompts: List[str], api_key: str, base_url: str):
|
||||
async def process_tasks(prompts: List[str], api_key: str, base_url: str | None):
|
||||
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
@ -15,22 +15,23 @@ async def process_tasks(prompts: List[str], api_key: str, base_url: str):
|
||||
print(f"An exception occurred: {result}")
|
||||
processed_results.append(None)
|
||||
else:
|
||||
processed_results.append(result) # type: ignore
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
|
||||
|
||||
async def generate_image(prompt: str, api_key: str, base_url: str):
|
||||
async def generate_image(
|
||||
prompt: str, api_key: str, base_url: str | None
|
||||
) -> Union[str, None]:
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
image_params: Dict[str, Union[str, int]] = {
|
||||
"model": "dall-e-3",
|
||||
"quality": "standard",
|
||||
"style": "natural",
|
||||
"n": 1,
|
||||
"size": "1024x1024",
|
||||
"prompt": prompt,
|
||||
}
|
||||
res = await client.images.generate(**image_params) # type: ignore
|
||||
res = await client.images.generate(
|
||||
model="dall-e-3",
|
||||
quality="standard",
|
||||
style="natural",
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
prompt=prompt,
|
||||
)
|
||||
await client.close()
|
||||
return res.data[0].url
|
||||
|
||||
@ -63,13 +64,13 @@ def create_alt_url_mapping(code: str) -> Dict[str, str]:
|
||||
|
||||
async def generate_images(
|
||||
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
|
||||
):
|
||||
) -> str:
|
||||
# Find all images
|
||||
soup = BeautifulSoup(code, "html.parser")
|
||||
images = soup.find_all("img")
|
||||
|
||||
# Extract alt texts as image prompts
|
||||
alts = []
|
||||
alts: List[str | None] = []
|
||||
for img in images:
|
||||
# Only include URL if the image starts with https://placehold.co
|
||||
# and it's not already in the image_cache
|
||||
@ -77,26 +78,26 @@ async def generate_images(
|
||||
img["src"].startswith("https://placehold.co")
|
||||
and image_cache.get(img.get("alt")) is None
|
||||
):
|
||||
alts.append(img.get("alt", None)) # type: ignore
|
||||
alts.append(img.get("alt", None))
|
||||
|
||||
# Exclude images with no alt text
|
||||
alts = [alt for alt in alts if alt is not None] # type: ignore
|
||||
filtered_alts: List[str] = [alt for alt in alts if alt is not None]
|
||||
|
||||
# Remove duplicates
|
||||
prompts = list(set(alts)) # type: ignore
|
||||
prompts = list(set(filtered_alts))
|
||||
|
||||
# Return early if there are no images to replace
|
||||
if len(prompts) == 0: # type: ignore
|
||||
if len(prompts) == 0:
|
||||
return code
|
||||
|
||||
# Generate images
|
||||
results = await process_tasks(prompts, api_key, base_url) # type: ignore
|
||||
results = await process_tasks(prompts, api_key, base_url)
|
||||
|
||||
# Create a dict mapping alt text to image URL
|
||||
mapped_image_urls = dict(zip(prompts, results)) # type: ignore
|
||||
mapped_image_urls = dict(zip(prompts, results))
|
||||
|
||||
# Merge with image_cache
|
||||
mapped_image_urls = {**mapped_image_urls, **image_cache} # type: ignore
|
||||
mapped_image_urls = {**mapped_image_urls, **image_cache}
|
||||
|
||||
# Replace old image URLs with the generated URLs
|
||||
for img in images:
|
||||
|
||||
@ -13,7 +13,7 @@ from llm import (
|
||||
)
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from mock_llm import mock_completion
|
||||
from typing import Dict, List, cast, get_args
|
||||
from typing import Dict, List, Union, cast, get_args
|
||||
from image_generation import create_alt_url_mapping, generate_images
|
||||
from prompts import assemble_imported_code_prompt, assemble_prompt
|
||||
from datetime import datetime
|
||||
@ -132,7 +132,7 @@ async def stream_code(websocket: WebSocket):
|
||||
print("Using Anthropic API key from environment variable")
|
||||
|
||||
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
|
||||
openai_base_url = None
|
||||
openai_base_url: Union[str, None] = None
|
||||
# Disable user-specified OpenAI Base URL in prod
|
||||
if not os.environ.get("IS_PROD"):
|
||||
if "openAiBaseURL" in params and params["openAiBaseURL"]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user