diff --git a/backend/image_generation/core.py b/backend/image_generation/core.py index 3fcc2dc..dfd3375 100644 --- a/backend/image_generation/core.py +++ b/backend/image_generation/core.py @@ -13,11 +13,17 @@ async def process_tasks( base_url: str | None, model: Literal["dalle3", "sdxl-lightning"], ): + import time + + start_time = time.time() if model == "dalle3": tasks = [generate_image_dalle(prompt, api_key, base_url) for prompt in prompts] else: tasks = [generate_image_replicate(prompt, api_key) for prompt in prompts] results = await asyncio.gather(*tasks, return_exceptions=True) + end_time = time.time() + generation_time = end_time - start_time + print(f"Image generation time: {generation_time:.2f} seconds") processed_results: List[Union[str, None]] = [] for result in results: diff --git a/backend/image_generation/replicate.py b/backend/image_generation/replicate.py index ce07306..e8b8748 100644 --- a/backend/image_generation/replicate.py +++ b/backend/image_generation/replicate.py @@ -26,11 +26,11 @@ async def call_replicate( # Polling every 1 second until the status is succeeded or error num_polls = 0 - max_polls = 120 + max_polls = 100 while num_polls < max_polls: num_polls += 1 - await asyncio.sleep(1) + await asyncio.sleep(0.2) # Check the status status_check_url = f"{url}/{prediction_id}"