refactor
This commit is contained in:
parent
0f731598dd
commit
ff12790883
@ -23,6 +23,14 @@ class Llm(Enum):
|
|||||||
CLAUDE_3_5_SONNET_2024_06_20 = "claude-3-5-sonnet-20240620"
|
CLAUDE_3_5_SONNET_2024_06_20 = "claude-3-5-sonnet-20240620"
|
||||||
|
|
||||||
|
|
||||||
|
def is_openai_model(model: Llm) -> bool:
|
||||||
|
return model in {
|
||||||
|
Llm.GPT_4_VISION,
|
||||||
|
Llm.GPT_4_TURBO_2024_04_09,
|
||||||
|
Llm.GPT_4O_2024_05_13,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Will throw errors if you send a garbage string
|
# Will throw errors if you send a garbage string
|
||||||
def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
|
def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
|
||||||
if frontend_str == "gpt_4_vision":
|
if frontend_str == "gpt_4_vision":
|
||||||
|
|||||||
@ -4,11 +4,12 @@ import traceback
|
|||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, WebSocket
|
||||||
import openai
|
import openai
|
||||||
from codegen.utils import extract_html_content
|
from codegen.utils import extract_html_content
|
||||||
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
from config import ANTHROPIC_API_KEY, IS_PROD, OPENAI_API_KEY, SHOULD_MOCK_AI_RESPONSE
|
||||||
from custom_types import InputMode
|
from custom_types import InputMode
|
||||||
from llm import (
|
from llm import (
|
||||||
Llm,
|
Llm,
|
||||||
convert_frontend_str_to_llm,
|
convert_frontend_str_to_llm,
|
||||||
|
is_openai_model,
|
||||||
stream_claude_response,
|
stream_claude_response,
|
||||||
stream_claude_response_native,
|
stream_claude_response_native,
|
||||||
stream_openai_response,
|
stream_openai_response,
|
||||||
@ -90,6 +91,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
async def throw_error(
|
async def throw_error(
|
||||||
message: str,
|
message: str,
|
||||||
):
|
):
|
||||||
|
print(message)
|
||||||
await websocket.send_json({"type": "error", "value": message})
|
await websocket.send_json({"type": "error", "value": message})
|
||||||
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
|
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
|
||||||
|
|
||||||
@ -104,7 +106,6 @@ async def stream_code(websocket: WebSocket):
|
|||||||
|
|
||||||
# TODO: Are the values always strings?
|
# TODO: Are the values always strings?
|
||||||
params: Dict[str, str] = await websocket.receive_json()
|
params: Dict[str, str] = await websocket.receive_json()
|
||||||
|
|
||||||
print("Received params")
|
print("Received params")
|
||||||
|
|
||||||
# Read the code config settings (stack) from the request.
|
# Read the code config settings (stack) from the request.
|
||||||
@ -135,28 +136,21 @@ async def stream_code(websocket: WebSocket):
|
|||||||
|
|
||||||
# Auto-upgrade usage of older models
|
# Auto-upgrade usage of older models
|
||||||
code_generation_model = auto_upgrade_model(code_generation_model)
|
code_generation_model = auto_upgrade_model(code_generation_model)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
|
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
||||||
# If neither is provided, we throw an error.
|
# If neither is provided, we throw an error.
|
||||||
openai_api_key = None
|
openai_api_key = params.get("openAiApiKey")
|
||||||
if params["openAiApiKey"]:
|
if openai_api_key:
|
||||||
openai_api_key = params["openAiApiKey"]
|
|
||||||
print("Using OpenAI API key from client-side settings dialog")
|
print("Using OpenAI API key from client-side settings dialog")
|
||||||
else:
|
else:
|
||||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
openai_api_key = OPENAI_API_KEY
|
||||||
if openai_api_key:
|
if openai_api_key:
|
||||||
print("Using OpenAI API key from environment variable")
|
print("Using OpenAI API key from environment variable")
|
||||||
|
|
||||||
if not openai_api_key and (
|
if not openai_api_key and is_openai_model(code_generation_model):
|
||||||
code_generation_model == Llm.GPT_4_VISION
|
|
||||||
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
|
|
||||||
or code_generation_model == Llm.GPT_4O_2024_05_13
|
|
||||||
):
|
|
||||||
print("OpenAI API key not found")
|
|
||||||
await throw_error(
|
await throw_error(
|
||||||
"No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server."
|
"No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server."
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user