refactor
This commit is contained in:
parent
0f731598dd
commit
ff12790883
@ -23,6 +23,14 @@ class Llm(Enum):
|
||||
CLAUDE_3_5_SONNET_2024_06_20 = "claude-3-5-sonnet-20240620"
|
||||
|
||||
|
||||
def is_openai_model(model: Llm) -> bool:
|
||||
return model in {
|
||||
Llm.GPT_4_VISION,
|
||||
Llm.GPT_4_TURBO_2024_04_09,
|
||||
Llm.GPT_4O_2024_05_13,
|
||||
}
|
||||
|
||||
|
||||
# Will throw errors if you send a garbage string
|
||||
def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
|
||||
if frontend_str == "gpt_4_vision":
|
||||
|
||||
@ -4,11 +4,12 @@ import traceback
|
||||
from fastapi import APIRouter, WebSocket
|
||||
import openai
|
||||
from codegen.utils import extract_html_content
|
||||
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
||||
from config import ANTHROPIC_API_KEY, IS_PROD, OPENAI_API_KEY, SHOULD_MOCK_AI_RESPONSE
|
||||
from custom_types import InputMode
|
||||
from llm import (
|
||||
Llm,
|
||||
convert_frontend_str_to_llm,
|
||||
is_openai_model,
|
||||
stream_claude_response,
|
||||
stream_claude_response_native,
|
||||
stream_openai_response,
|
||||
@ -90,6 +91,7 @@ async def stream_code(websocket: WebSocket):
|
||||
async def throw_error(
|
||||
message: str,
|
||||
):
|
||||
print(message)
|
||||
await websocket.send_json({"type": "error", "value": message})
|
||||
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
|
||||
|
||||
@ -104,7 +106,6 @@ async def stream_code(websocket: WebSocket):
|
||||
|
||||
# TODO: Are the values always strings?
|
||||
params: Dict[str, str] = await websocket.receive_json()
|
||||
|
||||
print("Received params")
|
||||
|
||||
# Read the code config settings (stack) from the request.
|
||||
@ -135,28 +136,21 @@ async def stream_code(websocket: WebSocket):
|
||||
|
||||
# Auto-upgrade usage of older models
|
||||
code_generation_model = auto_upgrade_model(code_generation_model)
|
||||
|
||||
print(
|
||||
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
|
||||
)
|
||||
|
||||
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
||||
# If neither is provided, we throw an error.
|
||||
openai_api_key = None
|
||||
if params["openAiApiKey"]:
|
||||
openai_api_key = params["openAiApiKey"]
|
||||
openai_api_key = params.get("openAiApiKey")
|
||||
if openai_api_key:
|
||||
print("Using OpenAI API key from client-side settings dialog")
|
||||
else:
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
openai_api_key = OPENAI_API_KEY
|
||||
if openai_api_key:
|
||||
print("Using OpenAI API key from environment variable")
|
||||
|
||||
if not openai_api_key and (
|
||||
code_generation_model == Llm.GPT_4_VISION
|
||||
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
|
||||
or code_generation_model == Llm.GPT_4O_2024_05_13
|
||||
):
|
||||
print("OpenAI API key not found")
|
||||
if not openai_api_key and is_openai_model(code_generation_model):
|
||||
await throw_error(
|
||||
"No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server."
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user