diff --git a/backend/image_generation/core.py b/backend/image_generation/core.py index 7615127..cff890c 100644 --- a/backend/image_generation/core.py +++ b/backend/image_generation/core.py @@ -11,7 +11,7 @@ async def process_tasks( prompts: List[str], api_key: str, base_url: str | None, - model: Literal["dalle3", "sdxl-lightning"], + model: Literal["dalle3", "flux"], ): import time @@ -98,7 +98,7 @@ async def generate_images( api_key: str, base_url: Union[str, None], image_cache: Dict[str, str], - model: Literal["dalle3", "sdxl-lightning"] = "dalle3", + model: Literal["dalle3", "flux"] = "dalle3", ) -> str: # Find all images soup = BeautifulSoup(code, "html.parser") diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 866c7c0..71b764d 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -64,7 +64,7 @@ async def perform_image_generation( return completion if replicate_api_key: - image_generation_model = "sdxl-lightning" + image_generation_model = "flux" api_key = replicate_api_key else: if not openai_api_key: diff --git a/backend/run_image_generation_evals.py b/backend/run_image_generation_evals.py index 83a1bd8..49e1c43 100644 --- a/backend/run_image_generation_evals.py +++ b/backend/run_image_generation_evals.py @@ -41,7 +41,7 @@ OUTPUT_DIR: str = "generated_images" async def generate_and_save_images( prompts: List[str], - model: Literal["dalle3", "sdxl-lightning"], + model: Literal["dalle3", "flux"], api_key: Optional[str], ) -> None: # Ensure the output directory exists @@ -64,7 +64,7 @@ async def generate_and_save_images( image_data: bytes = await response.read() # Save the image with a filename based on the input eval - prefix = "replicate_" if model == "sdxl-lightning" else "dalle3_" + prefix = "replicate_" if model == "flux" else "dalle3_" filename: str = ( f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png" ) @@ -78,7 +78,7 @@ async def generate_and_save_images( async def main() -> None: # await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY) - await generate_and_save_images(EVALS, "sdxl-lightning", REPLICATE_API_TOKEN) + await generate_and_save_images(EVALS, "flux", REPLICATE_API_TOKEN) if __name__ == "__main__":