AWS Bedrock support

This commit is contained in:
yoreland 2024-04-28 08:57:52 +08:00
parent f9c4dd9c7c
commit 7032342f91
4 changed files with 305 additions and 75 deletions

View File

@ -4,6 +4,8 @@
import os
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None)
AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
# Debugging-related

View File

@ -1,7 +1,9 @@
import os
from config import AWS_ACCESS_KEY
from config import AWS_SECRET_ACCESS_KEY
from config import ANTHROPIC_API_KEY
from llm import Llm, stream_claude_response, stream_openai_response
from llm import Llm, stream_claude_response, stream_openai_response, stream_claude_response_aws_bedrock
from prompts import assemble_prompt
from prompts.types import Stack
@ -10,20 +12,30 @@ async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str:
prompt_messages = assemble_prompt(image_url, stack)
openai_api_key = os.environ.get("OPENAI_API_KEY")
anthropic_api_key = ANTHROPIC_API_KEY
aws_access_key = AWS_ACCESS_KEY
aws_secret_access_key = AWS_SECRET_ACCESS_KEY
openai_base_url = None
async def process_chunk(content: str):
pass
if model == Llm.CLAUDE_3_SONNET:
if not anthropic_api_key:
raise Exception("Anthropic API key not found")
if not anthropic_api_key and not aws_access_key and not aws_secret_access_key:
raise Exception("Anthropic API key or AWS Access Key not found")
completion = await stream_claude_response(
prompt_messages,
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
)
if anthropic_api_key:
completion = await stream_claude_response(
prompt_messages,
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
)
else:
completion = await stream_claude_response_aws_bedrock(
prompt_messages,
access_key=aws_access_key,
secret_access_key=aws_secret_access_key,
callback=lambda x: process_chunk(x),
)
else:
if not openai_api_key:
raise Exception("OpenAI API key not found")

View File

@ -5,7 +5,10 @@ from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
from config import IS_DEBUG_ENABLED
from debug.DebugFileWriter import DebugFileWriter
import json
import boto3
from typing import List
from botocore.exceptions import ClientError
from utils import pprint_prompt
@ -29,11 +32,11 @@ def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
async def stream_openai_response(
messages: List[ChatCompletionMessageParam],
api_key: str,
base_url: str | None,
callback: Callable[[str], Awaitable[None]],
model: Llm,
messages: List[ChatCompletionMessageParam],
api_key: str,
base_url: str | None,
callback: Callable[[str], Awaitable[None]],
model: Llm,
) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
@ -65,9 +68,9 @@ async def stream_openai_response(
# TODO: Have a seperate function that translates OpenAI messages to Claude messages
async def stream_claude_response(
messages: List[ChatCompletionMessageParam],
api_key: str,
callback: Callable[[str], Awaitable[None]],
messages: List[ChatCompletionMessageParam],
api_key: str,
callback: Callable[[str], Awaitable[None]],
) -> str:
client = AsyncAnthropic(api_key=api_key)
@ -105,11 +108,11 @@ async def stream_claude_response(
# Stream Claude response
async with client.messages.stream(
model=model.value,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt,
messages=claude_messages, # type: ignore
model=model.value,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt,
messages=claude_messages, # type: ignore
) as stream:
async for text in stream.text_stream:
await callback(text)
@ -123,13 +126,100 @@ async def stream_claude_response(
return response.content[0].text
async def stream_claude_response_aws_bedrock(
messages: List[ChatCompletionMessageParam],
access_key: str,
secret_access_key: str,
callback: Callable[[str], Awaitable[None]],
) -> str:
try:
# Initialize the Bedrock Runtime client
bedrock_runtime = boto3.client(
service_name='bedrock-runtime',
aws_access_key_id=access_key,
aws_secret_access_key=secret_access_key,
)
# Set model parameters
model_id = 'anthropic.claude-3-sonnet-20240229-v1:0'
max_tokens = 4096
content_type = 'application/json'
accept = '*/*'
# Translate OpenAI messages to Claude messages
system_prompt = cast(str, messages[0].get("content"))
claude_messages = [dict(message) for message in messages[1:]]
for message in claude_messages:
if not isinstance(message["content"], list):
continue
for content in message["content"]: # type: ignore
if content["type"] == "image_url":
content["type"] = "image"
# Extract base64 data and media type from data URL
# Example base64 data URL: data:image/png;base64,iVBOR...
image_data_url = cast(str, content["image_url"]["url"])
media_type = image_data_url.split(";")[0].split(":")[1]
base64_data = image_data_url.split(",")[1]
# Remove OpenAI parameter
del content["image_url"]
content["source"] = {
"type": "base64",
"media_type": media_type,
"data": base64_data,
}
# Prepare the request body
body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"messages": claude_messages,
"system": system_prompt,
"temperature": 0.0
})
# Invoke the Bedrock Runtime API with response stream
response = bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=accept,
contentType=content_type,
)
stream = response.get("body")
# Stream the response
final_message = ""
if stream:
for event in stream:
chunk = event.get("chunk")
if chunk:
data = chunk.get("bytes").decode()
chunk_obj = json.loads(data)
if chunk_obj["type"] == "content_block_delta":
text = chunk_obj["delta"]["text"]
await callback(text)
final_message += text
return final_message
except ClientError as err:
message = err.response["Error"]["Message"]
print(f"A client error occurred: {message}")
except Exception as err:
print("An error occurred!")
raise err
async def stream_claude_response_native(
system_prompt: str,
messages: list[Any],
api_key: str,
callback: Callable[[str], Awaitable[None]],
include_thinking: bool = False,
model: Llm = Llm.CLAUDE_3_OPUS,
system_prompt: str,
messages: list[Any],
api_key: str,
callback: Callable[[str], Awaitable[None]],
include_thinking: bool = False,
model: Llm = Llm.CLAUDE_3_OPUS,
) -> str:
client = AsyncAnthropic(api_key=api_key)
@ -162,11 +252,11 @@ async def stream_claude_response_native(
pprint_prompt(messages_to_send)
async with client.messages.stream(
model=model.value,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt,
messages=messages_to_send, # type: ignore
model=model.value,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt,
messages=messages_to_send, # type: ignore
) as stream:
async for text in stream.text_stream:
print(text, end="", flush=True)
@ -210,3 +300,111 @@ async def stream_claude_response_native(
raise Exception("No HTML response found in AI response")
else:
return response.content[0].text
async def stream_claude_response_native_aws_bedrock(
system_prompt: str,
messages: list[Any],
access_key: str,
secret_access_key: str,
callback: Callable[[str], Awaitable[None]],
include_thinking: bool = False,
model: Llm = Llm.CLAUDE_3_OPUS,
) -> str:
try:
# Initialize the Bedrock Runtime client
bedrock_runtime = boto3.client(
service_name='bedrock-runtime',
aws_access_key_id=access_key,
aws_secret_access_key=secret_access_key,
)
# Set model parameters
# model_id = model.value
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
max_tokens = 4096
content_type = 'application/json'
accept = '*/*'
temperature = 0.0
# Multi-pass flow
current_pass_num = 1
max_passes = 2
prefix = "<thinking>"
response = None
# For debugging
full_stream = ""
debug_file_writer = DebugFileWriter()
while current_pass_num <= max_passes:
current_pass_num += 1
# Set up message depending on whether we have a <thinking> prefix
messages_to_send = (
messages + [{"role": "assistant", "content": prefix}]
if include_thinking
else messages
)
pprint_prompt(messages_to_send)
# Prepare the request body
body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"messages": messages_to_send,
"system": system_prompt,
"temperature": temperature
})
# Invoke the Bedrock Runtime API with response stream
response = bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=accept,
contentType=content_type,
)
stream = response.get("body")
# Stream the response
response_text = ""
if stream:
for event in stream:
chunk = event.get("chunk")
if chunk:
data = chunk.get("bytes").decode()
chunk_obj = json.loads(data)
if chunk_obj["type"] == "content_block_delta":
text = chunk_obj["delta"]["text"]
response_text += text
print(text, end="", flush=True)
full_stream += text
await callback(text)
# Set up messages array for next pass
messages += [
{"role": "assistant", "content": str(prefix) + response_text},
{
"role": "user",
"content": "You've done a good job with a first draft. Improve this further based on the original instructions so that the app is fully functional and looks like the original video of the app we're trying to replicate.",
},
]
# print(
# f"Token usage: Input Tokens: {response.usage.input_tokens}, Output Tokens: {response.usage.output_tokens}"
# )
if not response:
raise Exception("No HTML response found in AI response")
else:
return full_stream
except ClientError as err:
message = err.response["Error"]["Message"]
print(f"A client error occurred: {message}")
except Exception as err:
print("An error occurred!")
raise err

View File

@ -2,14 +2,14 @@ import os
import traceback
from fastapi import APIRouter, WebSocket
import openai
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE, AWS_ACCESS_KEY, AWS_SECRET_ACCESS_KEY
from custom_types import InputMode
from llm import (
Llm,
convert_frontend_str_to_llm,
stream_claude_response,
stream_claude_response_native,
stream_openai_response,
stream_openai_response, stream_claude_response_aws_bedrock, stream_claude_response_native_aws_bedrock,
)
from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
@ -55,7 +55,7 @@ async def stream_code(websocket: WebSocket):
print("Incoming websocket connection...")
async def throw_error(
message: str,
message: str,
):
await websocket.send_json({"type": "error", "value": message})
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
@ -110,8 +110,8 @@ async def stream_code(websocket: WebSocket):
print("Using OpenAI API key from environment variable")
if not openai_api_key and (
code_generation_model == Llm.GPT_4_VISION
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
code_generation_model == Llm.GPT_4_VISION
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
):
print("OpenAI API key not found")
await throw_error(
@ -218,33 +218,51 @@ async def stream_code(websocket: WebSocket):
else:
try:
if validated_input_mode == "video":
if not ANTHROPIC_API_KEY:
if not ANTHROPIC_API_KEY and not AWS_ACCESS_KEY and not AWS_SECRET_ACCESS_KEY:
await throw_error(
"Video only works with Anthropic models. No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env"
"Video only works with Anthropic models. Neither Anthropic API key or AWS Access Key found. Please add the environment variable ANTHROPIC_API_KEY or AWS_ACCESS_KEY/AWS_SECRET_ACCESS_KEY to backend/.env"
)
raise Exception("No Anthropic key")
completion = await stream_claude_response_native(
system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore
api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x),
model=Llm.CLAUDE_3_OPUS,
include_thinking=True,
)
if ANTHROPIC_API_KEY:
completion = await stream_claude_response_native(
system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore
api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x),
model=Llm.CLAUDE_3_OPUS,
include_thinking=True,
)
else:
completion = await stream_claude_response_native_aws_bedrock(
system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore
access_key=AWS_ACCESS_KEY,
secret_access_key=AWS_SECRET_ACCESS_KEY,
callback=lambda x: process_chunk(x),
model=Llm.CLAUDE_3_OPUS,
include_thinking=True,
)
exact_llm_version = Llm.CLAUDE_3_OPUS
elif code_generation_model == Llm.CLAUDE_3_SONNET:
if not ANTHROPIC_API_KEY:
if not ANTHROPIC_API_KEY and not AWS_ACCESS_KEY and not AWS_SECRET_ACCESS_KEY:
await throw_error(
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env"
"No Anthropic API key or AWS Access Key found. Please add the environment variable ANTHROPIC_API_KEY or AWS_ACCESS_KEY/AWS_SECRET_ACCESS_KEY to backend/.env"
)
raise Exception("No Anthropic key")
completion = await stream_claude_response(
prompt_messages, # type: ignore
api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x),
)
if ANTHROPIC_API_KEY:
completion = await stream_claude_response(
prompt_messages, # type: ignore
api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x),
)
else:
completion = await stream_claude_response_aws_bedrock(
prompt_messages, # type: ignore
access_key=AWS_ACCESS_KEY,
secret_access_key=AWS_SECRET_ACCESS_KEY,
callback=lambda x: process_chunk(x),
)
exact_llm_version = code_generation_model
else:
completion = await stream_openai_response(
@ -258,35 +276,35 @@ async def stream_code(websocket: WebSocket):
except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e)
error_message = (
"Incorrect OpenAI key. Please make sure your OpenAI API key is correct, or create a new OpenAI API key on your OpenAI dashboard."
+ (
" Alternatively, you can purchase code generation credits directly on this website."
if IS_PROD
else ""
)
"Incorrect OpenAI key. Please make sure your OpenAI API key is correct, or create a new OpenAI API key on your OpenAI dashboard."
+ (
" Alternatively, you can purchase code generation credits directly on this website."
if IS_PROD
else ""
)
)
return await throw_error(error_message)
except openai.NotFoundError as e:
print("[GENERATE_CODE] Model not found", e)
error_message = (
e.message
+ ". Please make sure you have followed the instructions correctly to obtain an OpenAI key with GPT vision access: https://github.com/abi/screenshot-to-code/blob/main/Troubleshooting.md"
+ (
" Alternatively, you can purchase code generation credits directly on this website."
if IS_PROD
else ""
)
e.message
+ ". Please make sure you have followed the instructions correctly to obtain an OpenAI key with GPT vision access: https://github.com/abi/screenshot-to-code/blob/main/Troubleshooting.md"
+ (
" Alternatively, you can purchase code generation credits directly on this website."
if IS_PROD
else ""
)
)
return await throw_error(error_message)
except openai.RateLimitError as e:
print("[GENERATE_CODE] Rate limit exceeded", e)
error_message = (
"OpenAI error - 'You exceeded your current quota, please check your plan and billing details.'"
+ (
" Alternatively, you can purchase code generation credits directly on this website."
if IS_PROD
else ""
)
"OpenAI error - 'You exceeded your current quota, please check your plan and billing details.'"
+ (
" Alternatively, you can purchase code generation credits directly on this website."
if IS_PROD
else ""
)
)
return await throw_error(error_message)