clean up model strings and add support for GPT-4 Turbo (Apr 2024)
This commit is contained in:
parent
9e1bcae545
commit
6587b626c5
@ -12,18 +12,20 @@ from utils import pprint_prompt
|
||||
# Actual model versions that are passed to the LLMs and stored in our logs
|
||||
class Llm(Enum):
|
||||
GPT_4_VISION = "gpt-4-vision-preview"
|
||||
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09"
|
||||
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
||||
CLAUDE_3_OPUS = "claude-3-opus-20240229"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
|
||||
|
||||
# Keep in sync with frontend (lib/models.ts)
|
||||
# User-facing names for the models (for example, in the future, gpt_4_vision might
|
||||
# be backed by a different model version)
|
||||
CODE_GENERATION_MODELS = [
|
||||
"gpt_4_vision",
|
||||
"claude_3_sonnet",
|
||||
]
|
||||
# Will throw errors if you send a garbage string
|
||||
def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
|
||||
if frontend_str == "gpt_4_vision":
|
||||
return Llm.GPT_4_VISION
|
||||
elif frontend_str == "claude_3_sonnet":
|
||||
return Llm.CLAUDE_3_SONNET
|
||||
else:
|
||||
return Llm(frontend_str)
|
||||
|
||||
|
||||
async def stream_openai_response(
|
||||
@ -31,23 +33,22 @@ async def stream_openai_response(
|
||||
api_key: str,
|
||||
base_url: str | None,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
model: Llm,
|
||||
) -> str:
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
model = Llm.GPT_4_VISION
|
||||
|
||||
# Base parameters
|
||||
params = {
|
||||
"model": model.value,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"timeout": 600,
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
# Add 'max_tokens' only if the model is a GPT4 vision model
|
||||
if model == Llm.GPT_4_VISION:
|
||||
# Add 'max_tokens' only if the model is a GPT4 vision or Turbo model
|
||||
if model == Llm.GPT_4_VISION or model == Llm.GPT_4_TURBO_2024_04_09:
|
||||
params["max_tokens"] = 4096
|
||||
params["temperature"] = 0
|
||||
|
||||
stream = await client.chat.completions.create(**params) # type: ignore
|
||||
full_response = ""
|
||||
|
||||
@ -5,8 +5,8 @@ import openai
|
||||
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
||||
from custom_types import InputMode
|
||||
from llm import (
|
||||
CODE_GENERATION_MODELS,
|
||||
Llm,
|
||||
convert_frontend_str_to_llm,
|
||||
stream_claude_response,
|
||||
stream_claude_response_native,
|
||||
stream_openai_response,
|
||||
@ -84,10 +84,14 @@ async def stream_code(websocket: WebSocket):
|
||||
validated_input_mode = cast(InputMode, input_mode)
|
||||
|
||||
# Read the model from the request. Fall back to default if not provided.
|
||||
code_generation_model = params.get("codeGenerationModel", "gpt_4_vision")
|
||||
if code_generation_model not in CODE_GENERATION_MODELS:
|
||||
await throw_error(f"Invalid model: {code_generation_model}")
|
||||
raise Exception(f"Invalid model: {code_generation_model}")
|
||||
code_generation_model_str = params.get(
|
||||
"codeGenerationModel", Llm.GPT_4_VISION.value
|
||||
)
|
||||
try:
|
||||
code_generation_model = convert_frontend_str_to_llm(code_generation_model_str)
|
||||
except:
|
||||
await throw_error(f"Invalid model: {code_generation_model_str}")
|
||||
raise Exception(f"Invalid model: {code_generation_model_str}")
|
||||
exact_llm_version = None
|
||||
|
||||
print(
|
||||
@ -105,7 +109,10 @@ async def stream_code(websocket: WebSocket):
|
||||
if openai_api_key:
|
||||
print("Using OpenAI API key from environment variable")
|
||||
|
||||
if not openai_api_key and code_generation_model == "gpt_4_vision":
|
||||
if not openai_api_key and (
|
||||
code_generation_model == Llm.GPT_4_VISION
|
||||
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
|
||||
):
|
||||
print("OpenAI API key not found")
|
||||
await throw_error(
|
||||
"No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server."
|
||||
@ -226,7 +233,7 @@ async def stream_code(websocket: WebSocket):
|
||||
include_thinking=True,
|
||||
)
|
||||
exact_llm_version = Llm.CLAUDE_3_OPUS
|
||||
elif code_generation_model == "claude_3_sonnet":
|
||||
elif code_generation_model == Llm.CLAUDE_3_SONNET:
|
||||
if not ANTHROPIC_API_KEY:
|
||||
await throw_error(
|
||||
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env"
|
||||
@ -238,15 +245,16 @@ async def stream_code(websocket: WebSocket):
|
||||
api_key=ANTHROPIC_API_KEY,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
exact_llm_version = Llm.CLAUDE_3_SONNET
|
||||
exact_llm_version = code_generation_model
|
||||
else:
|
||||
completion = await stream_openai_response(
|
||||
prompt_messages, # type: ignore
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
callback=lambda x: process_chunk(x),
|
||||
model=code_generation_model,
|
||||
)
|
||||
exact_llm_version = Llm.GPT_4_VISION
|
||||
exact_llm_version = code_generation_model
|
||||
except openai.AuthenticationError as e:
|
||||
print("[GENERATE_CODE] Authentication failed", e)
|
||||
error_message = (
|
||||
|
||||
36
backend/test_llm.py
Normal file
36
backend/test_llm.py
Normal file
@ -0,0 +1,36 @@
|
||||
import unittest
|
||||
from llm import convert_frontend_str_to_llm, Llm
|
||||
|
||||
|
||||
class TestConvertFrontendStrToLlm(unittest.TestCase):
|
||||
def test_convert_valid_strings(self):
|
||||
self.assertEqual(
|
||||
convert_frontend_str_to_llm("gpt_4_vision"),
|
||||
Llm.GPT_4_VISION,
|
||||
"Should convert 'gpt_4_vision' to Llm.GPT_4_VISION",
|
||||
)
|
||||
self.assertEqual(
|
||||
convert_frontend_str_to_llm("claude_3_sonnet"),
|
||||
Llm.CLAUDE_3_SONNET,
|
||||
"Should convert 'claude_3_sonnet' to Llm.CLAUDE_3_SONNET",
|
||||
)
|
||||
self.assertEqual(
|
||||
convert_frontend_str_to_llm("claude-3-opus-20240229"),
|
||||
Llm.CLAUDE_3_OPUS,
|
||||
"Should convert 'claude-3-opus-20240229' to Llm.CLAUDE_3_OPUS",
|
||||
)
|
||||
self.assertEqual(
|
||||
convert_frontend_str_to_llm("gpt-4-turbo-2024-04-09"),
|
||||
Llm.GPT_4_TURBO_2024_04_09,
|
||||
"Should convert 'gpt-4-turbo-2024-04-09' to Llm.GPT_4_TURBO_2024_04_09",
|
||||
)
|
||||
|
||||
def test_convert_invalid_string_raises_exception(self):
|
||||
with self.assertRaises(ValueError):
|
||||
convert_frontend_str_to_llm("invalid_string")
|
||||
with self.assertRaises(ValueError):
|
||||
convert_frontend_str_to_llm("another_invalid_string")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -1,12 +1,15 @@
|
||||
// Keep in sync with backend (llm.py)
|
||||
export enum CodeGenerationModel {
|
||||
GPT_4_VISION = "gpt_4_vision",
|
||||
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09",
|
||||
CLAUDE_3_SONNET = "claude_3_sonnet",
|
||||
}
|
||||
|
||||
// Will generate a static error if a model in the enum above is not in the descriptions
|
||||
export const CODE_GENERATION_MODEL_DESCRIPTIONS: {
|
||||
[key in CodeGenerationModel]: { name: string; inBeta: boolean };
|
||||
} = {
|
||||
gpt_4_vision: { name: "GPT-4 Vision", inBeta: false },
|
||||
claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: true },
|
||||
gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false },
|
||||
claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false },
|
||||
"gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false },
|
||||
};
|
||||
|
||||
Loading…
Reference in New Issue
Block a user