diff --git a/backend/image_generation.py b/backend/image_generation.py index b93792c..e3f609f 100644 --- a/backend/image_generation.py +++ b/backend/image_generation.py @@ -5,7 +5,7 @@ from openai import AsyncOpenAI from bs4 import BeautifulSoup -async def process_tasks(prompts: List[str], api_key: str, base_url: str): +async def process_tasks(prompts: List[str], api_key: str, base_url: str | None): tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts] results = await asyncio.gather(*tasks, return_exceptions=True) @@ -15,22 +15,23 @@ async def process_tasks(prompts: List[str], api_key: str, base_url: str): print(f"An exception occurred: {result}") processed_results.append(None) else: - processed_results.append(result) # type: ignore + processed_results.append(result) return processed_results -async def generate_image(prompt: str, api_key: str, base_url: str): +async def generate_image( + prompt: str, api_key: str, base_url: str | None +) -> Union[str, None]: client = AsyncOpenAI(api_key=api_key, base_url=base_url) - image_params: Dict[str, Union[str, int]] = { - "model": "dall-e-3", - "quality": "standard", - "style": "natural", - "n": 1, - "size": "1024x1024", - "prompt": prompt, - } - res = await client.images.generate(**image_params) # type: ignore + res = await client.images.generate( + model="dall-e-3", + quality="standard", + style="natural", + n=1, + size="1024x1024", + prompt=prompt, + ) await client.close() return res.data[0].url @@ -63,13 +64,13 @@ def create_alt_url_mapping(code: str) -> Dict[str, str]: async def generate_images( code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str] -): +) -> str: # Find all images soup = BeautifulSoup(code, "html.parser") images = soup.find_all("img") # Extract alt texts as image prompts - alts = [] + alts: List[str | None] = [] for img in images: # Only include URL if the image starts with https://placehold.co # and it's not already in the image_cache @@ -77,26 +78,26 @@ async def generate_images( img["src"].startswith("https://placehold.co") and image_cache.get(img.get("alt")) is None ): - alts.append(img.get("alt", None)) # type: ignore + alts.append(img.get("alt", None)) # Exclude images with no alt text - alts = [alt for alt in alts if alt is not None] # type: ignore + filtered_alts: List[str] = [alt for alt in alts if alt is not None] # Remove duplicates - prompts = list(set(alts)) # type: ignore + prompts = list(set(filtered_alts)) # Return early if there are no images to replace - if len(prompts) == 0: # type: ignore + if len(prompts) == 0: return code # Generate images - results = await process_tasks(prompts, api_key, base_url) # type: ignore + results = await process_tasks(prompts, api_key, base_url) # Create a dict mapping alt text to image URL - mapped_image_urls = dict(zip(prompts, results)) # type: ignore + mapped_image_urls = dict(zip(prompts, results)) # Merge with image_cache - mapped_image_urls = {**mapped_image_urls, **image_cache} # type: ignore + mapped_image_urls = {**mapped_image_urls, **image_cache} # Replace old image URLs with the generated URLs for img in images: diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 3fb92f1..379042e 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -13,7 +13,7 @@ from llm import ( ) from openai.types.chat import ChatCompletionMessageParam from mock_llm import mock_completion -from typing import Dict, List, cast, get_args +from typing import Dict, List, Union, cast, get_args from image_generation import create_alt_url_mapping, generate_images from prompts import assemble_imported_code_prompt, assemble_prompt from datetime import datetime @@ -132,7 +132,7 @@ async def stream_code(websocket: WebSocket): print("Using Anthropic API key from environment variable") # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. - openai_base_url = None + openai_base_url: Union[str, None] = None # Disable user-specified OpenAI Base URL in prod if not os.environ.get("IS_PROD"): if "openAiBaseURL" in params and params["openAiBaseURL"]: