switch to flux schnell for replicate image gen
This commit is contained in:
parent
a4087e613f
commit
8717298def
@ -54,18 +54,14 @@ async def generate_image_dalle(
|
||||
|
||||
async def generate_image_replicate(prompt: str, api_key: str) -> str:
|
||||
|
||||
# We use SDXL Lightning
|
||||
# We use Flux Schnell
|
||||
return await call_replicate(
|
||||
"5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
|
||||
{
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"prompt": prompt,
|
||||
"scheduler": "K_EULER",
|
||||
"num_outputs": 1,
|
||||
"guidance_scale": 0,
|
||||
"negative_prompt": "worst quality, low quality",
|
||||
"num_inference_steps": 4,
|
||||
"aspect_ratio": "1:1",
|
||||
"output_format": "png",
|
||||
"output_quality": 100,
|
||||
},
|
||||
api_key,
|
||||
)
|
||||
|
||||
@ -2,20 +2,21 @@ import asyncio
|
||||
import httpx
|
||||
|
||||
|
||||
async def call_replicate(
|
||||
replicate_model_version: str, input: dict[str, str | int], api_token: str
|
||||
) -> str:
|
||||
url = "https://api.replicate.com/v1/predictions"
|
||||
async def call_replicate(input: dict[str, str | int], api_token: str) -> str:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {"version": replicate_model_version, "input": input}
|
||||
data = {"input": input}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response = await client.post(
|
||||
"https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions",
|
||||
headers=headers,
|
||||
json=data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
@ -24,16 +25,18 @@ async def call_replicate(
|
||||
if not prediction_id:
|
||||
raise ValueError("Prediction ID not found in initial response.")
|
||||
|
||||
# Polling every 1 second until the status is succeeded or error
|
||||
# Polling every 0.1 seconds until the status is succeeded or error (upto 10s)
|
||||
num_polls = 0
|
||||
max_polls = 100
|
||||
while num_polls < max_polls:
|
||||
num_polls += 1
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Check the status
|
||||
status_check_url = f"{url}/{prediction_id}"
|
||||
status_check_url = (
|
||||
f"https://api.replicate.com/v1/predictions/{prediction_id}"
|
||||
)
|
||||
status_response = await client.get(status_check_url, headers=headers)
|
||||
status_response.raise_for_status()
|
||||
status_response_json = status_response.json()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user