switch to flux schnell for replicate image gen

This commit is contained in:
Abi Raja 2024-09-17 12:20:19 +02:00
parent a4087e613f
commit 8717298def
2 changed files with 16 additions and 17 deletions

View File

@ -54,18 +54,14 @@ async def generate_image_dalle(
async def generate_image_replicate(prompt: str, api_key: str) -> str: async def generate_image_replicate(prompt: str, api_key: str) -> str:
# We use SDXL Lightning # We use Flux Schnell
return await call_replicate( return await call_replicate(
"5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
{ {
"width": 1024,
"height": 1024,
"prompt": prompt, "prompt": prompt,
"scheduler": "K_EULER",
"num_outputs": 1, "num_outputs": 1,
"guidance_scale": 0, "aspect_ratio": "1:1",
"negative_prompt": "worst quality, low quality", "output_format": "png",
"num_inference_steps": 4, "output_quality": 100,
}, },
api_key, api_key,
) )

View File

@ -2,20 +2,21 @@ import asyncio
import httpx import httpx
async def call_replicate( async def call_replicate(input: dict[str, str | int], api_token: str) -> str:
replicate_model_version: str, input: dict[str, str | int], api_token: str
) -> str:
url = "https://api.replicate.com/v1/predictions"
headers = { headers = {
"Authorization": f"Bearer {api_token}", "Authorization": f"Bearer {api_token}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
data = {"version": replicate_model_version, "input": input} data = {"input": input}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: try:
response = await client.post(url, headers=headers, json=data) response = await client.post(
"https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions",
headers=headers,
json=data,
)
response.raise_for_status() response.raise_for_status()
response_json = response.json() response_json = response.json()
@ -24,16 +25,18 @@ async def call_replicate(
if not prediction_id: if not prediction_id:
raise ValueError("Prediction ID not found in initial response.") raise ValueError("Prediction ID not found in initial response.")
# Polling every 1 second until the status is succeeded or error # Polling every 0.1 seconds until the status is succeeded or error (upto 10s)
num_polls = 0 num_polls = 0
max_polls = 100 max_polls = 100
while num_polls < max_polls: while num_polls < max_polls:
num_polls += 1 num_polls += 1
await asyncio.sleep(0.2) await asyncio.sleep(0.1)
# Check the status # Check the status
status_check_url = f"{url}/{prediction_id}" status_check_url = (
f"https://api.replicate.com/v1/predictions/{prediction_id}"
)
status_response = await client.get(status_check_url, headers=headers) status_response = await client.get(status_check_url, headers=headers)
status_response.raise_for_status() status_response.raise_for_status()
status_response_json = status_response.json() status_response_json = status_response.json()