types fix in image generation file

This commit is contained in:
Naman 2024-05-29 07:11:03 +05:30
parent 23e631765e
commit d31ebcaa27
2 changed files with 16 additions and 14 deletions

View File

@ -11,16 +11,16 @@ async def process_tasks(prompts: List[str], api_key: str, base_url: str):
processed_results: List[Union[str, None]] = []
for result in results:
if isinstance(result, Exception):
if isinstance(result, BaseException):
print(f"An exception occurred: {result}")
processed_results.append(None)
else:
processed_results.append(result) # type: ignore
processed_results.append(result)
return processed_results
async def generate_image(prompt: str, api_key: str, base_url: str):
async def generate_image(prompt: str, api_key: str, base_url: str) -> Union[str, None]:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
image_params: Dict[str, Union[str, int]] = {
"model": "dall-e-3",
@ -63,13 +63,15 @@ def create_alt_url_mapping(code: str) -> Dict[str, str]:
async def generate_images(
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
):
) -> Union[str, None]:
if base_url is None:
return code
# Find all images
soup = BeautifulSoup(code, "html.parser")
images = soup.find_all("img")
# Extract alt texts as image prompts
alts = []
alts: List[str | None] = []
for img in images:
# Only include URL if the image starts with https://placehold.co
# and it's not already in the image_cache
@ -77,26 +79,26 @@ async def generate_images(
img["src"].startswith("https://placehold.co")
and image_cache.get(img.get("alt")) is None
):
alts.append(img.get("alt", None)) # type: ignore
alts.append(img.get("alt", None))
# Exclude images with no alt text
alts = [alt for alt in alts if alt is not None] # type: ignore
filtered_alts: List[str] = [alt for alt in alts if alt is not None]
# Remove duplicates
prompts = list(set(alts)) # type: ignore
prompts = list(set(filtered_alts))
# Return early if there are no images to replace
if len(prompts) == 0: # type: ignore
if len(prompts) == 0:
return code
# Generate images
results = await process_tasks(prompts, api_key, base_url) # type: ignore
results = await process_tasks(prompts, api_key, base_url)
# Create a dict mapping alt text to image URL
mapped_image_urls = dict(zip(prompts, results)) # type: ignore
mapped_image_urls = dict(zip(prompts, results))
# Merge with image_cache
mapped_image_urls = {**mapped_image_urls, **image_cache} # type: ignore
mapped_image_urls = {**mapped_image_urls, **image_cache}
# Replace old image URLs with the generated URLs
for img in images:

View File

@ -13,7 +13,7 @@ from llm import (
)
from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
from typing import Dict, List, cast, get_args
from typing import Dict, List, Union, cast, get_args
from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_imported_code_prompt, assemble_prompt
from datetime import datetime
@ -121,7 +121,7 @@ async def stream_code(websocket: WebSocket):
return
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
openai_base_url = None
openai_base_url: Union[str, None] = None
# Disable user-specified OpenAI Base URL in prod
if not os.environ.get("IS_PROD"):
if "openAiBaseURL" in params and params["openAiBaseURL"]: