From 26447ce15d867012f10088d3cab9d8fbb366028c Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Mon, 4 Dec 2023 16:32:17 -0500 Subject: [PATCH] handle openai.AuthenticationError --- backend/main.py | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/backend/main.py b/backend/main.py index e7cc082..72b6fd3 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,6 +1,5 @@ # Load environment variables first from dotenv import load_dotenv -from fastapi.responses import HTMLResponse load_dotenv() @@ -11,6 +10,8 @@ import traceback from datetime import datetime from fastapi import FastAPI, WebSocket from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse +import openai from llm import stream_openai_response from mock import mock_completion from image_generation import create_alt_url_mapping, generate_images @@ -35,6 +36,10 @@ app.add_middleware( # TODO: Should only be set to true when value is 'True', not any abitrary truthy value SHOULD_MOCK_AI_RESPONSE = bool(os.environ.get("MOCK", False)) +# Set to True when running in production (on the hosted version) +# Used as a feature flag to enable or disable certain features +IS_PROD = os.environ.get("IS_PROD", False) + app.include_router(screenshot.router) @@ -71,6 +76,12 @@ async def stream_code(websocket: WebSocket): print("Incoming websocket connection...") + async def throw_error( + message: str, + ): + await websocket.send_json({"type": "error", "value": message}) + await websocket.close() + params = await websocket.receive_json() print("Received params") @@ -177,12 +188,24 @@ async def stream_code(websocket: WebSocket): if SHOULD_MOCK_AI_RESPONSE: completion = await mock_completion(process_chunk) else: - completion = await stream_openai_response( - prompt_messages, - api_key=openai_api_key, - base_url=openai_base_url, - callback=lambda x: process_chunk(x), - ) + try: + completion = await stream_openai_response( + prompt_messages, + api_key=openai_api_key, + base_url=openai_base_url, + callback=lambda x: process_chunk(x), + ) + except openai.AuthenticationError as e: + print("[GENERATE_CODE] Authentication failed", e) + error_message = ( + "Incorrect OpenAI key. Please make sure your OpenAI API key is correct, or create a new OpenAI API key on your OpenAI dashboard." + + ( + " Alternatively, you can purchase code generation credits directly on this website." + if IS_PROD + else "" + ) + ) + return await throw_error(error_message) # Write the messages dict into a log so that we can debug later write_logs(prompt_messages, completion)