This commit is contained in:
Abi Raja 2024-07-31 10:14:23 -04:00
parent 0f731598dd
commit ff12790883
2 changed files with 15 additions and 13 deletions

View File

@ -23,6 +23,14 @@ class Llm(Enum):
CLAUDE_3_5_SONNET_2024_06_20 = "claude-3-5-sonnet-20240620" CLAUDE_3_5_SONNET_2024_06_20 = "claude-3-5-sonnet-20240620"
def is_openai_model(model: Llm) -> bool:
return model in {
Llm.GPT_4_VISION,
Llm.GPT_4_TURBO_2024_04_09,
Llm.GPT_4O_2024_05_13,
}
# Will throw errors if you send a garbage string # Will throw errors if you send a garbage string
def convert_frontend_str_to_llm(frontend_str: str) -> Llm: def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
if frontend_str == "gpt_4_vision": if frontend_str == "gpt_4_vision":

View File

@ -4,11 +4,12 @@ import traceback
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import openai import openai
from codegen.utils import extract_html_content from codegen.utils import extract_html_content
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE from config import ANTHROPIC_API_KEY, IS_PROD, OPENAI_API_KEY, SHOULD_MOCK_AI_RESPONSE
from custom_types import InputMode from custom_types import InputMode
from llm import ( from llm import (
Llm, Llm,
convert_frontend_str_to_llm, convert_frontend_str_to_llm,
is_openai_model,
stream_claude_response, stream_claude_response,
stream_claude_response_native, stream_claude_response_native,
stream_openai_response, stream_openai_response,
@ -90,6 +91,7 @@ async def stream_code(websocket: WebSocket):
async def throw_error( async def throw_error(
message: str, message: str,
): ):
print(message)
await websocket.send_json({"type": "error", "value": message}) await websocket.send_json({"type": "error", "value": message})
await websocket.close(APP_ERROR_WEB_SOCKET_CODE) await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
@ -104,7 +106,6 @@ async def stream_code(websocket: WebSocket):
# TODO: Are the values always strings? # TODO: Are the values always strings?
params: Dict[str, str] = await websocket.receive_json() params: Dict[str, str] = await websocket.receive_json()
print("Received params") print("Received params")
# Read the code config settings (stack) from the request. # Read the code config settings (stack) from the request.
@ -135,28 +136,21 @@ async def stream_code(websocket: WebSocket):
# Auto-upgrade usage of older models # Auto-upgrade usage of older models
code_generation_model = auto_upgrade_model(code_generation_model) code_generation_model = auto_upgrade_model(code_generation_model)
print( print(
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
) )
# Get the OpenAI API key from the request. Fall back to environment variable if not provided. # Get the OpenAI API key from the request. Fall back to environment variable if not provided.
# If neither is provided, we throw an error. # If neither is provided, we throw an error.
openai_api_key = None openai_api_key = params.get("openAiApiKey")
if params["openAiApiKey"]: if openai_api_key:
openai_api_key = params["openAiApiKey"]
print("Using OpenAI API key from client-side settings dialog") print("Using OpenAI API key from client-side settings dialog")
else: else:
openai_api_key = os.environ.get("OPENAI_API_KEY") openai_api_key = OPENAI_API_KEY
if openai_api_key: if openai_api_key:
print("Using OpenAI API key from environment variable") print("Using OpenAI API key from environment variable")
if not openai_api_key and ( if not openai_api_key and is_openai_model(code_generation_model):
code_generation_model == Llm.GPT_4_VISION
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
or code_generation_model == Llm.GPT_4O_2024_05_13
):
print("OpenAI API key not found")
await throw_error( await throw_error(
"No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server." "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server."
) )