Compare commits

...

2 Commits
hosted ... main

Author SHA1 Message Date
Abi Raja
a53afd350d update sdxl-lightning references to flux 2024-09-20 16:45:16 +02:00
Abi Raja
6899c7792e when all generations fail, print the all the underlying exceptions for debugging 2024-09-20 13:56:20 +02:00
3 changed files with 13 additions and 6 deletions

View File

@ -11,7 +11,7 @@ async def process_tasks(
prompts: List[str], prompts: List[str],
api_key: str, api_key: str,
base_url: str | None, base_url: str | None,
model: Literal["dalle3", "sdxl-lightning"], model: Literal["dalle3", "flux"],
): ):
import time import time
@ -98,7 +98,7 @@ async def generate_images(
api_key: str, api_key: str,
base_url: Union[str, None], base_url: Union[str, None],
image_cache: Dict[str, str], image_cache: Dict[str, str],
model: Literal["dalle3", "sdxl-lightning"] = "dalle3", model: Literal["dalle3", "flux"] = "dalle3",
) -> str: ) -> str:
# Find all images # Find all images
soup = BeautifulSoup(code, "html.parser") soup = BeautifulSoup(code, "html.parser")

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
import traceback
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import openai import openai
from codegen.utils import extract_html_content from codegen.utils import extract_html_content
@ -63,7 +64,7 @@ async def perform_image_generation(
return completion return completion
if replicate_api_key: if replicate_api_key:
image_generation_model = "sdxl-lightning" image_generation_model = "flux"
api_key = replicate_api_key api_key = replicate_api_key
else: else:
if not openai_api_key: if not openai_api_key:
@ -321,6 +322,12 @@ async def stream_code(websocket: WebSocket):
) )
if all_generations_failed: if all_generations_failed:
await throw_error("Error generating code. Please contact support.") await throw_error("Error generating code. Please contact support.")
# Print the all the underlying exceptions for debugging
for completion in completions:
traceback.print_exception(
type(completion), completion, completion.__traceback__
)
raise Exception("All generations failed") raise Exception("All generations failed")
# If some completions failed, replace them with empty strings # If some completions failed, replace them with empty strings

View File

@ -41,7 +41,7 @@ OUTPUT_DIR: str = "generated_images"
async def generate_and_save_images( async def generate_and_save_images(
prompts: List[str], prompts: List[str],
model: Literal["dalle3", "sdxl-lightning"], model: Literal["dalle3", "flux"],
api_key: Optional[str], api_key: Optional[str],
) -> None: ) -> None:
# Ensure the output directory exists # Ensure the output directory exists
@ -64,7 +64,7 @@ async def generate_and_save_images(
image_data: bytes = await response.read() image_data: bytes = await response.read()
# Save the image with a filename based on the input eval # Save the image with a filename based on the input eval
prefix = "replicate_" if model == "sdxl-lightning" else "dalle3_" prefix = "replicate_" if model == "flux" else "dalle3_"
filename: str = ( filename: str = (
f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png" f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png"
) )
@ -78,7 +78,7 @@ async def generate_and_save_images(
async def main() -> None: async def main() -> None:
# await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY) # await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY)
await generate_and_save_images(EVALS, "sdxl-lightning", REPLICATE_API_TOKEN) await generate_and_save_images(EVALS, "flux", REPLICATE_API_TOKEN)
if __name__ == "__main__": if __name__ == "__main__":