switch to flux schnell for replicate image gen
This commit is contained in:
parent
a4087e613f
commit
8717298def
@ -54,18 +54,14 @@ async def generate_image_dalle(
|
|||||||
|
|
||||||
async def generate_image_replicate(prompt: str, api_key: str) -> str:
|
async def generate_image_replicate(prompt: str, api_key: str) -> str:
|
||||||
|
|
||||||
# We use SDXL Lightning
|
# We use Flux Schnell
|
||||||
return await call_replicate(
|
return await call_replicate(
|
||||||
"5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
|
|
||||||
{
|
{
|
||||||
"width": 1024,
|
|
||||||
"height": 1024,
|
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"scheduler": "K_EULER",
|
|
||||||
"num_outputs": 1,
|
"num_outputs": 1,
|
||||||
"guidance_scale": 0,
|
"aspect_ratio": "1:1",
|
||||||
"negative_prompt": "worst quality, low quality",
|
"output_format": "png",
|
||||||
"num_inference_steps": 4,
|
"output_quality": 100,
|
||||||
},
|
},
|
||||||
api_key,
|
api_key,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,20 +2,21 @@ import asyncio
|
|||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
async def call_replicate(
|
async def call_replicate(input: dict[str, str | int], api_token: str) -> str:
|
||||||
replicate_model_version: str, input: dict[str, str | int], api_token: str
|
|
||||||
) -> str:
|
|
||||||
url = "https://api.replicate.com/v1/predictions"
|
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {api_token}",
|
"Authorization": f"Bearer {api_token}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
data = {"version": replicate_model_version, "input": input}
|
data = {"input": input}
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
try:
|
try:
|
||||||
response = await client.post(url, headers=headers, json=data)
|
response = await client.post(
|
||||||
|
"https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions",
|
||||||
|
headers=headers,
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
@ -24,16 +25,18 @@ async def call_replicate(
|
|||||||
if not prediction_id:
|
if not prediction_id:
|
||||||
raise ValueError("Prediction ID not found in initial response.")
|
raise ValueError("Prediction ID not found in initial response.")
|
||||||
|
|
||||||
# Polling every 1 second until the status is succeeded or error
|
# Polling every 0.1 seconds until the status is succeeded or error (upto 10s)
|
||||||
num_polls = 0
|
num_polls = 0
|
||||||
max_polls = 100
|
max_polls = 100
|
||||||
while num_polls < max_polls:
|
while num_polls < max_polls:
|
||||||
num_polls += 1
|
num_polls += 1
|
||||||
|
|
||||||
await asyncio.sleep(0.2)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
# Check the status
|
# Check the status
|
||||||
status_check_url = f"{url}/{prediction_id}"
|
status_check_url = (
|
||||||
|
f"https://api.replicate.com/v1/predictions/{prediction_id}"
|
||||||
|
)
|
||||||
status_response = await client.get(status_check_url, headers=headers)
|
status_response = await client.get(status_check_url, headers=headers)
|
||||||
status_response.raise_for_status()
|
status_response.raise_for_status()
|
||||||
status_response_json = status_response.json()
|
status_response_json = status_response.json()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user