identify exact llm being used during generation
This commit is contained in:
parent
4e30b207c1
commit
81c4fbe28d
@ -1,3 +1,4 @@
|
|||||||
|
from enum import Enum
|
||||||
from typing import Any, Awaitable, Callable, List, cast
|
from typing import Any, Awaitable, Callable, List, cast
|
||||||
from anthropic import AsyncAnthropic
|
from anthropic import AsyncAnthropic
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
@ -5,13 +6,18 @@ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
|||||||
|
|
||||||
from utils import pprint_prompt
|
from utils import pprint_prompt
|
||||||
|
|
||||||
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
|
||||||
MODEL_CLAUDE_SONNET = "claude-3-sonnet-20240229"
|
# Actual model versions that are passed to the LLMs and stored in our logs
|
||||||
MODEL_CLAUDE_OPUS = "claude-3-opus-20240229"
|
class Llm(Enum):
|
||||||
MODEL_CLAUDE_HAIKU = "claude-3-haiku-20240307"
|
GPT_4_VISION = "gpt-4-vision-preview"
|
||||||
|
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
||||||
|
CLAUDE_3_OPUS = "claude-3-opus-20240229"
|
||||||
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
|
|
||||||
|
|
||||||
# Keep in sync with frontend (lib/models.ts)
|
# Keep in sync with frontend (lib/models.ts)
|
||||||
|
# User-facing names for the models (for example, in the future, gpt_4_vision might
|
||||||
|
# be backed by a different model version)
|
||||||
CODE_GENERATION_MODELS = [
|
CODE_GENERATION_MODELS = [
|
||||||
"gpt_4_vision",
|
"gpt_4_vision",
|
||||||
"claude_3_sonnet",
|
"claude_3_sonnet",
|
||||||
@ -26,13 +32,13 @@ async def stream_openai_response(
|
|||||||
) -> str:
|
) -> str:
|
||||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
model = MODEL_GPT_4_VISION
|
model = Llm.GPT_4_VISION
|
||||||
|
|
||||||
# Base parameters
|
# Base parameters
|
||||||
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
|
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
|
||||||
|
|
||||||
# Add 'max_tokens' only if the model is a GPT4 vision model
|
# Add 'max_tokens' only if the model is a GPT4 vision model
|
||||||
if model == MODEL_GPT_4_VISION:
|
if model == Llm.GPT_4_VISION:
|
||||||
params["max_tokens"] = 4096
|
params["max_tokens"] = 4096
|
||||||
params["temperature"] = 0
|
params["temperature"] = 0
|
||||||
|
|
||||||
@ -59,12 +65,12 @@ async def stream_claude_response(
|
|||||||
client = AsyncAnthropic(api_key=api_key)
|
client = AsyncAnthropic(api_key=api_key)
|
||||||
|
|
||||||
# Base parameters
|
# Base parameters
|
||||||
model = MODEL_CLAUDE_SONNET
|
model = Llm.CLAUDE_3_SONNET
|
||||||
max_tokens = 4096
|
max_tokens = 4096
|
||||||
temperature = 0.0
|
temperature = 0.0
|
||||||
|
|
||||||
# Translate OpenAI messages to Claude messages
|
# Translate OpenAI messages to Claude messages
|
||||||
system_prompt = cast(str, messages[0]["content"])
|
system_prompt = cast(str, messages[0].get("content"))
|
||||||
claude_messages = [dict(message) for message in messages[1:]]
|
claude_messages = [dict(message) for message in messages[1:]]
|
||||||
for message in claude_messages:
|
for message in claude_messages:
|
||||||
if not isinstance(message["content"], list):
|
if not isinstance(message["content"], list):
|
||||||
@ -91,7 +97,7 @@ async def stream_claude_response(
|
|||||||
|
|
||||||
# Stream Claude response
|
# Stream Claude response
|
||||||
async with client.messages.stream(
|
async with client.messages.stream(
|
||||||
model=model,
|
model=model.value,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
system=system_prompt,
|
system=system_prompt,
|
||||||
@ -111,7 +117,7 @@ async def stream_claude_response_native(
|
|||||||
api_key: str,
|
api_key: str,
|
||||||
callback: Callable[[str], Awaitable[None]],
|
callback: Callable[[str], Awaitable[None]],
|
||||||
include_thinking: bool = False,
|
include_thinking: bool = False,
|
||||||
model: str = MODEL_CLAUDE_OPUS,
|
model: Llm = Llm.CLAUDE_3_OPUS,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
client = AsyncAnthropic(api_key=api_key)
|
client = AsyncAnthropic(api_key=api_key)
|
||||||
@ -140,7 +146,7 @@ async def stream_claude_response_native(
|
|||||||
pprint_prompt(messages_to_send)
|
pprint_prompt(messages_to_send)
|
||||||
|
|
||||||
async with client.messages.stream(
|
async with client.messages.stream(
|
||||||
model=model,
|
model=model.value,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
system=system_prompt,
|
system=system_prompt,
|
||||||
|
|||||||
@ -311,35 +311,35 @@ def test_prompts():
|
|||||||
tailwind_prompt = assemble_prompt(
|
tailwind_prompt = assemble_prompt(
|
||||||
"image_data_url", "html_tailwind", "result_image_data_url"
|
"image_data_url", "html_tailwind", "result_image_data_url"
|
||||||
)
|
)
|
||||||
assert tailwind_prompt[0]["content"] == TAILWIND_SYSTEM_PROMPT
|
assert tailwind_prompt[0].get("content") == TAILWIND_SYSTEM_PROMPT
|
||||||
assert tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
assert tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
||||||
|
|
||||||
react_tailwind_prompt = assemble_prompt(
|
react_tailwind_prompt = assemble_prompt(
|
||||||
"image_data_url", "react_tailwind", "result_image_data_url"
|
"image_data_url", "react_tailwind", "result_image_data_url"
|
||||||
)
|
)
|
||||||
assert react_tailwind_prompt[0]["content"] == REACT_TAILWIND_SYSTEM_PROMPT
|
assert react_tailwind_prompt[0].get("content") == REACT_TAILWIND_SYSTEM_PROMPT
|
||||||
assert react_tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
assert react_tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
||||||
|
|
||||||
bootstrap_prompt = assemble_prompt(
|
bootstrap_prompt = assemble_prompt(
|
||||||
"image_data_url", "bootstrap", "result_image_data_url"
|
"image_data_url", "bootstrap", "result_image_data_url"
|
||||||
)
|
)
|
||||||
assert bootstrap_prompt[0]["content"] == BOOTSTRAP_SYSTEM_PROMPT
|
assert bootstrap_prompt[0].get("content") == BOOTSTRAP_SYSTEM_PROMPT
|
||||||
assert bootstrap_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
assert bootstrap_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
||||||
|
|
||||||
ionic_tailwind = assemble_prompt(
|
ionic_tailwind = assemble_prompt(
|
||||||
"image_data_url", "ionic_tailwind", "result_image_data_url"
|
"image_data_url", "ionic_tailwind", "result_image_data_url"
|
||||||
)
|
)
|
||||||
assert ionic_tailwind[0]["content"] == IONIC_TAILWIND_SYSTEM_PROMPT
|
assert ionic_tailwind[0].get("content") == IONIC_TAILWIND_SYSTEM_PROMPT
|
||||||
assert ionic_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
assert ionic_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
||||||
|
|
||||||
vue_tailwind = assemble_prompt(
|
vue_tailwind = assemble_prompt(
|
||||||
"image_data_url", "vue_tailwind", "result_image_data_url"
|
"image_data_url", "vue_tailwind", "result_image_data_url"
|
||||||
)
|
)
|
||||||
assert vue_tailwind[0]["content"] == VUE_TAILWIND_SYSTEM_PROMPT
|
assert vue_tailwind[0].get("content") == VUE_TAILWIND_SYSTEM_PROMPT
|
||||||
assert vue_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
assert vue_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
||||||
|
|
||||||
svg_prompt = assemble_prompt("image_data_url", "svg", "result_image_data_url")
|
svg_prompt = assemble_prompt("image_data_url", "svg", "result_image_data_url")
|
||||||
assert svg_prompt[0]["content"] == SVG_SYSTEM_PROMPT
|
assert svg_prompt[0].get("content") == SVG_SYSTEM_PROMPT
|
||||||
assert svg_prompt[1]["content"][2]["text"] == SVG_USER_PROMPT # type: ignore
|
assert svg_prompt[1]["content"][2]["text"] == SVG_USER_PROMPT # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
|||||||
from custom_types import InputMode
|
from custom_types import InputMode
|
||||||
from llm import (
|
from llm import (
|
||||||
CODE_GENERATION_MODELS,
|
CODE_GENERATION_MODELS,
|
||||||
MODEL_CLAUDE_OPUS,
|
Llm,
|
||||||
stream_claude_response,
|
stream_claude_response,
|
||||||
stream_claude_response_native,
|
stream_claude_response_native,
|
||||||
stream_openai_response,
|
stream_openai_response,
|
||||||
@ -88,6 +88,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
if code_generation_model not in CODE_GENERATION_MODELS:
|
if code_generation_model not in CODE_GENERATION_MODELS:
|
||||||
await throw_error(f"Invalid model: {code_generation_model}")
|
await throw_error(f"Invalid model: {code_generation_model}")
|
||||||
raise Exception(f"Invalid model: {code_generation_model}")
|
raise Exception(f"Invalid model: {code_generation_model}")
|
||||||
|
exact_llm_version = None
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
|
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
|
||||||
@ -238,9 +239,10 @@ async def stream_code(websocket: WebSocket):
|
|||||||
messages=prompt_messages, # type: ignore
|
messages=prompt_messages, # type: ignore
|
||||||
api_key=ANTHROPIC_API_KEY,
|
api_key=ANTHROPIC_API_KEY,
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x: process_chunk(x),
|
||||||
model=MODEL_CLAUDE_OPUS,
|
model=Llm.CLAUDE_3_OPUS,
|
||||||
include_thinking=True,
|
include_thinking=True,
|
||||||
)
|
)
|
||||||
|
exact_llm_version = Llm.CLAUDE_3_OPUS
|
||||||
elif code_generation_model == "claude_3_sonnet":
|
elif code_generation_model == "claude_3_sonnet":
|
||||||
if not ANTHROPIC_API_KEY:
|
if not ANTHROPIC_API_KEY:
|
||||||
await throw_error(
|
await throw_error(
|
||||||
@ -253,6 +255,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
api_key=ANTHROPIC_API_KEY,
|
api_key=ANTHROPIC_API_KEY,
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x: process_chunk(x),
|
||||||
)
|
)
|
||||||
|
exact_llm_version = Llm.CLAUDE_3_SONNET
|
||||||
else:
|
else:
|
||||||
completion = await stream_openai_response(
|
completion = await stream_openai_response(
|
||||||
prompt_messages, # type: ignore
|
prompt_messages, # type: ignore
|
||||||
@ -260,6 +263,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
base_url=openai_base_url,
|
base_url=openai_base_url,
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x: process_chunk(x),
|
||||||
)
|
)
|
||||||
|
exact_llm_version = Llm.GPT_4_VISION
|
||||||
except openai.AuthenticationError as e:
|
except openai.AuthenticationError as e:
|
||||||
print("[GENERATE_CODE] Authentication failed", e)
|
print("[GENERATE_CODE] Authentication failed", e)
|
||||||
error_message = (
|
error_message = (
|
||||||
@ -298,6 +302,8 @@ async def stream_code(websocket: WebSocket):
|
|||||||
if validated_input_mode == "video":
|
if validated_input_mode == "video":
|
||||||
completion = extract_tag_content("html", completion)
|
completion = extract_tag_content("html", completion)
|
||||||
|
|
||||||
|
print("Exact used model for generation: ", exact_llm_version)
|
||||||
|
|
||||||
# Write the messages dict into a log so that we can debug later
|
# Write the messages dict into a log so that we can debug later
|
||||||
write_logs(prompt_messages, completion) # type: ignore
|
write_logs(prompt_messages, completion) # type: ignore
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import mimetypes
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Union, cast
|
from typing import Any, Union, cast
|
||||||
from moviepy.editor import VideoFileClip # type: ignore
|
from moviepy.editor import VideoFileClip # type: ignore
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import math
|
import math
|
||||||
@ -17,7 +17,7 @@ TARGET_NUM_SCREENSHOTS = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def assemble_claude_prompt_video(video_data_url: str):
|
async def assemble_claude_prompt_video(video_data_url: str) -> list[Any]:
|
||||||
images = split_video_into_screenshots(video_data_url)
|
images = split_video_into_screenshots(video_data_url)
|
||||||
|
|
||||||
# Save images to tmp if we're debugging
|
# Save images to tmp if we're debugging
|
||||||
@ -28,7 +28,7 @@ async def assemble_claude_prompt_video(video_data_url: str):
|
|||||||
print(f"Number of frames extracted from video: {len(images)}")
|
print(f"Number of frames extracted from video: {len(images)}")
|
||||||
if len(images) > 20:
|
if len(images) > 20:
|
||||||
print(f"Too many screenshots: {len(images)}")
|
print(f"Too many screenshots: {len(images)}")
|
||||||
return
|
raise ValueError("Too many screenshots extracted from video")
|
||||||
|
|
||||||
# Convert images to the message format for Claude
|
# Convert images to the message format for Claude
|
||||||
content_messages: list[dict[str, Union[dict[str, str], str]]] = []
|
content_messages: list[dict[str, Union[dict[str, str], str]]] = []
|
||||||
|
|||||||
@ -17,8 +17,7 @@ from utils import pprint_prompt
|
|||||||
from config import ANTHROPIC_API_KEY
|
from config import ANTHROPIC_API_KEY
|
||||||
from video.utils import extract_tag_content, assemble_claude_prompt_video
|
from video.utils import extract_tag_content, assemble_claude_prompt_video
|
||||||
from llm import (
|
from llm import (
|
||||||
MODEL_CLAUDE_OPUS,
|
Llm,
|
||||||
# MODEL_CLAUDE_SONNET,
|
|
||||||
stream_claude_response_native,
|
stream_claude_response_native,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -87,7 +86,7 @@ async def main():
|
|||||||
messages=prompt_messages,
|
messages=prompt_messages,
|
||||||
api_key=ANTHROPIC_API_KEY,
|
api_key=ANTHROPIC_API_KEY,
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x: process_chunk(x),
|
||||||
model=MODEL_CLAUDE_OPUS,
|
model=Llm.CLAUDE_3_OPUS,
|
||||||
include_thinking=True,
|
include_thinking=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user