diff --git a/backend/llm.py b/backend/llm.py index 3c2c853..3d653b2 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -12,18 +12,20 @@ from utils import pprint_prompt # Actual model versions that are passed to the LLMs and stored in our logs class Llm(Enum): GPT_4_VISION = "gpt-4-vision-preview" + GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09" CLAUDE_3_SONNET = "claude-3-sonnet-20240229" CLAUDE_3_OPUS = "claude-3-opus-20240229" CLAUDE_3_HAIKU = "claude-3-haiku-20240307" -# Keep in sync with frontend (lib/models.ts) -# User-facing names for the models (for example, in the future, gpt_4_vision might -# be backed by a different model version) -CODE_GENERATION_MODELS = [ - "gpt_4_vision", - "claude_3_sonnet", -] +# Will throw errors if you send a garbage string +def convert_frontend_str_to_llm(frontend_str: str) -> Llm: + if frontend_str == "gpt_4_vision": + return Llm.GPT_4_VISION + elif frontend_str == "claude_3_sonnet": + return Llm.CLAUDE_3_SONNET + else: + return Llm(frontend_str) async def stream_openai_response( @@ -31,23 +33,22 @@ async def stream_openai_response( api_key: str, base_url: str | None, callback: Callable[[str], Awaitable[None]], + model: Llm, ) -> str: client = AsyncOpenAI(api_key=api_key, base_url=base_url) - model = Llm.GPT_4_VISION - # Base parameters params = { "model": model.value, "messages": messages, "stream": True, "timeout": 600, + "temperature": 0.0, } - # Add 'max_tokens' only if the model is a GPT4 vision model - if model == Llm.GPT_4_VISION: + # Add 'max_tokens' only if the model is a GPT4 vision or Turbo model + if model == Llm.GPT_4_VISION or model == Llm.GPT_4_TURBO_2024_04_09: params["max_tokens"] = 4096 - params["temperature"] = 0 stream = await client.chat.completions.create(**params) # type: ignore full_response = "" diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index a7edb9b..fa5c7a5 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -5,8 +5,8 @@ import openai from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE from custom_types import InputMode from llm import ( - CODE_GENERATION_MODELS, Llm, + convert_frontend_str_to_llm, stream_claude_response, stream_claude_response_native, stream_openai_response, @@ -84,10 +84,14 @@ async def stream_code(websocket: WebSocket): validated_input_mode = cast(InputMode, input_mode) # Read the model from the request. Fall back to default if not provided. - code_generation_model = params.get("codeGenerationModel", "gpt_4_vision") - if code_generation_model not in CODE_GENERATION_MODELS: - await throw_error(f"Invalid model: {code_generation_model}") - raise Exception(f"Invalid model: {code_generation_model}") + code_generation_model_str = params.get( + "codeGenerationModel", Llm.GPT_4_VISION.value + ) + try: + code_generation_model = convert_frontend_str_to_llm(code_generation_model_str) + except: + await throw_error(f"Invalid model: {code_generation_model_str}") + raise Exception(f"Invalid model: {code_generation_model_str}") exact_llm_version = None print( @@ -105,7 +109,10 @@ async def stream_code(websocket: WebSocket): if openai_api_key: print("Using OpenAI API key from environment variable") - if not openai_api_key and code_generation_model == "gpt_4_vision": + if not openai_api_key and ( + code_generation_model == Llm.GPT_4_VISION + or code_generation_model == Llm.GPT_4_TURBO_2024_04_09 + ): print("OpenAI API key not found") await throw_error( "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server." @@ -226,7 +233,7 @@ async def stream_code(websocket: WebSocket): include_thinking=True, ) exact_llm_version = Llm.CLAUDE_3_OPUS - elif code_generation_model == "claude_3_sonnet": + elif code_generation_model == Llm.CLAUDE_3_SONNET: if not ANTHROPIC_API_KEY: await throw_error( "No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env" @@ -238,15 +245,16 @@ async def stream_code(websocket: WebSocket): api_key=ANTHROPIC_API_KEY, callback=lambda x: process_chunk(x), ) - exact_llm_version = Llm.CLAUDE_3_SONNET + exact_llm_version = code_generation_model else: completion = await stream_openai_response( prompt_messages, # type: ignore api_key=openai_api_key, base_url=openai_base_url, callback=lambda x: process_chunk(x), + model=code_generation_model, ) - exact_llm_version = Llm.GPT_4_VISION + exact_llm_version = code_generation_model except openai.AuthenticationError as e: print("[GENERATE_CODE] Authentication failed", e) error_message = ( diff --git a/backend/test_llm.py b/backend/test_llm.py new file mode 100644 index 0000000..ec005a3 --- /dev/null +++ b/backend/test_llm.py @@ -0,0 +1,36 @@ +import unittest +from llm import convert_frontend_str_to_llm, Llm + + +class TestConvertFrontendStrToLlm(unittest.TestCase): + def test_convert_valid_strings(self): + self.assertEqual( + convert_frontend_str_to_llm("gpt_4_vision"), + Llm.GPT_4_VISION, + "Should convert 'gpt_4_vision' to Llm.GPT_4_VISION", + ) + self.assertEqual( + convert_frontend_str_to_llm("claude_3_sonnet"), + Llm.CLAUDE_3_SONNET, + "Should convert 'claude_3_sonnet' to Llm.CLAUDE_3_SONNET", + ) + self.assertEqual( + convert_frontend_str_to_llm("claude-3-opus-20240229"), + Llm.CLAUDE_3_OPUS, + "Should convert 'claude-3-opus-20240229' to Llm.CLAUDE_3_OPUS", + ) + self.assertEqual( + convert_frontend_str_to_llm("gpt-4-turbo-2024-04-09"), + Llm.GPT_4_TURBO_2024_04_09, + "Should convert 'gpt-4-turbo-2024-04-09' to Llm.GPT_4_TURBO_2024_04_09", + ) + + def test_convert_invalid_string_raises_exception(self): + with self.assertRaises(ValueError): + convert_frontend_str_to_llm("invalid_string") + with self.assertRaises(ValueError): + convert_frontend_str_to_llm("another_invalid_string") + + +if __name__ == "__main__": + unittest.main() diff --git a/frontend/src/lib/models.ts b/frontend/src/lib/models.ts index a972f78..58b3e31 100644 --- a/frontend/src/lib/models.ts +++ b/frontend/src/lib/models.ts @@ -1,12 +1,15 @@ // Keep in sync with backend (llm.py) export enum CodeGenerationModel { GPT_4_VISION = "gpt_4_vision", + GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09", CLAUDE_3_SONNET = "claude_3_sonnet", } +// Will generate a static error if a model in the enum above is not in the descriptions export const CODE_GENERATION_MODEL_DESCRIPTIONS: { [key in CodeGenerationModel]: { name: string; inBeta: boolean }; } = { - gpt_4_vision: { name: "GPT-4 Vision", inBeta: false }, - claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: true }, + gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false }, + claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false }, + "gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false }, };