Added support for AWS Bedrock Claude3

This commit is contained in:
shaoyun.zhang 2024-05-11 15:15:03 +08:00
parent f9c4dd9c7c
commit 26065b36df
4 changed files with 90 additions and 1 deletions

View File

@ -14,6 +14,8 @@ COPY poetry.lock pyproject.toml /app/
# Disable the creation of virtual environments # Disable the creation of virtual environments
RUN poetry config virtualenvs.create false RUN poetry config virtualenvs.create false
RUN poetry add boto3=^1.34.76
# Install dependencies # Install dependencies
RUN poetry install RUN poetry install

View File

@ -1,6 +1,6 @@
from enum import Enum from enum import Enum
from typing import Any, Awaitable, Callable, List, cast from typing import Any, Awaitable, Callable, List, cast
from anthropic import AsyncAnthropic from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
from config import IS_DEBUG_ENABLED from config import IS_DEBUG_ENABLED
@ -16,6 +16,7 @@ class Llm(Enum):
CLAUDE_3_SONNET = "claude-3-sonnet-20240229" CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_OPUS = "claude-3-opus-20240229" CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307" CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
AWS_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"
# Will throw errors if you send a garbage string # Will throw errors if you send a garbage string
@ -24,6 +25,8 @@ def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
return Llm.GPT_4_VISION return Llm.GPT_4_VISION
elif frontend_str == "claude_3_sonnet": elif frontend_str == "claude_3_sonnet":
return Llm.CLAUDE_3_SONNET return Llm.CLAUDE_3_SONNET
elif frontend_str == "aws_claude_3_sonnet":
return Llm.AWS_CLAUDE_3_SONNET
else: else:
return Llm(frontend_str) return Llm(frontend_str)
@ -123,6 +126,72 @@ async def stream_claude_response(
return response.content[0].text return response.content[0].text
# TODO: Have a seperate function that translates OpenAI messages to Claude messages
async def stream_aws_claude_response(
messages: List[ChatCompletionMessageParam],
aws_access_key: str,
aws_secret_key: str,
aws_region: str,
callback: Callable[[str], Awaitable[None]],
) -> str:
client = AsyncAnthropicBedrock(
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_region=aws_region,
)
# Base parameters
model = Llm.AWS_CLAUDE_3_SONNET
max_tokens = 4096
temperature = 0.0
# Translate OpenAI messages to Claude messages
system_prompt = cast(str, messages[0].get("content"))
claude_messages = [dict(message) for message in messages[1:]]
for message in claude_messages:
if not isinstance(message["content"], list):
continue
for content in message["content"]: # type: ignore
if content["type"] == "image_url":
content["type"] = "image"
# Extract base64 data and media type from data URL
# Example base64 data URL: data:image/png;base64,iVBOR...
image_data_url = cast(str, content["image_url"]["url"])
media_type = image_data_url.split(";")[0].split(":")[1]
base64_data = image_data_url.split(",")[1]
# Remove OpenAI parameter
del content["image_url"]
content["source"] = {
"type": "base64",
"media_type": media_type,
"data": base64_data,
}
# Stream Claude response
async with client.messages.stream(
model=model.value,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt,
messages=claude_messages, # type: ignore
) as stream:
async for text in stream.text_stream:
await callback(text)
# Return final message
response = await stream.get_final_message()
# Close the Anthropic client
await client.close()
return response.content[0].text
async def stream_claude_response_native( async def stream_claude_response_native(
system_prompt: str, system_prompt: str,
messages: list[Any], messages: list[Any],

View File

@ -7,6 +7,7 @@ from custom_types import InputMode
from llm import ( from llm import (
Llm, Llm,
convert_frontend_str_to_llm, convert_frontend_str_to_llm,
stream_aws_claude_response,
stream_claude_response, stream_claude_response,
stream_claude_response_native, stream_claude_response_native,
stream_openai_response, stream_openai_response,
@ -246,6 +247,21 @@ async def stream_code(websocket: WebSocket):
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
) )
exact_llm_version = code_generation_model exact_llm_version = code_generation_model
elif code_generation_model == Llm.AWS_CLAUDE_3_SONNET:
if not os.environ.get("AWS_AK", None):
await throw_error(
"No AWS Bedrock Anthropic API Access Key found. Please add the environment variable AWS_AK to backend/.env"
)
raise Exception("No AWS Bedrock Anthropic Access key")
completion = await stream_aws_claude_response(
prompt_messages, # type: ignore
aws_access_key=os.environ.get("AWS_AK"),
aws_secret_key=os.environ.get("AWS_SK"),
aws_region=os.environ.get("AWS_REGION"),
callback=lambda x: process_chunk(x),
)
exact_llm_version = code_generation_model
else: else:
completion = await stream_openai_response( completion = await stream_openai_response(
prompt_messages, # type: ignore prompt_messages, # type: ignore

View File

@ -3,6 +3,7 @@ export enum CodeGenerationModel {
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09", GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09",
GPT_4_VISION = "gpt_4_vision", GPT_4_VISION = "gpt_4_vision",
CLAUDE_3_SONNET = "claude_3_sonnet", CLAUDE_3_SONNET = "claude_3_sonnet",
AWS_CLAUDE_3_SONNET = "aws_claude_3_sonnet",
} }
// Will generate a static error if a model in the enum above is not in the descriptions // Will generate a static error if a model in the enum above is not in the descriptions
@ -12,4 +13,5 @@ export const CODE_GENERATION_MODEL_DESCRIPTIONS: {
"gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false }, "gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false },
gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false }, gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false },
claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false }, claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false },
aws_claude_3_sonnet: { name: "Claude 3 Sonnet (AWS Bedrock)", inBeta: false },
}; };