diff --git a/backend/llm.py b/backend/llm.py index 450ec2f..d828b6d 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -23,6 +23,14 @@ class Llm(Enum): CLAUDE_3_5_SONNET_2024_06_20 = "claude-3-5-sonnet-20240620" +def is_openai_model(model: Llm) -> bool: + return model in { + Llm.GPT_4_VISION, + Llm.GPT_4_TURBO_2024_04_09, + Llm.GPT_4O_2024_05_13, + } + + # Will throw errors if you send a garbage string def convert_frontend_str_to_llm(frontend_str: str) -> Llm: if frontend_str == "gpt_4_vision": diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 055853f..04f5fa8 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -4,11 +4,12 @@ import traceback from fastapi import APIRouter, WebSocket import openai from codegen.utils import extract_html_content -from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE +from config import ANTHROPIC_API_KEY, IS_PROD, OPENAI_API_KEY, SHOULD_MOCK_AI_RESPONSE from custom_types import InputMode from llm import ( Llm, convert_frontend_str_to_llm, + is_openai_model, stream_claude_response, stream_claude_response_native, stream_openai_response, @@ -90,6 +91,7 @@ async def stream_code(websocket: WebSocket): async def throw_error( message: str, ): + print(message) await websocket.send_json({"type": "error", "value": message}) await websocket.close(APP_ERROR_WEB_SOCKET_CODE) @@ -104,7 +106,6 @@ async def stream_code(websocket: WebSocket): # TODO: Are the values always strings? params: Dict[str, str] = await websocket.receive_json() - print("Received params") # Read the code config settings (stack) from the request. @@ -135,28 +136,21 @@ async def stream_code(websocket: WebSocket): # Auto-upgrade usage of older models code_generation_model = auto_upgrade_model(code_generation_model) - print( f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." ) # Get the OpenAI API key from the request. Fall back to environment variable if not provided. # If neither is provided, we throw an error. - openai_api_key = None - if params["openAiApiKey"]: - openai_api_key = params["openAiApiKey"] + openai_api_key = params.get("openAiApiKey") + if openai_api_key: print("Using OpenAI API key from client-side settings dialog") else: - openai_api_key = os.environ.get("OPENAI_API_KEY") + openai_api_key = OPENAI_API_KEY if openai_api_key: print("Using OpenAI API key from environment variable") - if not openai_api_key and ( - code_generation_model == Llm.GPT_4_VISION - or code_generation_model == Llm.GPT_4_TURBO_2024_04_09 - or code_generation_model == Llm.GPT_4O_2024_05_13 - ): - print("OpenAI API key not found") + if not openai_api_key and is_openai_model(code_generation_model): await throw_error( "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server." )