This commit is contained in:
shaoyun 2024-06-08 23:15:13 +01:00 committed by GitHub
commit 5b6bd7f2be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 90 additions and 1 deletions

View File

@ -14,6 +14,8 @@ COPY poetry.lock pyproject.toml /app/
# Disable the creation of virtual environments
RUN poetry config virtualenvs.create false
RUN poetry add boto3=^1.34.76
# Install dependencies
RUN poetry install

View File

@ -1,6 +1,6 @@
from enum import Enum
from typing import Any, Awaitable, Callable, List, cast
from anthropic import AsyncAnthropic
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
from config import IS_DEBUG_ENABLED
@ -17,6 +17,7 @@ class Llm(Enum):
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
AWS_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"
# Will throw errors if you send a garbage string
@ -25,6 +26,8 @@ def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
return Llm.GPT_4_VISION
elif frontend_str == "claude_3_sonnet":
return Llm.CLAUDE_3_SONNET
elif frontend_str == "aws_claude_3_sonnet":
return Llm.AWS_CLAUDE_3_SONNET
else:
return Llm(frontend_str)
@ -129,6 +132,72 @@ async def stream_claude_response(
return response.content[0].text
# TODO: Have a seperate function that translates OpenAI messages to Claude messages
async def stream_aws_claude_response(
messages: List[ChatCompletionMessageParam],
aws_access_key: str,
aws_secret_key: str,
aws_region: str,
callback: Callable[[str], Awaitable[None]],
) -> str:
client = AsyncAnthropicBedrock(
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_region=aws_region,
)
# Base parameters
model = Llm.AWS_CLAUDE_3_SONNET
max_tokens = 4096
temperature = 0.0
# Translate OpenAI messages to Claude messages
system_prompt = cast(str, messages[0].get("content"))
claude_messages = [dict(message) for message in messages[1:]]
for message in claude_messages:
if not isinstance(message["content"], list):
continue
for content in message["content"]: # type: ignore
if content["type"] == "image_url":
content["type"] = "image"
# Extract base64 data and media type from data URL
# Example base64 data URL: data:image/png;base64,iVBOR...
image_data_url = cast(str, content["image_url"]["url"])
media_type = image_data_url.split(";")[0].split(":")[1]
base64_data = image_data_url.split(",")[1]
# Remove OpenAI parameter
del content["image_url"]
content["source"] = {
"type": "base64",
"media_type": media_type,
"data": base64_data,
}
# Stream Claude response
async with client.messages.stream(
model=model.value,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt,
messages=claude_messages, # type: ignore
) as stream:
async for text in stream.text_stream:
await callback(text)
# Return final message
response = await stream.get_final_message()
# Close the Anthropic client
await client.close()
return response.content[0].text
async def stream_claude_response_native(
system_prompt: str,
messages: list[Any],

View File

@ -7,6 +7,7 @@ from custom_types import InputMode
from llm import (
Llm,
convert_frontend_str_to_llm,
stream_aws_claude_response,
stream_claude_response,
stream_claude_response_native,
stream_openai_response,
@ -258,6 +259,21 @@ async def stream_code(websocket: WebSocket):
callback=lambda x: process_chunk(x),
)
exact_llm_version = code_generation_model
elif code_generation_model == Llm.AWS_CLAUDE_3_SONNET:
if not os.environ.get("AWS_AK", None):
await throw_error(
"No AWS Bedrock Anthropic API Access Key found. Please add the environment variable AWS_AK to backend/.env"
)
raise Exception("No AWS Bedrock Anthropic Access key")
completion = await stream_aws_claude_response(
prompt_messages, # type: ignore
aws_access_key=os.environ.get("AWS_AK"),
aws_secret_key=os.environ.get("AWS_SK"),
aws_region=os.environ.get("AWS_REGION"),
callback=lambda x: process_chunk(x),
)
exact_llm_version = code_generation_model
else:
completion = await stream_openai_response(
prompt_messages, # type: ignore

View File

@ -5,6 +5,7 @@ export enum CodeGenerationModel {
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09",
GPT_4_VISION = "gpt_4_vision",
CLAUDE_3_SONNET = "claude_3_sonnet",
AWS_CLAUDE_3_SONNET = "aws_claude_3_sonnet",
}
// Will generate a static error if a model in the enum above is not in the descriptions
@ -15,4 +16,5 @@ export const CODE_GENERATION_MODEL_DESCRIPTIONS: {
"gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false },
gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false },
claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false },
aws_claude_3_sonnet: { name: "Claude 3 Sonnet (AWS Bedrock)", inBeta: false },
};