clean up model strings and add support for GPT-4 Turbo (Apr 2024)

This commit is contained in:
Abi Raja 2024-04-11 09:55:55 -04:00
parent 9e1bcae545
commit 6587b626c5
4 changed files with 71 additions and 23 deletions

View File

@ -12,18 +12,20 @@ from utils import pprint_prompt
# Actual model versions that are passed to the LLMs and stored in our logs
class Llm(Enum):
GPT_4_VISION = "gpt-4-vision-preview"
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09"
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# Keep in sync with frontend (lib/models.ts)
# User-facing names for the models (for example, in the future, gpt_4_vision might
# be backed by a different model version)
CODE_GENERATION_MODELS = [
"gpt_4_vision",
"claude_3_sonnet",
]
# Will throw errors if you send a garbage string
def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
if frontend_str == "gpt_4_vision":
return Llm.GPT_4_VISION
elif frontend_str == "claude_3_sonnet":
return Llm.CLAUDE_3_SONNET
else:
return Llm(frontend_str)
async def stream_openai_response(
@ -31,23 +33,22 @@ async def stream_openai_response(
api_key: str,
base_url: str | None,
callback: Callable[[str], Awaitable[None]],
model: Llm,
) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
model = Llm.GPT_4_VISION
# Base parameters
params = {
"model": model.value,
"messages": messages,
"stream": True,
"timeout": 600,
"temperature": 0.0,
}
# Add 'max_tokens' only if the model is a GPT4 vision model
if model == Llm.GPT_4_VISION:
# Add 'max_tokens' only if the model is a GPT4 vision or Turbo model
if model == Llm.GPT_4_VISION or model == Llm.GPT_4_TURBO_2024_04_09:
params["max_tokens"] = 4096
params["temperature"] = 0
stream = await client.chat.completions.create(**params) # type: ignore
full_response = ""

View File

@ -5,8 +5,8 @@ import openai
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
from custom_types import InputMode
from llm import (
CODE_GENERATION_MODELS,
Llm,
convert_frontend_str_to_llm,
stream_claude_response,
stream_claude_response_native,
stream_openai_response,
@ -84,10 +84,14 @@ async def stream_code(websocket: WebSocket):
validated_input_mode = cast(InputMode, input_mode)
# Read the model from the request. Fall back to default if not provided.
code_generation_model = params.get("codeGenerationModel", "gpt_4_vision")
if code_generation_model not in CODE_GENERATION_MODELS:
await throw_error(f"Invalid model: {code_generation_model}")
raise Exception(f"Invalid model: {code_generation_model}")
code_generation_model_str = params.get(
"codeGenerationModel", Llm.GPT_4_VISION.value
)
try:
code_generation_model = convert_frontend_str_to_llm(code_generation_model_str)
except:
await throw_error(f"Invalid model: {code_generation_model_str}")
raise Exception(f"Invalid model: {code_generation_model_str}")
exact_llm_version = None
print(
@ -105,7 +109,10 @@ async def stream_code(websocket: WebSocket):
if openai_api_key:
print("Using OpenAI API key from environment variable")
if not openai_api_key and code_generation_model == "gpt_4_vision":
if not openai_api_key and (
code_generation_model == Llm.GPT_4_VISION
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
):
print("OpenAI API key not found")
await throw_error(
"No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server."
@ -226,7 +233,7 @@ async def stream_code(websocket: WebSocket):
include_thinking=True,
)
exact_llm_version = Llm.CLAUDE_3_OPUS
elif code_generation_model == "claude_3_sonnet":
elif code_generation_model == Llm.CLAUDE_3_SONNET:
if not ANTHROPIC_API_KEY:
await throw_error(
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env"
@ -238,15 +245,16 @@ async def stream_code(websocket: WebSocket):
api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x),
)
exact_llm_version = Llm.CLAUDE_3_SONNET
exact_llm_version = code_generation_model
else:
completion = await stream_openai_response(
prompt_messages, # type: ignore
api_key=openai_api_key,
base_url=openai_base_url,
callback=lambda x: process_chunk(x),
model=code_generation_model,
)
exact_llm_version = Llm.GPT_4_VISION
exact_llm_version = code_generation_model
except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e)
error_message = (

36
backend/test_llm.py Normal file
View File

@ -0,0 +1,36 @@
import unittest
from llm import convert_frontend_str_to_llm, Llm
class TestConvertFrontendStrToLlm(unittest.TestCase):
def test_convert_valid_strings(self):
self.assertEqual(
convert_frontend_str_to_llm("gpt_4_vision"),
Llm.GPT_4_VISION,
"Should convert 'gpt_4_vision' to Llm.GPT_4_VISION",
)
self.assertEqual(
convert_frontend_str_to_llm("claude_3_sonnet"),
Llm.CLAUDE_3_SONNET,
"Should convert 'claude_3_sonnet' to Llm.CLAUDE_3_SONNET",
)
self.assertEqual(
convert_frontend_str_to_llm("claude-3-opus-20240229"),
Llm.CLAUDE_3_OPUS,
"Should convert 'claude-3-opus-20240229' to Llm.CLAUDE_3_OPUS",
)
self.assertEqual(
convert_frontend_str_to_llm("gpt-4-turbo-2024-04-09"),
Llm.GPT_4_TURBO_2024_04_09,
"Should convert 'gpt-4-turbo-2024-04-09' to Llm.GPT_4_TURBO_2024_04_09",
)
def test_convert_invalid_string_raises_exception(self):
with self.assertRaises(ValueError):
convert_frontend_str_to_llm("invalid_string")
with self.assertRaises(ValueError):
convert_frontend_str_to_llm("another_invalid_string")
if __name__ == "__main__":
unittest.main()

View File

@ -1,12 +1,15 @@
// Keep in sync with backend (llm.py)
export enum CodeGenerationModel {
GPT_4_VISION = "gpt_4_vision",
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09",
CLAUDE_3_SONNET = "claude_3_sonnet",
}
// Will generate a static error if a model in the enum above is not in the descriptions
export const CODE_GENERATION_MODEL_DESCRIPTIONS: {
[key in CodeGenerationModel]: { name: string; inBeta: boolean };
} = {
gpt_4_vision: { name: "GPT-4 Vision", inBeta: false },
claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: true },
gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false },
claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false },
"gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false },
};