From 8717298defff267ad60a63f8581692bdf1d2f1e3 Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Tue, 17 Sep 2024 12:20:19 +0200 Subject: [PATCH] switch to flux schnell for replicate image gen --- backend/image_generation/core.py | 12 ++++-------- backend/image_generation/replicate.py | 21 ++++++++++++--------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/backend/image_generation/core.py b/backend/image_generation/core.py index 91a5535..7615127 100644 --- a/backend/image_generation/core.py +++ b/backend/image_generation/core.py @@ -54,18 +54,14 @@ async def generate_image_dalle( async def generate_image_replicate(prompt: str, api_key: str) -> str: - # We use SDXL Lightning + # We use Flux Schnell return await call_replicate( - "5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f", { - "width": 1024, - "height": 1024, "prompt": prompt, - "scheduler": "K_EULER", "num_outputs": 1, - "guidance_scale": 0, - "negative_prompt": "worst quality, low quality", - "num_inference_steps": 4, + "aspect_ratio": "1:1", + "output_format": "png", + "output_quality": 100, }, api_key, ) diff --git a/backend/image_generation/replicate.py b/backend/image_generation/replicate.py index e8b8748..86dd6f1 100644 --- a/backend/image_generation/replicate.py +++ b/backend/image_generation/replicate.py @@ -2,20 +2,21 @@ import asyncio import httpx -async def call_replicate( - replicate_model_version: str, input: dict[str, str | int], api_token: str -) -> str: - url = "https://api.replicate.com/v1/predictions" +async def call_replicate(input: dict[str, str | int], api_token: str) -> str: headers = { "Authorization": f"Bearer {api_token}", "Content-Type": "application/json", } - data = {"version": replicate_model_version, "input": input} + data = {"input": input} async with httpx.AsyncClient() as client: try: - response = await client.post(url, headers=headers, json=data) + response = await client.post( + "https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions", + headers=headers, + json=data, + ) response.raise_for_status() response_json = response.json() @@ -24,16 +25,18 @@ async def call_replicate( if not prediction_id: raise ValueError("Prediction ID not found in initial response.") - # Polling every 1 second until the status is succeeded or error + # Polling every 0.1 seconds until the status is succeeded or error (upto 10s) num_polls = 0 max_polls = 100 while num_polls < max_polls: num_polls += 1 - await asyncio.sleep(0.2) + await asyncio.sleep(0.1) # Check the status - status_check_url = f"{url}/{prediction_id}" + status_check_url = ( + f"https://api.replicate.com/v1/predictions/{prediction_id}" + ) status_response = await client.get(status_check_url, headers=headers) status_response.raise_for_status() status_response_json = status_response.json()