diff --git a/backend/main.py b/backend/main.py index 8852b35..bef3b94 100644 --- a/backend/main.py +++ b/backend/main.py @@ -11,6 +11,7 @@ from datetime import datetime from fastapi import FastAPI, WebSocket from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse +import openai from llm import stream_openai_response from mock import mock_completion from image_generation import create_alt_url_mapping, generate_images @@ -50,6 +51,10 @@ app.add_middleware( # TODO: Should only be set to true when value is 'True', not any abitrary truthy value SHOULD_MOCK_AI_RESPONSE = bool(os.environ.get("MOCK", False)) +# Set to True when running in production (on the hosted version) +# Used as a feature flag to enable or disable certain features +IS_PROD = os.environ.get("IS_PROD", False) + app.include_router(screenshot.router) @@ -86,6 +91,12 @@ async def stream_code(websocket: WebSocket): print("Incoming websocket connection...") + async def throw_error( + message: str, + ): + await websocket.send_json({"type": "error", "value": message}) + await websocket.close() + params = await websocket.receive_json() print("Received params") @@ -192,12 +203,47 @@ async def stream_code(websocket: WebSocket): if SHOULD_MOCK_AI_RESPONSE: completion = await mock_completion(process_chunk) else: - completion = await stream_openai_response( - prompt_messages, - api_key=openai_api_key, - base_url=openai_base_url, - callback=lambda x: process_chunk(x), - ) + try: + completion = await stream_openai_response( + prompt_messages, + api_key=openai_api_key, + base_url=openai_base_url, + callback=lambda x: process_chunk(x), + ) + except openai.AuthenticationError as e: + print("[GENERATE_CODE] Authentication failed", e) + error_message = ( + "Incorrect OpenAI key. Please make sure your OpenAI API key is correct, or create a new OpenAI API key on your OpenAI dashboard." + + ( + " Alternatively, you can purchase code generation credits directly on this website." + if IS_PROD + else "" + ) + ) + return await throw_error(error_message) + except openai.NotFoundError as e: + print("[GENERATE_CODE] Model not found", e) + error_message = ( + e.message + + ". Please make sure you have followed the instructions correctly to obtain an OpenAI key with GPT vision access: https://github.com/abi/screenshot-to-code/blob/main/Troubleshooting.md" + + ( + " Alternatively, you can purchase code generation credits directly on this website." + if IS_PROD + else "" + ) + ) + return await throw_error(error_message) + except openai.RateLimitError as e: + print("[GENERATE_CODE] Rate limit exceeded", e) + error_message = ( + "OpenAI error - 'You exceeded your current quota, please check your plan and billing details.'" + + ( + " Alternatively, you can purchase code generation credits directly on this website." + if IS_PROD + else "" + ) + ) + return await throw_error(error_message) # Write the messages dict into a log so that we can debug later write_logs(prompt_messages, completion)