Added support for AWS Bedrock Claude3
This commit is contained in:
parent
f9c4dd9c7c
commit
26065b36df
@ -14,6 +14,8 @@ COPY poetry.lock pyproject.toml /app/
|
||||
# Disable the creation of virtual environments
|
||||
RUN poetry config virtualenvs.create false
|
||||
|
||||
RUN poetry add boto3=^1.34.76
|
||||
|
||||
# Install dependencies
|
||||
RUN poetry install
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, List, cast
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
||||
from config import IS_DEBUG_ENABLED
|
||||
@ -16,6 +16,7 @@ class Llm(Enum):
|
||||
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
||||
CLAUDE_3_OPUS = "claude-3-opus-20240229"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
AWS_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
|
||||
|
||||
# Will throw errors if you send a garbage string
|
||||
@ -24,6 +25,8 @@ def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
|
||||
return Llm.GPT_4_VISION
|
||||
elif frontend_str == "claude_3_sonnet":
|
||||
return Llm.CLAUDE_3_SONNET
|
||||
elif frontend_str == "aws_claude_3_sonnet":
|
||||
return Llm.AWS_CLAUDE_3_SONNET
|
||||
else:
|
||||
return Llm(frontend_str)
|
||||
|
||||
@ -123,6 +126,72 @@ async def stream_claude_response(
|
||||
return response.content[0].text
|
||||
|
||||
|
||||
# TODO: Have a seperate function that translates OpenAI messages to Claude messages
|
||||
async def stream_aws_claude_response(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
aws_access_key: str,
|
||||
aws_secret_key: str,
|
||||
aws_region: str,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
) -> str:
|
||||
|
||||
client = AsyncAnthropicBedrock(
|
||||
aws_access_key=aws_access_key,
|
||||
aws_secret_key=aws_secret_key,
|
||||
aws_region=aws_region,
|
||||
)
|
||||
|
||||
# Base parameters
|
||||
model = Llm.AWS_CLAUDE_3_SONNET
|
||||
max_tokens = 4096
|
||||
temperature = 0.0
|
||||
|
||||
# Translate OpenAI messages to Claude messages
|
||||
system_prompt = cast(str, messages[0].get("content"))
|
||||
claude_messages = [dict(message) for message in messages[1:]]
|
||||
for message in claude_messages:
|
||||
if not isinstance(message["content"], list):
|
||||
continue
|
||||
|
||||
for content in message["content"]: # type: ignore
|
||||
if content["type"] == "image_url":
|
||||
content["type"] = "image"
|
||||
|
||||
# Extract base64 data and media type from data URL
|
||||
# Example base64 data URL: data:image/png;base64,iVBOR...
|
||||
image_data_url = cast(str, content["image_url"]["url"])
|
||||
media_type = image_data_url.split(";")[0].split(":")[1]
|
||||
base64_data = image_data_url.split(",")[1]
|
||||
|
||||
# Remove OpenAI parameter
|
||||
del content["image_url"]
|
||||
|
||||
content["source"] = {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data,
|
||||
}
|
||||
|
||||
# Stream Claude response
|
||||
async with client.messages.stream(
|
||||
model=model.value,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_prompt,
|
||||
messages=claude_messages, # type: ignore
|
||||
) as stream:
|
||||
async for text in stream.text_stream:
|
||||
await callback(text)
|
||||
|
||||
# Return final message
|
||||
response = await stream.get_final_message()
|
||||
|
||||
# Close the Anthropic client
|
||||
await client.close()
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
|
||||
async def stream_claude_response_native(
|
||||
system_prompt: str,
|
||||
messages: list[Any],
|
||||
|
||||
@ -7,6 +7,7 @@ from custom_types import InputMode
|
||||
from llm import (
|
||||
Llm,
|
||||
convert_frontend_str_to_llm,
|
||||
stream_aws_claude_response,
|
||||
stream_claude_response,
|
||||
stream_claude_response_native,
|
||||
stream_openai_response,
|
||||
@ -246,6 +247,21 @@ async def stream_code(websocket: WebSocket):
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
exact_llm_version = code_generation_model
|
||||
elif code_generation_model == Llm.AWS_CLAUDE_3_SONNET:
|
||||
if not os.environ.get("AWS_AK", None):
|
||||
await throw_error(
|
||||
"No AWS Bedrock Anthropic API Access Key found. Please add the environment variable AWS_AK to backend/.env"
|
||||
)
|
||||
raise Exception("No AWS Bedrock Anthropic Access key")
|
||||
|
||||
completion = await stream_aws_claude_response(
|
||||
prompt_messages, # type: ignore
|
||||
aws_access_key=os.environ.get("AWS_AK"),
|
||||
aws_secret_key=os.environ.get("AWS_SK"),
|
||||
aws_region=os.environ.get("AWS_REGION"),
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
exact_llm_version = code_generation_model
|
||||
else:
|
||||
completion = await stream_openai_response(
|
||||
prompt_messages, # type: ignore
|
||||
|
||||
@ -3,6 +3,7 @@ export enum CodeGenerationModel {
|
||||
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09",
|
||||
GPT_4_VISION = "gpt_4_vision",
|
||||
CLAUDE_3_SONNET = "claude_3_sonnet",
|
||||
AWS_CLAUDE_3_SONNET = "aws_claude_3_sonnet",
|
||||
}
|
||||
|
||||
// Will generate a static error if a model in the enum above is not in the descriptions
|
||||
@ -12,4 +13,5 @@ export const CODE_GENERATION_MODEL_DESCRIPTIONS: {
|
||||
"gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false },
|
||||
gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false },
|
||||
claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false },
|
||||
aws_claude_3_sonnet: { name: "Claude 3 Sonnet (AWS Bedrock)", inBeta: false },
|
||||
};
|
||||
|
||||
Loading…
Reference in New Issue
Block a user