This commit is contained in:
Abi Raja 2024-07-31 10:14:23 -04:00
parent 0f731598dd
commit ff12790883
2 changed files with 15 additions and 13 deletions

View File

@ -23,6 +23,14 @@ class Llm(Enum):
CLAUDE_3_5_SONNET_2024_06_20 = "claude-3-5-sonnet-20240620"
def is_openai_model(model: Llm) -> bool:
return model in {
Llm.GPT_4_VISION,
Llm.GPT_4_TURBO_2024_04_09,
Llm.GPT_4O_2024_05_13,
}
# Will throw errors if you send a garbage string
def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
if frontend_str == "gpt_4_vision":

View File

@ -4,11 +4,12 @@ import traceback
from fastapi import APIRouter, WebSocket
import openai
from codegen.utils import extract_html_content
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
from config import ANTHROPIC_API_KEY, IS_PROD, OPENAI_API_KEY, SHOULD_MOCK_AI_RESPONSE
from custom_types import InputMode
from llm import (
Llm,
convert_frontend_str_to_llm,
is_openai_model,
stream_claude_response,
stream_claude_response_native,
stream_openai_response,
@ -90,6 +91,7 @@ async def stream_code(websocket: WebSocket):
async def throw_error(
message: str,
):
print(message)
await websocket.send_json({"type": "error", "value": message})
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
@ -104,7 +106,6 @@ async def stream_code(websocket: WebSocket):
# TODO: Are the values always strings?
params: Dict[str, str] = await websocket.receive_json()
print("Received params")
# Read the code config settings (stack) from the request.
@ -135,28 +136,21 @@ async def stream_code(websocket: WebSocket):
# Auto-upgrade usage of older models
code_generation_model = auto_upgrade_model(code_generation_model)
print(
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
)
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
# If neither is provided, we throw an error.
openai_api_key = None
if params["openAiApiKey"]:
openai_api_key = params["openAiApiKey"]
openai_api_key = params.get("openAiApiKey")
if openai_api_key:
print("Using OpenAI API key from client-side settings dialog")
else:
openai_api_key = os.environ.get("OPENAI_API_KEY")
openai_api_key = OPENAI_API_KEY
if openai_api_key:
print("Using OpenAI API key from environment variable")
if not openai_api_key and (
code_generation_model == Llm.GPT_4_VISION
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
or code_generation_model == Llm.GPT_4O_2024_05_13
):
print("OpenAI API key not found")
if not openai_api_key and is_openai_model(code_generation_model):
await throw_error(
"No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server."
)