clean up model strings and add support for GPT-4 Turbo (Apr 2024)

This commit is contained in:
Abi Raja 2024-04-11 09:55:55 -04:00
parent 9e1bcae545
commit 6587b626c5
4 changed files with 71 additions and 23 deletions

View File

@ -12,18 +12,20 @@ from utils import pprint_prompt
# Actual model versions that are passed to the LLMs and stored in our logs # Actual model versions that are passed to the LLMs and stored in our logs
class Llm(Enum): class Llm(Enum):
GPT_4_VISION = "gpt-4-vision-preview" GPT_4_VISION = "gpt-4-vision-preview"
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09"
CLAUDE_3_SONNET = "claude-3-sonnet-20240229" CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_OPUS = "claude-3-opus-20240229" CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307" CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# Keep in sync with frontend (lib/models.ts) # Will throw errors if you send a garbage string
# User-facing names for the models (for example, in the future, gpt_4_vision might def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
# be backed by a different model version) if frontend_str == "gpt_4_vision":
CODE_GENERATION_MODELS = [ return Llm.GPT_4_VISION
"gpt_4_vision", elif frontend_str == "claude_3_sonnet":
"claude_3_sonnet", return Llm.CLAUDE_3_SONNET
] else:
return Llm(frontend_str)
async def stream_openai_response( async def stream_openai_response(
@ -31,23 +33,22 @@ async def stream_openai_response(
api_key: str, api_key: str,
base_url: str | None, base_url: str | None,
callback: Callable[[str], Awaitable[None]], callback: Callable[[str], Awaitable[None]],
model: Llm,
) -> str: ) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url) client = AsyncOpenAI(api_key=api_key, base_url=base_url)
model = Llm.GPT_4_VISION
# Base parameters # Base parameters
params = { params = {
"model": model.value, "model": model.value,
"messages": messages, "messages": messages,
"stream": True, "stream": True,
"timeout": 600, "timeout": 600,
"temperature": 0.0,
} }
# Add 'max_tokens' only if the model is a GPT4 vision model # Add 'max_tokens' only if the model is a GPT4 vision or Turbo model
if model == Llm.GPT_4_VISION: if model == Llm.GPT_4_VISION or model == Llm.GPT_4_TURBO_2024_04_09:
params["max_tokens"] = 4096 params["max_tokens"] = 4096
params["temperature"] = 0
stream = await client.chat.completions.create(**params) # type: ignore stream = await client.chat.completions.create(**params) # type: ignore
full_response = "" full_response = ""

View File

@ -5,8 +5,8 @@ import openai
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
from custom_types import InputMode from custom_types import InputMode
from llm import ( from llm import (
CODE_GENERATION_MODELS,
Llm, Llm,
convert_frontend_str_to_llm,
stream_claude_response, stream_claude_response,
stream_claude_response_native, stream_claude_response_native,
stream_openai_response, stream_openai_response,
@ -84,10 +84,14 @@ async def stream_code(websocket: WebSocket):
validated_input_mode = cast(InputMode, input_mode) validated_input_mode = cast(InputMode, input_mode)
# Read the model from the request. Fall back to default if not provided. # Read the model from the request. Fall back to default if not provided.
code_generation_model = params.get("codeGenerationModel", "gpt_4_vision") code_generation_model_str = params.get(
if code_generation_model not in CODE_GENERATION_MODELS: "codeGenerationModel", Llm.GPT_4_VISION.value
await throw_error(f"Invalid model: {code_generation_model}") )
raise Exception(f"Invalid model: {code_generation_model}") try:
code_generation_model = convert_frontend_str_to_llm(code_generation_model_str)
except:
await throw_error(f"Invalid model: {code_generation_model_str}")
raise Exception(f"Invalid model: {code_generation_model_str}")
exact_llm_version = None exact_llm_version = None
print( print(
@ -105,7 +109,10 @@ async def stream_code(websocket: WebSocket):
if openai_api_key: if openai_api_key:
print("Using OpenAI API key from environment variable") print("Using OpenAI API key from environment variable")
if not openai_api_key and code_generation_model == "gpt_4_vision": if not openai_api_key and (
code_generation_model == Llm.GPT_4_VISION
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
):
print("OpenAI API key not found") print("OpenAI API key not found")
await throw_error( await throw_error(
"No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server." "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server."
@ -226,7 +233,7 @@ async def stream_code(websocket: WebSocket):
include_thinking=True, include_thinking=True,
) )
exact_llm_version = Llm.CLAUDE_3_OPUS exact_llm_version = Llm.CLAUDE_3_OPUS
elif code_generation_model == "claude_3_sonnet": elif code_generation_model == Llm.CLAUDE_3_SONNET:
if not ANTHROPIC_API_KEY: if not ANTHROPIC_API_KEY:
await throw_error( await throw_error(
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env" "No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env"
@ -238,15 +245,16 @@ async def stream_code(websocket: WebSocket):
api_key=ANTHROPIC_API_KEY, api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
) )
exact_llm_version = Llm.CLAUDE_3_SONNET exact_llm_version = code_generation_model
else: else:
completion = await stream_openai_response( completion = await stream_openai_response(
prompt_messages, # type: ignore prompt_messages, # type: ignore
api_key=openai_api_key, api_key=openai_api_key,
base_url=openai_base_url, base_url=openai_base_url,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
model=code_generation_model,
) )
exact_llm_version = Llm.GPT_4_VISION exact_llm_version = code_generation_model
except openai.AuthenticationError as e: except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e) print("[GENERATE_CODE] Authentication failed", e)
error_message = ( error_message = (

36
backend/test_llm.py Normal file
View File

@ -0,0 +1,36 @@
import unittest
from llm import convert_frontend_str_to_llm, Llm
class TestConvertFrontendStrToLlm(unittest.TestCase):
def test_convert_valid_strings(self):
self.assertEqual(
convert_frontend_str_to_llm("gpt_4_vision"),
Llm.GPT_4_VISION,
"Should convert 'gpt_4_vision' to Llm.GPT_4_VISION",
)
self.assertEqual(
convert_frontend_str_to_llm("claude_3_sonnet"),
Llm.CLAUDE_3_SONNET,
"Should convert 'claude_3_sonnet' to Llm.CLAUDE_3_SONNET",
)
self.assertEqual(
convert_frontend_str_to_llm("claude-3-opus-20240229"),
Llm.CLAUDE_3_OPUS,
"Should convert 'claude-3-opus-20240229' to Llm.CLAUDE_3_OPUS",
)
self.assertEqual(
convert_frontend_str_to_llm("gpt-4-turbo-2024-04-09"),
Llm.GPT_4_TURBO_2024_04_09,
"Should convert 'gpt-4-turbo-2024-04-09' to Llm.GPT_4_TURBO_2024_04_09",
)
def test_convert_invalid_string_raises_exception(self):
with self.assertRaises(ValueError):
convert_frontend_str_to_llm("invalid_string")
with self.assertRaises(ValueError):
convert_frontend_str_to_llm("another_invalid_string")
if __name__ == "__main__":
unittest.main()

View File

@ -1,12 +1,15 @@
// Keep in sync with backend (llm.py) // Keep in sync with backend (llm.py)
export enum CodeGenerationModel { export enum CodeGenerationModel {
GPT_4_VISION = "gpt_4_vision", GPT_4_VISION = "gpt_4_vision",
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09",
CLAUDE_3_SONNET = "claude_3_sonnet", CLAUDE_3_SONNET = "claude_3_sonnet",
} }
// Will generate a static error if a model in the enum above is not in the descriptions
export const CODE_GENERATION_MODEL_DESCRIPTIONS: { export const CODE_GENERATION_MODEL_DESCRIPTIONS: {
[key in CodeGenerationModel]: { name: string; inBeta: boolean }; [key in CodeGenerationModel]: { name: string; inBeta: boolean };
} = { } = {
gpt_4_vision: { name: "GPT-4 Vision", inBeta: false }, gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false },
claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: true }, claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false },
"gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false },
}; };