Compare commits

...

2 Commits
hosted ... main

Author SHA1 Message Date
Abi Raja
a53afd350d update sdxl-lightning references to flux 2024-09-20 16:45:16 +02:00
Abi Raja
6899c7792e when all generations fail, print the all the underlying exceptions for debugging 2024-09-20 13:56:20 +02:00
3 changed files with 13 additions and 6 deletions

View File

@ -11,7 +11,7 @@ async def process_tasks(
prompts: List[str],
api_key: str,
base_url: str | None,
model: Literal["dalle3", "sdxl-lightning"],
model: Literal["dalle3", "flux"],
):
import time
@ -98,7 +98,7 @@ async def generate_images(
api_key: str,
base_url: Union[str, None],
image_cache: Dict[str, str],
model: Literal["dalle3", "sdxl-lightning"] = "dalle3",
model: Literal["dalle3", "flux"] = "dalle3",
) -> str:
# Find all images
soup = BeautifulSoup(code, "html.parser")

View File

@ -1,5 +1,6 @@
import asyncio
from dataclasses import dataclass
import traceback
from fastapi import APIRouter, WebSocket
import openai
from codegen.utils import extract_html_content
@ -63,7 +64,7 @@ async def perform_image_generation(
return completion
if replicate_api_key:
image_generation_model = "sdxl-lightning"
image_generation_model = "flux"
api_key = replicate_api_key
else:
if not openai_api_key:
@ -321,6 +322,12 @@ async def stream_code(websocket: WebSocket):
)
if all_generations_failed:
await throw_error("Error generating code. Please contact support.")
# Print the all the underlying exceptions for debugging
for completion in completions:
traceback.print_exception(
type(completion), completion, completion.__traceback__
)
raise Exception("All generations failed")
# If some completions failed, replace them with empty strings

View File

@ -41,7 +41,7 @@ OUTPUT_DIR: str = "generated_images"
async def generate_and_save_images(
prompts: List[str],
model: Literal["dalle3", "sdxl-lightning"],
model: Literal["dalle3", "flux"],
api_key: Optional[str],
) -> None:
# Ensure the output directory exists
@ -64,7 +64,7 @@ async def generate_and_save_images(
image_data: bytes = await response.read()
# Save the image with a filename based on the input eval
prefix = "replicate_" if model == "sdxl-lightning" else "dalle3_"
prefix = "replicate_" if model == "flux" else "dalle3_"
filename: str = (
f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png"
)
@ -78,7 +78,7 @@ async def generate_and_save_images(
async def main() -> None:
# await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY)
await generate_and_save_images(EVALS, "sdxl-lightning", REPLICATE_API_TOKEN)
await generate_and_save_images(EVALS, "flux", REPLICATE_API_TOKEN)
if __name__ == "__main__":