poll more frequently and print timing logs
This commit is contained in:
parent
52099e0853
commit
4b0adc5769
@ -13,11 +13,17 @@ async def process_tasks(
|
|||||||
base_url: str | None,
|
base_url: str | None,
|
||||||
model: Literal["dalle3", "sdxl-lightning"],
|
model: Literal["dalle3", "sdxl-lightning"],
|
||||||
):
|
):
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
if model == "dalle3":
|
if model == "dalle3":
|
||||||
tasks = [generate_image_dalle(prompt, api_key, base_url) for prompt in prompts]
|
tasks = [generate_image_dalle(prompt, api_key, base_url) for prompt in prompts]
|
||||||
else:
|
else:
|
||||||
tasks = [generate_image_replicate(prompt, api_key) for prompt in prompts]
|
tasks = [generate_image_replicate(prompt, api_key) for prompt in prompts]
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
end_time = time.time()
|
||||||
|
generation_time = end_time - start_time
|
||||||
|
print(f"Image generation time: {generation_time:.2f} seconds")
|
||||||
|
|
||||||
processed_results: List[Union[str, None]] = []
|
processed_results: List[Union[str, None]] = []
|
||||||
for result in results:
|
for result in results:
|
||||||
|
|||||||
@ -26,11 +26,11 @@ async def call_replicate(
|
|||||||
|
|
||||||
# Polling every 1 second until the status is succeeded or error
|
# Polling every 1 second until the status is succeeded or error
|
||||||
num_polls = 0
|
num_polls = 0
|
||||||
max_polls = 120
|
max_polls = 100
|
||||||
while num_polls < max_polls:
|
while num_polls < max_polls:
|
||||||
num_polls += 1
|
num_polls += 1
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
# Check the status
|
# Check the status
|
||||||
status_check_url = f"{url}/{prediction_id}"
|
status_check_url = f"{url}/{prediction_id}"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user