poll more frequently and print timing logs

This commit is contained in:
Abi Raja 2024-07-29 16:33:04 -04:00
parent 52099e0853
commit 4b0adc5769
2 changed files with 8 additions and 2 deletions

View File

@ -13,11 +13,17 @@ async def process_tasks(
base_url: str | None, base_url: str | None,
model: Literal["dalle3", "sdxl-lightning"], model: Literal["dalle3", "sdxl-lightning"],
): ):
import time
start_time = time.time()
if model == "dalle3": if model == "dalle3":
tasks = [generate_image_dalle(prompt, api_key, base_url) for prompt in prompts] tasks = [generate_image_dalle(prompt, api_key, base_url) for prompt in prompts]
else: else:
tasks = [generate_image_replicate(prompt, api_key) for prompt in prompts] tasks = [generate_image_replicate(prompt, api_key) for prompt in prompts]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
end_time = time.time()
generation_time = end_time - start_time
print(f"Image generation time: {generation_time:.2f} seconds")
processed_results: List[Union[str, None]] = [] processed_results: List[Union[str, None]] = []
for result in results: for result in results:

View File

@ -26,11 +26,11 @@ async def call_replicate(
# Polling every 1 second until the status is succeeded or error # Polling every 1 second until the status is succeeded or error
num_polls = 0 num_polls = 0
max_polls = 120 max_polls = 100
while num_polls < max_polls: while num_polls < max_polls:
num_polls += 1 num_polls += 1
await asyncio.sleep(1) await asyncio.sleep(0.2)
# Check the status # Check the status
status_check_url = f"{url}/{prediction_id}" status_check_url = f"{url}/{prediction_id}"