poll more frequently and print timing logs

This commit is contained in:
Abi Raja 2024-07-29 16:33:04 -04:00
parent 52099e0853
commit 4b0adc5769
2 changed files with 8 additions and 2 deletions

View File

@ -13,11 +13,17 @@ async def process_tasks(
base_url: str | None,
model: Literal["dalle3", "sdxl-lightning"],
):
import time
start_time = time.time()
if model == "dalle3":
tasks = [generate_image_dalle(prompt, api_key, base_url) for prompt in prompts]
else:
tasks = [generate_image_replicate(prompt, api_key) for prompt in prompts]
results = await asyncio.gather(*tasks, return_exceptions=True)
end_time = time.time()
generation_time = end_time - start_time
print(f"Image generation time: {generation_time:.2f} seconds")
processed_results: List[Union[str, None]] = []
for result in results:

View File

@ -26,11 +26,11 @@ async def call_replicate(
# Polling every 1 second until the status is succeeded or error
num_polls = 0
max_polls = 120
max_polls = 100
while num_polls < max_polls:
num_polls += 1
await asyncio.sleep(1)
await asyncio.sleep(0.2)
# Check the status
status_check_url = f"{url}/{prediction_id}"