Merge 8ad36a6e95 into 392b9849a2
This commit is contained in:
commit
5b6bd7f2be
@ -14,6 +14,8 @@ COPY poetry.lock pyproject.toml /app/
|
|||||||
# Disable the creation of virtual environments
|
# Disable the creation of virtual environments
|
||||||
RUN poetry config virtualenvs.create false
|
RUN poetry config virtualenvs.create false
|
||||||
|
|
||||||
|
RUN poetry add boto3=^1.34.76
|
||||||
|
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
RUN poetry install
|
RUN poetry install
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Awaitable, Callable, List, cast
|
from typing import Any, Awaitable, Callable, List, cast
|
||||||
from anthropic import AsyncAnthropic
|
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
||||||
from config import IS_DEBUG_ENABLED
|
from config import IS_DEBUG_ENABLED
|
||||||
@ -17,6 +17,7 @@ class Llm(Enum):
|
|||||||
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
||||||
CLAUDE_3_OPUS = "claude-3-opus-20240229"
|
CLAUDE_3_OPUS = "claude-3-opus-20240229"
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
|
AWS_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||||
|
|
||||||
|
|
||||||
# Will throw errors if you send a garbage string
|
# Will throw errors if you send a garbage string
|
||||||
@ -25,6 +26,8 @@ def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
|
|||||||
return Llm.GPT_4_VISION
|
return Llm.GPT_4_VISION
|
||||||
elif frontend_str == "claude_3_sonnet":
|
elif frontend_str == "claude_3_sonnet":
|
||||||
return Llm.CLAUDE_3_SONNET
|
return Llm.CLAUDE_3_SONNET
|
||||||
|
elif frontend_str == "aws_claude_3_sonnet":
|
||||||
|
return Llm.AWS_CLAUDE_3_SONNET
|
||||||
else:
|
else:
|
||||||
return Llm(frontend_str)
|
return Llm(frontend_str)
|
||||||
|
|
||||||
@ -129,6 +132,72 @@ async def stream_claude_response(
|
|||||||
return response.content[0].text
|
return response.content[0].text
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Have a seperate function that translates OpenAI messages to Claude messages
|
||||||
|
async def stream_aws_claude_response(
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
aws_access_key: str,
|
||||||
|
aws_secret_key: str,
|
||||||
|
aws_region: str,
|
||||||
|
callback: Callable[[str], Awaitable[None]],
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
client = AsyncAnthropicBedrock(
|
||||||
|
aws_access_key=aws_access_key,
|
||||||
|
aws_secret_key=aws_secret_key,
|
||||||
|
aws_region=aws_region,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Base parameters
|
||||||
|
model = Llm.AWS_CLAUDE_3_SONNET
|
||||||
|
max_tokens = 4096
|
||||||
|
temperature = 0.0
|
||||||
|
|
||||||
|
# Translate OpenAI messages to Claude messages
|
||||||
|
system_prompt = cast(str, messages[0].get("content"))
|
||||||
|
claude_messages = [dict(message) for message in messages[1:]]
|
||||||
|
for message in claude_messages:
|
||||||
|
if not isinstance(message["content"], list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for content in message["content"]: # type: ignore
|
||||||
|
if content["type"] == "image_url":
|
||||||
|
content["type"] = "image"
|
||||||
|
|
||||||
|
# Extract base64 data and media type from data URL
|
||||||
|
# Example base64 data URL: data:image/png;base64,iVBOR...
|
||||||
|
image_data_url = cast(str, content["image_url"]["url"])
|
||||||
|
media_type = image_data_url.split(";")[0].split(":")[1]
|
||||||
|
base64_data = image_data_url.split(",")[1]
|
||||||
|
|
||||||
|
# Remove OpenAI parameter
|
||||||
|
del content["image_url"]
|
||||||
|
|
||||||
|
content["source"] = {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": media_type,
|
||||||
|
"data": base64_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Stream Claude response
|
||||||
|
async with client.messages.stream(
|
||||||
|
model=model.value,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
system=system_prompt,
|
||||||
|
messages=claude_messages, # type: ignore
|
||||||
|
) as stream:
|
||||||
|
async for text in stream.text_stream:
|
||||||
|
await callback(text)
|
||||||
|
|
||||||
|
# Return final message
|
||||||
|
response = await stream.get_final_message()
|
||||||
|
|
||||||
|
# Close the Anthropic client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
return response.content[0].text
|
||||||
|
|
||||||
|
|
||||||
async def stream_claude_response_native(
|
async def stream_claude_response_native(
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
messages: list[Any],
|
messages: list[Any],
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from custom_types import InputMode
|
|||||||
from llm import (
|
from llm import (
|
||||||
Llm,
|
Llm,
|
||||||
convert_frontend_str_to_llm,
|
convert_frontend_str_to_llm,
|
||||||
|
stream_aws_claude_response,
|
||||||
stream_claude_response,
|
stream_claude_response,
|
||||||
stream_claude_response_native,
|
stream_claude_response_native,
|
||||||
stream_openai_response,
|
stream_openai_response,
|
||||||
@ -258,6 +259,21 @@ async def stream_code(websocket: WebSocket):
|
|||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x: process_chunk(x),
|
||||||
)
|
)
|
||||||
exact_llm_version = code_generation_model
|
exact_llm_version = code_generation_model
|
||||||
|
elif code_generation_model == Llm.AWS_CLAUDE_3_SONNET:
|
||||||
|
if not os.environ.get("AWS_AK", None):
|
||||||
|
await throw_error(
|
||||||
|
"No AWS Bedrock Anthropic API Access Key found. Please add the environment variable AWS_AK to backend/.env"
|
||||||
|
)
|
||||||
|
raise Exception("No AWS Bedrock Anthropic Access key")
|
||||||
|
|
||||||
|
completion = await stream_aws_claude_response(
|
||||||
|
prompt_messages, # type: ignore
|
||||||
|
aws_access_key=os.environ.get("AWS_AK"),
|
||||||
|
aws_secret_key=os.environ.get("AWS_SK"),
|
||||||
|
aws_region=os.environ.get("AWS_REGION"),
|
||||||
|
callback=lambda x: process_chunk(x),
|
||||||
|
)
|
||||||
|
exact_llm_version = code_generation_model
|
||||||
else:
|
else:
|
||||||
completion = await stream_openai_response(
|
completion = await stream_openai_response(
|
||||||
prompt_messages, # type: ignore
|
prompt_messages, # type: ignore
|
||||||
|
|||||||
@ -5,6 +5,7 @@ export enum CodeGenerationModel {
|
|||||||
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09",
|
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09",
|
||||||
GPT_4_VISION = "gpt_4_vision",
|
GPT_4_VISION = "gpt_4_vision",
|
||||||
CLAUDE_3_SONNET = "claude_3_sonnet",
|
CLAUDE_3_SONNET = "claude_3_sonnet",
|
||||||
|
AWS_CLAUDE_3_SONNET = "aws_claude_3_sonnet",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Will generate a static error if a model in the enum above is not in the descriptions
|
// Will generate a static error if a model in the enum above is not in the descriptions
|
||||||
@ -15,4 +16,5 @@ export const CODE_GENERATION_MODEL_DESCRIPTIONS: {
|
|||||||
"gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false },
|
"gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false },
|
||||||
gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false },
|
gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false },
|
||||||
claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false },
|
claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false },
|
||||||
|
aws_claude_3_sonnet: { name: "Claude 3 Sonnet (AWS Bedrock)", inBeta: false },
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user