AWS Bedrock support - adjust code style

This commit is contained in:
yoreland 2024-06-09 00:01:32 +08:00
parent 7032342f91
commit afd34dfaad
4 changed files with 179 additions and 184 deletions

View File

@ -4,8 +4,11 @@
import os import os
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None) ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
# AWS
AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None) AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None)
AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY", None) AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME", "us-west-2")
# Debugging-related # Debugging-related

View File

@ -1,4 +1,6 @@
import os import os
from config import AWS_REGION_NAME
from config import AWS_ACCESS_KEY from config import AWS_ACCESS_KEY
from config import AWS_SECRET_ACCESS_KEY from config import AWS_SECRET_ACCESS_KEY
from config import ANTHROPIC_API_KEY from config import ANTHROPIC_API_KEY
@ -14,6 +16,7 @@ async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str:
anthropic_api_key = ANTHROPIC_API_KEY anthropic_api_key = ANTHROPIC_API_KEY
aws_access_key = AWS_ACCESS_KEY aws_access_key = AWS_ACCESS_KEY
aws_secret_access_key = AWS_SECRET_ACCESS_KEY aws_secret_access_key = AWS_SECRET_ACCESS_KEY
aws_region_name = AWS_REGION_NAME,
openai_base_url = None openai_base_url = None
async def process_chunk(content: str): async def process_chunk(content: str):
@ -34,6 +37,7 @@ async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str:
prompt_messages, prompt_messages,
access_key=aws_access_key, access_key=aws_access_key,
secret_access_key=aws_secret_access_key, secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
) )
else: else:

View File

@ -17,6 +17,7 @@ class Llm(Enum):
GPT_4_VISION = "gpt-4-vision-preview" GPT_4_VISION = "gpt-4-vision-preview"
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09" GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09"
CLAUDE_3_SONNET = "claude-3-sonnet-20240229" CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_SONNET_BEDROCK = "anthropic.claude-3-sonnet-20240229-v1:0"
CLAUDE_3_OPUS = "claude-3-opus-20240229" CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307" CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
@ -125,61 +126,42 @@ async def stream_claude_response(
return response.content[0].text return response.content[0].text
def initialize_bedrock_client(access_key: str, secret_access_key: str, aws_region_name: str):
async def stream_claude_response_aws_bedrock(
messages: List[ChatCompletionMessageParam],
access_key: str,
secret_access_key: str,
callback: Callable[[str], Awaitable[None]],
) -> str:
try: try:
# Initialize the Bedrock Runtime client # Initialize the Bedrock Runtime client
bedrock_runtime = boto3.client( bedrock_runtime = boto3.client(
service_name='bedrock-runtime', service_name='bedrock-runtime',
aws_access_key_id=access_key, aws_access_key_id=access_key,
aws_secret_access_key=secret_access_key, aws_secret_access_key=secret_access_key,
region_name=aws_region_name,
) )
return bedrock_runtime
except ClientError as err:
message = err.response["Error"]["Message"]
print(f"A client error occurred: {message}")
except Exception as err:
print("An error occurred!")
raise err
# Set model parameters async def stream_bedrock_response(
model_id = 'anthropic.claude-3-sonnet-20240229-v1:0' bedrock_runtime,
max_tokens = 4096 messages: List[dict],
content_type = 'application/json' system_prompt: str,
accept = '*/*' model_id: str,
max_tokens: int,
# Translate OpenAI messages to Claude messages content_type: str,
system_prompt = cast(str, messages[0].get("content")) accept: str,
claude_messages = [dict(message) for message in messages[1:]] temperature: float,
for message in claude_messages: callback: Callable[[str], Awaitable[None]],
if not isinstance(message["content"], list): ) -> str:
continue try:
for content in message["content"]: # type: ignore
if content["type"] == "image_url":
content["type"] = "image"
# Extract base64 data and media type from data URL
# Example base64 data URL: data:image/png;base64,iVBOR...
image_data_url = cast(str, content["image_url"]["url"])
media_type = image_data_url.split(";")[0].split(":")[1]
base64_data = image_data_url.split(",")[1]
# Remove OpenAI parameter
del content["image_url"]
content["source"] = {
"type": "base64",
"media_type": media_type,
"data": base64_data,
}
# Prepare the request body # Prepare the request body
body = json.dumps({ body = json.dumps({
"anthropic_version": "bedrock-2023-05-31", "anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens, "max_tokens": max_tokens,
"messages": claude_messages, "messages": messages,
"system": system_prompt, "system": system_prompt,
"temperature": 0.0 "temperature": temperature
}) })
# Invoke the Bedrock Runtime API with response stream # Invoke the Bedrock Runtime API with response stream
@ -213,6 +195,60 @@ async def stream_claude_response_aws_bedrock(
print("An error occurred!") print("An error occurred!")
raise err raise err
async def stream_claude_response_aws_bedrock(
messages: List[dict],
access_key: str,
secret_access_key: str,
aws_region_name: str,
callback: Callable[[str], Awaitable[None]],
) -> str:
bedrock_runtime = initialize_bedrock_client(access_key, secret_access_key, aws_region_name)
# Set model parameters
model_id = Llm.CLAUDE_3_SONNET_BEDROCK.value
max_tokens = 4096
content_type = 'application/json'
accept = '*/*'
temperature = 0.0
# Translate OpenAI messages to Claude messages
system_prompt = cast(str, messages[0].get("content"))
claude_messages = [dict(message) for message in messages[1:]]
for message in claude_messages:
if not isinstance(message["content"], list):
continue
for content in message["content"]: # type: ignore
if content["type"] == "image_url":
content["type"] = "image"
# Extract base64 data and media type from data URL
# Example base64 data URL: data:image/png;base64,iVBOR...
image_data_url = cast(str, content["image_url"]["url"])
media_type = image_data_url.split(";")[0].split(":")[1]
base64_data = image_data_url.split(",")[1]
# Remove OpenAI parameter
del content["image_url"]
content["source"] = {
"type": "base64",
"media_type": media_type,
"data": base64_data,
}
return await stream_bedrock_response(
bedrock_runtime,
claude_messages,
system_prompt,
model_id,
max_tokens,
content_type,
accept,
temperature,
callback,
)
async def stream_claude_response_native( async def stream_claude_response_native(
system_prompt: str, system_prompt: str,
messages: list[Any], messages: list[Any],
@ -306,22 +342,15 @@ async def stream_claude_response_native_aws_bedrock(
messages: list[Any], messages: list[Any],
access_key: str, access_key: str,
secret_access_key: str, secret_access_key: str,
aws_region_name: str,
callback: Callable[[str], Awaitable[None]], callback: Callable[[str], Awaitable[None]],
include_thinking: bool = False, include_thinking: bool = False,
model: Llm = Llm.CLAUDE_3_OPUS, model: Llm = Llm.CLAUDE_3_SONNET_BEDROCK,
) -> str: ) -> str:
bedrock_runtime = initialize_bedrock_client(access_key, secret_access_key, aws_region_name)
try:
# Initialize the Bedrock Runtime client
bedrock_runtime = boto3.client(
service_name='bedrock-runtime',
aws_access_key_id=access_key,
aws_secret_access_key=secret_access_key,
)
# Set model parameters # Set model parameters
# model_id = model.value model_id = Llm.CLAUDE_3_SONNET_BEDROCK.value
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
max_tokens = 4096 max_tokens = 4096
content_type = 'application/json' content_type = 'application/json'
accept = '*/*' accept = '*/*'
@ -334,10 +363,6 @@ async def stream_claude_response_native_aws_bedrock(
prefix = "<thinking>" prefix = "<thinking>"
response = None response = None
# For debugging
full_stream = ""
debug_file_writer = DebugFileWriter()
while current_pass_num <= max_passes: while current_pass_num <= max_passes:
current_pass_num += 1 current_pass_num += 1
@ -348,41 +373,17 @@ async def stream_claude_response_native_aws_bedrock(
else messages else messages
) )
pprint_prompt(messages_to_send) response_text = await stream_bedrock_response(
bedrock_runtime,
# Prepare the request body messages_to_send,
body = json.dumps({ system_prompt,
"anthropic_version": "bedrock-2023-05-31", model_id,
"max_tokens": max_tokens, max_tokens,
"messages": messages_to_send, content_type,
"system": system_prompt, accept,
"temperature": temperature temperature,
}) callback,
# Invoke the Bedrock Runtime API with response stream
response = bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=accept,
contentType=content_type,
) )
stream = response.get("body")
# Stream the response
response_text = ""
if stream:
for event in stream:
chunk = event.get("chunk")
if chunk:
data = chunk.get("bytes").decode()
chunk_obj = json.loads(data)
if chunk_obj["type"] == "content_block_delta":
text = chunk_obj["delta"]["text"]
response_text += text
print(text, end="", flush=True)
full_stream += text
await callback(text)
# Set up messages array for next pass # Set up messages array for next pass
messages += [ messages += [
@ -393,18 +394,4 @@ async def stream_claude_response_native_aws_bedrock(
}, },
] ]
# print( return response_text
# f"Token usage: Input Tokens: {response.usage.input_tokens}, Output Tokens: {response.usage.output_tokens}"
# )
if not response:
raise Exception("No HTML response found in AI response")
else:
return full_stream
except ClientError as err:
message = err.response["Error"]["Message"]
print(f"A client error occurred: {message}")
except Exception as err:
print("An error occurred!")
raise err

View File

@ -2,7 +2,7 @@ import os
import traceback import traceback
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import openai import openai
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE, AWS_ACCESS_KEY, AWS_SECRET_ACCESS_KEY from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE, AWS_ACCESS_KEY, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME
from custom_types import InputMode from custom_types import InputMode
from llm import ( from llm import (
Llm, Llm,
@ -25,7 +25,6 @@ from prompts.types import Stack
from video.utils import extract_tag_content, assemble_claude_prompt_video from video.utils import extract_tag_content, assemble_claude_prompt_video
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
router = APIRouter() router = APIRouter()
@ -239,6 +238,7 @@ async def stream_code(websocket: WebSocket):
messages=prompt_messages, # type: ignore messages=prompt_messages, # type: ignore
access_key=AWS_ACCESS_KEY, access_key=AWS_ACCESS_KEY,
secret_access_key=AWS_SECRET_ACCESS_KEY, secret_access_key=AWS_SECRET_ACCESS_KEY,
aws_region_name=AWS_REGION_NAME,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
model=Llm.CLAUDE_3_OPUS, model=Llm.CLAUDE_3_OPUS,
include_thinking=True, include_thinking=True,
@ -261,6 +261,7 @@ async def stream_code(websocket: WebSocket):
prompt_messages, # type: ignore prompt_messages, # type: ignore
access_key=AWS_ACCESS_KEY, access_key=AWS_ACCESS_KEY,
secret_access_key=AWS_SECRET_ACCESS_KEY, secret_access_key=AWS_SECRET_ACCESS_KEY,
aws_region_name=AWS_REGION_NAME,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
) )
exact_llm_version = code_generation_model exact_llm_version = code_generation_model