Merge pull request #348 from naman1608/type-fixes

full typed in image generation file
This commit is contained in:
Abi Raja 2024-06-05 15:21:41 -04:00 committed by GitHub
commit 561ac0b088
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 23 deletions

View File

@ -5,7 +5,7 @@ from openai import AsyncOpenAI
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
async def process_tasks(prompts: List[str], api_key: str, base_url: str): async def process_tasks(prompts: List[str], api_key: str, base_url: str | None):
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts] tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
@ -15,22 +15,23 @@ async def process_tasks(prompts: List[str], api_key: str, base_url: str):
print(f"An exception occurred: {result}") print(f"An exception occurred: {result}")
processed_results.append(None) processed_results.append(None)
else: else:
processed_results.append(result) # type: ignore processed_results.append(result)
return processed_results return processed_results
async def generate_image(prompt: str, api_key: str, base_url: str): async def generate_image(
prompt: str, api_key: str, base_url: str | None
) -> Union[str, None]:
client = AsyncOpenAI(api_key=api_key, base_url=base_url) client = AsyncOpenAI(api_key=api_key, base_url=base_url)
image_params: Dict[str, Union[str, int]] = { res = await client.images.generate(
"model": "dall-e-3", model="dall-e-3",
"quality": "standard", quality="standard",
"style": "natural", style="natural",
"n": 1, n=1,
"size": "1024x1024", size="1024x1024",
"prompt": prompt, prompt=prompt,
} )
res = await client.images.generate(**image_params) # type: ignore
await client.close() await client.close()
return res.data[0].url return res.data[0].url
@ -63,13 +64,13 @@ def create_alt_url_mapping(code: str) -> Dict[str, str]:
async def generate_images( async def generate_images(
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str] code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
): ) -> str:
# Find all images # Find all images
soup = BeautifulSoup(code, "html.parser") soup = BeautifulSoup(code, "html.parser")
images = soup.find_all("img") images = soup.find_all("img")
# Extract alt texts as image prompts # Extract alt texts as image prompts
alts = [] alts: List[str | None] = []
for img in images: for img in images:
# Only include URL if the image starts with https://placehold.co # Only include URL if the image starts with https://placehold.co
# and it's not already in the image_cache # and it's not already in the image_cache
@ -77,26 +78,26 @@ async def generate_images(
img["src"].startswith("https://placehold.co") img["src"].startswith("https://placehold.co")
and image_cache.get(img.get("alt")) is None and image_cache.get(img.get("alt")) is None
): ):
alts.append(img.get("alt", None)) # type: ignore alts.append(img.get("alt", None))
# Exclude images with no alt text # Exclude images with no alt text
alts = [alt for alt in alts if alt is not None] # type: ignore filtered_alts: List[str] = [alt for alt in alts if alt is not None]
# Remove duplicates # Remove duplicates
prompts = list(set(alts)) # type: ignore prompts = list(set(filtered_alts))
# Return early if there are no images to replace # Return early if there are no images to replace
if len(prompts) == 0: # type: ignore if len(prompts) == 0:
return code return code
# Generate images # Generate images
results = await process_tasks(prompts, api_key, base_url) # type: ignore results = await process_tasks(prompts, api_key, base_url)
# Create a dict mapping alt text to image URL # Create a dict mapping alt text to image URL
mapped_image_urls = dict(zip(prompts, results)) # type: ignore mapped_image_urls = dict(zip(prompts, results))
# Merge with image_cache # Merge with image_cache
mapped_image_urls = {**mapped_image_urls, **image_cache} # type: ignore mapped_image_urls = {**mapped_image_urls, **image_cache}
# Replace old image URLs with the generated URLs # Replace old image URLs with the generated URLs
for img in images: for img in images:

View File

@ -13,7 +13,7 @@ from llm import (
) )
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion from mock_llm import mock_completion
from typing import Dict, List, cast, get_args from typing import Dict, List, Union, cast, get_args
from image_generation import create_alt_url_mapping, generate_images from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_imported_code_prompt, assemble_prompt from prompts import assemble_imported_code_prompt, assemble_prompt
from datetime import datetime from datetime import datetime
@ -132,7 +132,7 @@ async def stream_code(websocket: WebSocket):
print("Using Anthropic API key from environment variable") print("Using Anthropic API key from environment variable")
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
openai_base_url = None openai_base_url: Union[str, None] = None
# Disable user-specified OpenAI Base URL in prod # Disable user-specified OpenAI Base URL in prod
if not os.environ.get("IS_PROD"): if not os.environ.get("IS_PROD"):
if "openAiBaseURL" in params and params["openAiBaseURL"]: if "openAiBaseURL" in params and params["openAiBaseURL"]: