diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 489f6ad..d01c309 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -379,7 +379,23 @@ async def stream_code(websocket: WebSocket): ) ) - completions = await asyncio.gather(*tasks) + # Run the models in parallel and capture exceptions if any + completions = await asyncio.gather(*tasks, return_exceptions=True) + + # If all generations failed, throw an error + all_generations_failed = all( + isinstance(completion, Exception) for completion in completions + ) + if all_generations_failed: + await throw_error("Error generating code. Please contact support.") + raise Exception("All generations failed") + + # If some completions failed, replace them with empty strings + for index, completion in enumerate(completions): + if isinstance(completion, Exception): + completions[index] = "" + print("Generation failed for variant", index) + print("Models used for generation: ", variant_models) except openai.AuthenticationError as e: