From 81c4fbe28d3f08320b56fc2defbad9587d43b48c Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Mon, 18 Mar 2024 17:44:05 -0400 Subject: [PATCH] identify exact llm being used during generation --- backend/llm.py | 28 +++++++++++++++++----------- backend/prompts/test_prompts.py | 12 ++++++------ backend/routes/generate_code.py | 10 ++++++++-- backend/video/utils.py | 6 +++--- backend/video_to_app.py | 5 ++--- 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/backend/llm.py b/backend/llm.py index c7032c5..dba5720 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Awaitable, Callable, List, cast from anthropic import AsyncAnthropic from openai import AsyncOpenAI @@ -5,13 +6,18 @@ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk from utils import pprint_prompt -MODEL_GPT_4_VISION = "gpt-4-vision-preview" -MODEL_CLAUDE_SONNET = "claude-3-sonnet-20240229" -MODEL_CLAUDE_OPUS = "claude-3-opus-20240229" -MODEL_CLAUDE_HAIKU = "claude-3-haiku-20240307" + +# Actual model versions that are passed to the LLMs and stored in our logs +class Llm(Enum): + GPT_4_VISION = "gpt-4-vision-preview" + CLAUDE_3_SONNET = "claude-3-sonnet-20240229" + CLAUDE_3_OPUS = "claude-3-opus-20240229" + CLAUDE_3_HAIKU = "claude-3-haiku-20240307" # Keep in sync with frontend (lib/models.ts) +# User-facing names for the models (for example, in the future, gpt_4_vision might +# be backed by a different model version) CODE_GENERATION_MODELS = [ "gpt_4_vision", "claude_3_sonnet", @@ -26,13 +32,13 @@ async def stream_openai_response( ) -> str: client = AsyncOpenAI(api_key=api_key, base_url=base_url) - model = MODEL_GPT_4_VISION + model = Llm.GPT_4_VISION # Base parameters params = {"model": model, "messages": messages, "stream": True, "timeout": 600} # Add 'max_tokens' only if the model is a GPT4 vision model - if model == MODEL_GPT_4_VISION: + if model == Llm.GPT_4_VISION: params["max_tokens"] = 4096 params["temperature"] = 0 @@ -59,12 +65,12 @@ async def stream_claude_response( client = AsyncAnthropic(api_key=api_key) # Base parameters - model = MODEL_CLAUDE_SONNET + model = Llm.CLAUDE_3_SONNET max_tokens = 4096 temperature = 0.0 # Translate OpenAI messages to Claude messages - system_prompt = cast(str, messages[0]["content"]) + system_prompt = cast(str, messages[0].get("content")) claude_messages = [dict(message) for message in messages[1:]] for message in claude_messages: if not isinstance(message["content"], list): @@ -91,7 +97,7 @@ async def stream_claude_response( # Stream Claude response async with client.messages.stream( - model=model, + model=model.value, max_tokens=max_tokens, temperature=temperature, system=system_prompt, @@ -111,7 +117,7 @@ async def stream_claude_response_native( api_key: str, callback: Callable[[str], Awaitable[None]], include_thinking: bool = False, - model: str = MODEL_CLAUDE_OPUS, + model: Llm = Llm.CLAUDE_3_OPUS, ) -> str: client = AsyncAnthropic(api_key=api_key) @@ -140,7 +146,7 @@ async def stream_claude_response_native( pprint_prompt(messages_to_send) async with client.messages.stream( - model=model, + model=model.value, max_tokens=max_tokens, temperature=temperature, system=system_prompt, diff --git a/backend/prompts/test_prompts.py b/backend/prompts/test_prompts.py index 9e60410..42e89c3 100644 --- a/backend/prompts/test_prompts.py +++ b/backend/prompts/test_prompts.py @@ -311,35 +311,35 @@ def test_prompts(): tailwind_prompt = assemble_prompt( "image_data_url", "html_tailwind", "result_image_data_url" ) - assert tailwind_prompt[0]["content"] == TAILWIND_SYSTEM_PROMPT + assert tailwind_prompt[0].get("content") == TAILWIND_SYSTEM_PROMPT assert tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore react_tailwind_prompt = assemble_prompt( "image_data_url", "react_tailwind", "result_image_data_url" ) - assert react_tailwind_prompt[0]["content"] == REACT_TAILWIND_SYSTEM_PROMPT + assert react_tailwind_prompt[0].get("content") == REACT_TAILWIND_SYSTEM_PROMPT assert react_tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore bootstrap_prompt = assemble_prompt( "image_data_url", "bootstrap", "result_image_data_url" ) - assert bootstrap_prompt[0]["content"] == BOOTSTRAP_SYSTEM_PROMPT + assert bootstrap_prompt[0].get("content") == BOOTSTRAP_SYSTEM_PROMPT assert bootstrap_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore ionic_tailwind = assemble_prompt( "image_data_url", "ionic_tailwind", "result_image_data_url" ) - assert ionic_tailwind[0]["content"] == IONIC_TAILWIND_SYSTEM_PROMPT + assert ionic_tailwind[0].get("content") == IONIC_TAILWIND_SYSTEM_PROMPT assert ionic_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore vue_tailwind = assemble_prompt( "image_data_url", "vue_tailwind", "result_image_data_url" ) - assert vue_tailwind[0]["content"] == VUE_TAILWIND_SYSTEM_PROMPT + assert vue_tailwind[0].get("content") == VUE_TAILWIND_SYSTEM_PROMPT assert vue_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore svg_prompt = assemble_prompt("image_data_url", "svg", "result_image_data_url") - assert svg_prompt[0]["content"] == SVG_SYSTEM_PROMPT + assert svg_prompt[0].get("content") == SVG_SYSTEM_PROMPT assert svg_prompt[1]["content"][2]["text"] == SVG_USER_PROMPT # type: ignore diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 2efc7a0..2d82aca 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -6,7 +6,7 @@ from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE from custom_types import InputMode from llm import ( CODE_GENERATION_MODELS, - MODEL_CLAUDE_OPUS, + Llm, stream_claude_response, stream_claude_response_native, stream_openai_response, @@ -88,6 +88,7 @@ async def stream_code(websocket: WebSocket): if code_generation_model not in CODE_GENERATION_MODELS: await throw_error(f"Invalid model: {code_generation_model}") raise Exception(f"Invalid model: {code_generation_model}") + exact_llm_version = None print( f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." @@ -238,9 +239,10 @@ async def stream_code(websocket: WebSocket): messages=prompt_messages, # type: ignore api_key=ANTHROPIC_API_KEY, callback=lambda x: process_chunk(x), - model=MODEL_CLAUDE_OPUS, + model=Llm.CLAUDE_3_OPUS, include_thinking=True, ) + exact_llm_version = Llm.CLAUDE_3_OPUS elif code_generation_model == "claude_3_sonnet": if not ANTHROPIC_API_KEY: await throw_error( @@ -253,6 +255,7 @@ async def stream_code(websocket: WebSocket): api_key=ANTHROPIC_API_KEY, callback=lambda x: process_chunk(x), ) + exact_llm_version = Llm.CLAUDE_3_SONNET else: completion = await stream_openai_response( prompt_messages, # type: ignore @@ -260,6 +263,7 @@ async def stream_code(websocket: WebSocket): base_url=openai_base_url, callback=lambda x: process_chunk(x), ) + exact_llm_version = Llm.GPT_4_VISION except openai.AuthenticationError as e: print("[GENERATE_CODE] Authentication failed", e) error_message = ( @@ -298,6 +302,8 @@ async def stream_code(websocket: WebSocket): if validated_input_mode == "video": completion = extract_tag_content("html", completion) + print("Exact used model for generation: ", exact_llm_version) + # Write the messages dict into a log so that we can debug later write_logs(prompt_messages, completion) # type: ignore diff --git a/backend/video/utils.py b/backend/video/utils.py index 94501a2..dc772fc 100644 --- a/backend/video/utils.py +++ b/backend/video/utils.py @@ -5,7 +5,7 @@ import mimetypes import os import tempfile import uuid -from typing import Union, cast +from typing import Any, Union, cast from moviepy.editor import VideoFileClip # type: ignore from PIL import Image import math @@ -17,7 +17,7 @@ TARGET_NUM_SCREENSHOTS = ( ) -async def assemble_claude_prompt_video(video_data_url: str): +async def assemble_claude_prompt_video(video_data_url: str) -> list[Any]: images = split_video_into_screenshots(video_data_url) # Save images to tmp if we're debugging @@ -28,7 +28,7 @@ async def assemble_claude_prompt_video(video_data_url: str): print(f"Number of frames extracted from video: {len(images)}") if len(images) > 20: print(f"Too many screenshots: {len(images)}") - return + raise ValueError("Too many screenshots extracted from video") # Convert images to the message format for Claude content_messages: list[dict[str, Union[dict[str, str], str]]] = [] diff --git a/backend/video_to_app.py b/backend/video_to_app.py index 597cbbf..c876804 100644 --- a/backend/video_to_app.py +++ b/backend/video_to_app.py @@ -17,8 +17,7 @@ from utils import pprint_prompt from config import ANTHROPIC_API_KEY from video.utils import extract_tag_content, assemble_claude_prompt_video from llm import ( - MODEL_CLAUDE_OPUS, - # MODEL_CLAUDE_SONNET, + Llm, stream_claude_response_native, ) @@ -87,7 +86,7 @@ async def main(): messages=prompt_messages, api_key=ANTHROPIC_API_KEY, callback=lambda x: process_chunk(x), - model=MODEL_CLAUDE_OPUS, + model=Llm.CLAUDE_3_OPUS, include_thinking=True, )