Merge branch 'main' into hosted
This commit is contained in:
commit
f6079542f7
@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, List, cast
|
||||
from anthropic import AsyncAnthropic
|
||||
from openai import AsyncOpenAI
|
||||
@ -5,13 +6,18 @@ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
||||
|
||||
from utils import pprint_prompt
|
||||
|
||||
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
||||
MODEL_CLAUDE_SONNET = "claude-3-sonnet-20240229"
|
||||
MODEL_CLAUDE_OPUS = "claude-3-opus-20240229"
|
||||
MODEL_CLAUDE_HAIKU = "claude-3-haiku-20240307"
|
||||
|
||||
# Actual model versions that are passed to the LLMs and stored in our logs
|
||||
class Llm(Enum):
|
||||
GPT_4_VISION = "gpt-4-vision-preview"
|
||||
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
||||
CLAUDE_3_OPUS = "claude-3-opus-20240229"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
|
||||
|
||||
# Keep in sync with frontend (lib/models.ts)
|
||||
# User-facing names for the models (for example, in the future, gpt_4_vision might
|
||||
# be backed by a different model version)
|
||||
CODE_GENERATION_MODELS = [
|
||||
"gpt_4_vision",
|
||||
"claude_3_sonnet",
|
||||
@ -26,13 +32,13 @@ async def stream_openai_response(
|
||||
) -> str:
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
model = MODEL_GPT_4_VISION
|
||||
model = Llm.GPT_4_VISION
|
||||
|
||||
# Base parameters
|
||||
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
|
||||
|
||||
# Add 'max_tokens' only if the model is a GPT4 vision model
|
||||
if model == MODEL_GPT_4_VISION:
|
||||
if model == Llm.GPT_4_VISION:
|
||||
params["max_tokens"] = 4096
|
||||
params["temperature"] = 0
|
||||
|
||||
@ -59,12 +65,12 @@ async def stream_claude_response(
|
||||
client = AsyncAnthropic(api_key=api_key)
|
||||
|
||||
# Base parameters
|
||||
model = MODEL_CLAUDE_SONNET
|
||||
model = Llm.CLAUDE_3_SONNET
|
||||
max_tokens = 4096
|
||||
temperature = 0.0
|
||||
|
||||
# Translate OpenAI messages to Claude messages
|
||||
system_prompt = cast(str, messages[0]["content"])
|
||||
system_prompt = cast(str, messages[0].get("content"))
|
||||
claude_messages = [dict(message) for message in messages[1:]]
|
||||
for message in claude_messages:
|
||||
if not isinstance(message["content"], list):
|
||||
@ -91,7 +97,7 @@ async def stream_claude_response(
|
||||
|
||||
# Stream Claude response
|
||||
async with client.messages.stream(
|
||||
model=model,
|
||||
model=model.value,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_prompt,
|
||||
@ -111,7 +117,7 @@ async def stream_claude_response_native(
|
||||
api_key: str,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
include_thinking: bool = False,
|
||||
model: str = MODEL_CLAUDE_OPUS,
|
||||
model: Llm = Llm.CLAUDE_3_OPUS,
|
||||
) -> str:
|
||||
|
||||
client = AsyncAnthropic(api_key=api_key)
|
||||
@ -140,7 +146,7 @@ async def stream_claude_response_native(
|
||||
pprint_prompt(messages_to_send)
|
||||
|
||||
async with client.messages.stream(
|
||||
model=model,
|
||||
model=model.value,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_prompt,
|
||||
|
||||
@ -311,35 +311,35 @@ def test_prompts():
|
||||
tailwind_prompt = assemble_prompt(
|
||||
"image_data_url", "html_tailwind", "result_image_data_url"
|
||||
)
|
||||
assert tailwind_prompt[0]["content"] == TAILWIND_SYSTEM_PROMPT
|
||||
assert tailwind_prompt[0].get("content") == TAILWIND_SYSTEM_PROMPT
|
||||
assert tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
||||
|
||||
react_tailwind_prompt = assemble_prompt(
|
||||
"image_data_url", "react_tailwind", "result_image_data_url"
|
||||
)
|
||||
assert react_tailwind_prompt[0]["content"] == REACT_TAILWIND_SYSTEM_PROMPT
|
||||
assert react_tailwind_prompt[0].get("content") == REACT_TAILWIND_SYSTEM_PROMPT
|
||||
assert react_tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
||||
|
||||
bootstrap_prompt = assemble_prompt(
|
||||
"image_data_url", "bootstrap", "result_image_data_url"
|
||||
)
|
||||
assert bootstrap_prompt[0]["content"] == BOOTSTRAP_SYSTEM_PROMPT
|
||||
assert bootstrap_prompt[0].get("content") == BOOTSTRAP_SYSTEM_PROMPT
|
||||
assert bootstrap_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
||||
|
||||
ionic_tailwind = assemble_prompt(
|
||||
"image_data_url", "ionic_tailwind", "result_image_data_url"
|
||||
)
|
||||
assert ionic_tailwind[0]["content"] == IONIC_TAILWIND_SYSTEM_PROMPT
|
||||
assert ionic_tailwind[0].get("content") == IONIC_TAILWIND_SYSTEM_PROMPT
|
||||
assert ionic_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
||||
|
||||
vue_tailwind = assemble_prompt(
|
||||
"image_data_url", "vue_tailwind", "result_image_data_url"
|
||||
)
|
||||
assert vue_tailwind[0]["content"] == VUE_TAILWIND_SYSTEM_PROMPT
|
||||
assert vue_tailwind[0].get("content") == VUE_TAILWIND_SYSTEM_PROMPT
|
||||
assert vue_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore
|
||||
|
||||
svg_prompt = assemble_prompt("image_data_url", "svg", "result_image_data_url")
|
||||
assert svg_prompt[0]["content"] == SVG_SYSTEM_PROMPT
|
||||
assert svg_prompt[0].get("content") == SVG_SYSTEM_PROMPT
|
||||
assert svg_prompt[1]["content"][2]["text"] == SVG_USER_PROMPT # type: ignore
|
||||
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
||||
from custom_types import InputMode
|
||||
from llm import (
|
||||
CODE_GENERATION_MODELS,
|
||||
MODEL_CLAUDE_OPUS,
|
||||
Llm,
|
||||
stream_claude_response,
|
||||
stream_claude_response_native,
|
||||
stream_openai_response,
|
||||
@ -88,6 +88,7 @@ async def stream_code(websocket: WebSocket):
|
||||
if code_generation_model not in CODE_GENERATION_MODELS:
|
||||
await throw_error(f"Invalid model: {code_generation_model}")
|
||||
raise Exception(f"Invalid model: {code_generation_model}")
|
||||
exact_llm_version = None
|
||||
|
||||
print(
|
||||
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
|
||||
@ -271,9 +272,10 @@ async def stream_code(websocket: WebSocket):
|
||||
messages=prompt_messages, # type: ignore
|
||||
api_key=ANTHROPIC_API_KEY,
|
||||
callback=lambda x: process_chunk(x),
|
||||
model=MODEL_CLAUDE_OPUS,
|
||||
model=Llm.CLAUDE_3_OPUS,
|
||||
include_thinking=True,
|
||||
)
|
||||
exact_llm_version = Llm.CLAUDE_3_OPUS
|
||||
elif code_generation_model == "claude_3_sonnet":
|
||||
if not ANTHROPIC_API_KEY:
|
||||
await throw_error(
|
||||
@ -286,6 +288,7 @@ async def stream_code(websocket: WebSocket):
|
||||
api_key=ANTHROPIC_API_KEY,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
exact_llm_version = Llm.CLAUDE_3_SONNET
|
||||
else:
|
||||
completion = await stream_openai_response(
|
||||
prompt_messages, # type: ignore
|
||||
@ -293,6 +296,7 @@ async def stream_code(websocket: WebSocket):
|
||||
base_url=openai_base_url,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
exact_llm_version = Llm.GPT_4_VISION
|
||||
except openai.AuthenticationError as e:
|
||||
print("[GENERATE_CODE] Authentication failed", e)
|
||||
error_message = (
|
||||
@ -331,6 +335,8 @@ async def stream_code(websocket: WebSocket):
|
||||
if validated_input_mode == "video":
|
||||
completion = extract_tag_content("html", completion)
|
||||
|
||||
print("Exact used model for generation: ", exact_llm_version)
|
||||
|
||||
# Write the messages dict into a log so that we can debug later
|
||||
write_logs(prompt_messages, completion) # type: ignore
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import mimetypes
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import Union, cast
|
||||
from typing import Any, Union, cast
|
||||
from moviepy.editor import VideoFileClip # type: ignore
|
||||
from PIL import Image
|
||||
import math
|
||||
@ -17,7 +17,7 @@ TARGET_NUM_SCREENSHOTS = (
|
||||
)
|
||||
|
||||
|
||||
async def assemble_claude_prompt_video(video_data_url: str):
|
||||
async def assemble_claude_prompt_video(video_data_url: str) -> list[Any]:
|
||||
images = split_video_into_screenshots(video_data_url)
|
||||
|
||||
# Save images to tmp if we're debugging
|
||||
@ -28,7 +28,7 @@ async def assemble_claude_prompt_video(video_data_url: str):
|
||||
print(f"Number of frames extracted from video: {len(images)}")
|
||||
if len(images) > 20:
|
||||
print(f"Too many screenshots: {len(images)}")
|
||||
return
|
||||
raise ValueError("Too many screenshots extracted from video")
|
||||
|
||||
# Convert images to the message format for Claude
|
||||
content_messages: list[dict[str, Union[dict[str, str], str]]] = []
|
||||
|
||||
@ -17,8 +17,7 @@ from utils import pprint_prompt
|
||||
from config import ANTHROPIC_API_KEY
|
||||
from video.utils import extract_tag_content, assemble_claude_prompt_video
|
||||
from llm import (
|
||||
MODEL_CLAUDE_OPUS,
|
||||
# MODEL_CLAUDE_SONNET,
|
||||
Llm,
|
||||
stream_claude_response_native,
|
||||
)
|
||||
|
||||
@ -87,7 +86,7 @@ async def main():
|
||||
messages=prompt_messages,
|
||||
api_key=ANTHROPIC_API_KEY,
|
||||
callback=lambda x: process_chunk(x),
|
||||
model=MODEL_CLAUDE_OPUS,
|
||||
model=Llm.CLAUDE_3_OPUS,
|
||||
include_thinking=True,
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user