poll more frequently and print timing logs
This commit is contained in:
parent
52099e0853
commit
4b0adc5769
@ -13,11 +13,17 @@ async def process_tasks(
|
||||
base_url: str | None,
|
||||
model: Literal["dalle3", "sdxl-lightning"],
|
||||
):
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
if model == "dalle3":
|
||||
tasks = [generate_image_dalle(prompt, api_key, base_url) for prompt in prompts]
|
||||
else:
|
||||
tasks = [generate_image_replicate(prompt, api_key) for prompt in prompts]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
end_time = time.time()
|
||||
generation_time = end_time - start_time
|
||||
print(f"Image generation time: {generation_time:.2f} seconds")
|
||||
|
||||
processed_results: List[Union[str, None]] = []
|
||||
for result in results:
|
||||
|
||||
@ -26,11 +26,11 @@ async def call_replicate(
|
||||
|
||||
# Polling every 1 second until the status is succeeded or error
|
||||
num_polls = 0
|
||||
max_polls = 120
|
||||
max_polls = 100
|
||||
while num_polls < max_polls:
|
||||
num_polls += 1
|
||||
|
||||
await asyncio.sleep(1)
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Check the status
|
||||
status_check_url = f"{url}/{prediction_id}"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user