identify exact llm being used during generation

This commit is contained in:
Abi Raja 2024-03-18 17:44:05 -04:00
parent 4e30b207c1
commit 81c4fbe28d
5 changed files with 36 additions and 25 deletions

View File

@ -1,3 +1,4 @@
from enum import Enum
from typing import Any, Awaitable, Callable, List, cast from typing import Any, Awaitable, Callable, List, cast
from anthropic import AsyncAnthropic from anthropic import AsyncAnthropic
from openai import AsyncOpenAI from openai import AsyncOpenAI
@ -5,13 +6,18 @@ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
from utils import pprint_prompt from utils import pprint_prompt
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
MODEL_CLAUDE_SONNET = "claude-3-sonnet-20240229" # Actual model versions that are passed to the LLMs and stored in our logs
MODEL_CLAUDE_OPUS = "claude-3-opus-20240229" class Llm(Enum):
MODEL_CLAUDE_HAIKU = "claude-3-haiku-20240307" GPT_4_VISION = "gpt-4-vision-preview"
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# Keep in sync with frontend (lib/models.ts) # Keep in sync with frontend (lib/models.ts)
# User-facing names for the models (for example, in the future, gpt_4_vision might
# be backed by a different model version)
CODE_GENERATION_MODELS = [ CODE_GENERATION_MODELS = [
"gpt_4_vision", "gpt_4_vision",
"claude_3_sonnet", "claude_3_sonnet",
@ -26,13 +32,13 @@ async def stream_openai_response(
) -> str: ) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url) client = AsyncOpenAI(api_key=api_key, base_url=base_url)
model = MODEL_GPT_4_VISION model = Llm.GPT_4_VISION
# Base parameters # Base parameters
params = {"model": model, "messages": messages, "stream": True, "timeout": 600} params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
# Add 'max_tokens' only if the model is a GPT4 vision model # Add 'max_tokens' only if the model is a GPT4 vision model
if model == MODEL_GPT_4_VISION: if model == Llm.GPT_4_VISION:
params["max_tokens"] = 4096 params["max_tokens"] = 4096
params["temperature"] = 0 params["temperature"] = 0
@ -59,12 +65,12 @@ async def stream_claude_response(
client = AsyncAnthropic(api_key=api_key) client = AsyncAnthropic(api_key=api_key)
# Base parameters # Base parameters
model = MODEL_CLAUDE_SONNET model = Llm.CLAUDE_3_SONNET
max_tokens = 4096 max_tokens = 4096
temperature = 0.0 temperature = 0.0
# Translate OpenAI messages to Claude messages # Translate OpenAI messages to Claude messages
system_prompt = cast(str, messages[0]["content"]) system_prompt = cast(str, messages[0].get("content"))
claude_messages = [dict(message) for message in messages[1:]] claude_messages = [dict(message) for message in messages[1:]]
for message in claude_messages: for message in claude_messages:
if not isinstance(message["content"], list): if not isinstance(message["content"], list):
@ -91,7 +97,7 @@ async def stream_claude_response(
# Stream Claude response # Stream Claude response
async with client.messages.stream( async with client.messages.stream(
model=model, model=model.value,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
system=system_prompt, system=system_prompt,
@ -111,7 +117,7 @@ async def stream_claude_response_native(
api_key: str, api_key: str,
callback: Callable[[str], Awaitable[None]], callback: Callable[[str], Awaitable[None]],
include_thinking: bool = False, include_thinking: bool = False,
model: str = MODEL_CLAUDE_OPUS, model: Llm = Llm.CLAUDE_3_OPUS,
) -> str: ) -> str:
client = AsyncAnthropic(api_key=api_key) client = AsyncAnthropic(api_key=api_key)
@ -140,7 +146,7 @@ async def stream_claude_response_native(
pprint_prompt(messages_to_send) pprint_prompt(messages_to_send)
async with client.messages.stream( async with client.messages.stream(
model=model, model=model.value,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
system=system_prompt, system=system_prompt,

View File

@ -311,35 +311,35 @@ def test_prompts():
tailwind_prompt = assemble_prompt( tailwind_prompt = assemble_prompt(
"image_data_url", "html_tailwind", "result_image_data_url" "image_data_url", "html_tailwind", "result_image_data_url"
) )
assert tailwind_prompt[0]["content"] == TAILWIND_SYSTEM_PROMPT assert tailwind_prompt[0].get("content") == TAILWIND_SYSTEM_PROMPT
assert tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore assert tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore
react_tailwind_prompt = assemble_prompt( react_tailwind_prompt = assemble_prompt(
"image_data_url", "react_tailwind", "result_image_data_url" "image_data_url", "react_tailwind", "result_image_data_url"
) )
assert react_tailwind_prompt[0]["content"] == REACT_TAILWIND_SYSTEM_PROMPT assert react_tailwind_prompt[0].get("content") == REACT_TAILWIND_SYSTEM_PROMPT
assert react_tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore assert react_tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore
bootstrap_prompt = assemble_prompt( bootstrap_prompt = assemble_prompt(
"image_data_url", "bootstrap", "result_image_data_url" "image_data_url", "bootstrap", "result_image_data_url"
) )
assert bootstrap_prompt[0]["content"] == BOOTSTRAP_SYSTEM_PROMPT assert bootstrap_prompt[0].get("content") == BOOTSTRAP_SYSTEM_PROMPT
assert bootstrap_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore assert bootstrap_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore
ionic_tailwind = assemble_prompt( ionic_tailwind = assemble_prompt(
"image_data_url", "ionic_tailwind", "result_image_data_url" "image_data_url", "ionic_tailwind", "result_image_data_url"
) )
assert ionic_tailwind[0]["content"] == IONIC_TAILWIND_SYSTEM_PROMPT assert ionic_tailwind[0].get("content") == IONIC_TAILWIND_SYSTEM_PROMPT
assert ionic_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore assert ionic_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore
vue_tailwind = assemble_prompt( vue_tailwind = assemble_prompt(
"image_data_url", "vue_tailwind", "result_image_data_url" "image_data_url", "vue_tailwind", "result_image_data_url"
) )
assert vue_tailwind[0]["content"] == VUE_TAILWIND_SYSTEM_PROMPT assert vue_tailwind[0].get("content") == VUE_TAILWIND_SYSTEM_PROMPT
assert vue_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore assert vue_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore
svg_prompt = assemble_prompt("image_data_url", "svg", "result_image_data_url") svg_prompt = assemble_prompt("image_data_url", "svg", "result_image_data_url")
assert svg_prompt[0]["content"] == SVG_SYSTEM_PROMPT assert svg_prompt[0].get("content") == SVG_SYSTEM_PROMPT
assert svg_prompt[1]["content"][2]["text"] == SVG_USER_PROMPT # type: ignore assert svg_prompt[1]["content"][2]["text"] == SVG_USER_PROMPT # type: ignore

View File

@ -6,7 +6,7 @@ from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
from custom_types import InputMode from custom_types import InputMode
from llm import ( from llm import (
CODE_GENERATION_MODELS, CODE_GENERATION_MODELS,
MODEL_CLAUDE_OPUS, Llm,
stream_claude_response, stream_claude_response,
stream_claude_response_native, stream_claude_response_native,
stream_openai_response, stream_openai_response,
@ -88,6 +88,7 @@ async def stream_code(websocket: WebSocket):
if code_generation_model not in CODE_GENERATION_MODELS: if code_generation_model not in CODE_GENERATION_MODELS:
await throw_error(f"Invalid model: {code_generation_model}") await throw_error(f"Invalid model: {code_generation_model}")
raise Exception(f"Invalid model: {code_generation_model}") raise Exception(f"Invalid model: {code_generation_model}")
exact_llm_version = None
print( print(
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
@ -238,9 +239,10 @@ async def stream_code(websocket: WebSocket):
messages=prompt_messages, # type: ignore messages=prompt_messages, # type: ignore
api_key=ANTHROPIC_API_KEY, api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
model=MODEL_CLAUDE_OPUS, model=Llm.CLAUDE_3_OPUS,
include_thinking=True, include_thinking=True,
) )
exact_llm_version = Llm.CLAUDE_3_OPUS
elif code_generation_model == "claude_3_sonnet": elif code_generation_model == "claude_3_sonnet":
if not ANTHROPIC_API_KEY: if not ANTHROPIC_API_KEY:
await throw_error( await throw_error(
@ -253,6 +255,7 @@ async def stream_code(websocket: WebSocket):
api_key=ANTHROPIC_API_KEY, api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
) )
exact_llm_version = Llm.CLAUDE_3_SONNET
else: else:
completion = await stream_openai_response( completion = await stream_openai_response(
prompt_messages, # type: ignore prompt_messages, # type: ignore
@ -260,6 +263,7 @@ async def stream_code(websocket: WebSocket):
base_url=openai_base_url, base_url=openai_base_url,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
) )
exact_llm_version = Llm.GPT_4_VISION
except openai.AuthenticationError as e: except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e) print("[GENERATE_CODE] Authentication failed", e)
error_message = ( error_message = (
@ -298,6 +302,8 @@ async def stream_code(websocket: WebSocket):
if validated_input_mode == "video": if validated_input_mode == "video":
completion = extract_tag_content("html", completion) completion = extract_tag_content("html", completion)
print("Exact used model for generation: ", exact_llm_version)
# Write the messages dict into a log so that we can debug later # Write the messages dict into a log so that we can debug later
write_logs(prompt_messages, completion) # type: ignore write_logs(prompt_messages, completion) # type: ignore

View File

@ -5,7 +5,7 @@ import mimetypes
import os import os
import tempfile import tempfile
import uuid import uuid
from typing import Union, cast from typing import Any, Union, cast
from moviepy.editor import VideoFileClip # type: ignore from moviepy.editor import VideoFileClip # type: ignore
from PIL import Image from PIL import Image
import math import math
@ -17,7 +17,7 @@ TARGET_NUM_SCREENSHOTS = (
) )
async def assemble_claude_prompt_video(video_data_url: str): async def assemble_claude_prompt_video(video_data_url: str) -> list[Any]:
images = split_video_into_screenshots(video_data_url) images = split_video_into_screenshots(video_data_url)
# Save images to tmp if we're debugging # Save images to tmp if we're debugging
@ -28,7 +28,7 @@ async def assemble_claude_prompt_video(video_data_url: str):
print(f"Number of frames extracted from video: {len(images)}") print(f"Number of frames extracted from video: {len(images)}")
if len(images) > 20: if len(images) > 20:
print(f"Too many screenshots: {len(images)}") print(f"Too many screenshots: {len(images)}")
return raise ValueError("Too many screenshots extracted from video")
# Convert images to the message format for Claude # Convert images to the message format for Claude
content_messages: list[dict[str, Union[dict[str, str], str]]] = [] content_messages: list[dict[str, Union[dict[str, str], str]]] = []

View File

@ -17,8 +17,7 @@ from utils import pprint_prompt
from config import ANTHROPIC_API_KEY from config import ANTHROPIC_API_KEY
from video.utils import extract_tag_content, assemble_claude_prompt_video from video.utils import extract_tag_content, assemble_claude_prompt_video
from llm import ( from llm import (
MODEL_CLAUDE_OPUS, Llm,
# MODEL_CLAUDE_SONNET,
stream_claude_response_native, stream_claude_response_native,
) )
@ -87,7 +86,7 @@ async def main():
messages=prompt_messages, messages=prompt_messages,
api_key=ANTHROPIC_API_KEY, api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
model=MODEL_CLAUDE_OPUS, model=Llm.CLAUDE_3_OPUS,
include_thinking=True, include_thinking=True,
) )