switch to flux schnell for replicate image gen

This commit is contained in:
Abi Raja 2024-09-17 12:20:19 +02:00
parent a4087e613f
commit 8717298def
2 changed files with 16 additions and 17 deletions

View File

@ -54,18 +54,14 @@ async def generate_image_dalle(
async def generate_image_replicate(prompt: str, api_key: str) -> str:
# We use SDXL Lightning
# We use Flux Schnell
return await call_replicate(
"5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
{
"width": 1024,
"height": 1024,
"prompt": prompt,
"scheduler": "K_EULER",
"num_outputs": 1,
"guidance_scale": 0,
"negative_prompt": "worst quality, low quality",
"num_inference_steps": 4,
"aspect_ratio": "1:1",
"output_format": "png",
"output_quality": 100,
},
api_key,
)

View File

@ -2,20 +2,21 @@ import asyncio
import httpx
async def call_replicate(
replicate_model_version: str, input: dict[str, str | int], api_token: str
) -> str:
url = "https://api.replicate.com/v1/predictions"
async def call_replicate(input: dict[str, str | int], api_token: str) -> str:
headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json",
}
data = {"version": replicate_model_version, "input": input}
data = {"input": input}
async with httpx.AsyncClient() as client:
try:
response = await client.post(url, headers=headers, json=data)
response = await client.post(
"https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions",
headers=headers,
json=data,
)
response.raise_for_status()
response_json = response.json()
@ -24,16 +25,18 @@ async def call_replicate(
if not prediction_id:
raise ValueError("Prediction ID not found in initial response.")
# Polling every 1 second until the status is succeeded or error
# Polling every 0.1 seconds until the status is succeeded or error (upto 10s)
num_polls = 0
max_polls = 100
while num_polls < max_polls:
num_polls += 1
await asyncio.sleep(0.2)
await asyncio.sleep(0.1)
# Check the status
status_check_url = f"{url}/{prediction_id}"
status_check_url = (
f"https://api.replicate.com/v1/predictions/{prediction_id}"
)
status_response = await client.get(status_check_url, headers=headers)
status_response.raise_for_status()
status_response_json = status_response.json()