AWS Bedrock support
This commit is contained in:
parent
f9c4dd9c7c
commit
7032342f91
@ -4,6 +4,8 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
|
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
|
||||||
|
AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None)
|
||||||
|
AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
|
||||||
|
|
||||||
# Debugging-related
|
# Debugging-related
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
|
from config import AWS_ACCESS_KEY
|
||||||
|
from config import AWS_SECRET_ACCESS_KEY
|
||||||
from config import ANTHROPIC_API_KEY
|
from config import ANTHROPIC_API_KEY
|
||||||
|
|
||||||
from llm import Llm, stream_claude_response, stream_openai_response
|
from llm import Llm, stream_claude_response, stream_openai_response, stream_claude_response_aws_bedrock
|
||||||
from prompts import assemble_prompt
|
from prompts import assemble_prompt
|
||||||
from prompts.types import Stack
|
from prompts.types import Stack
|
||||||
|
|
||||||
@ -10,20 +12,30 @@ async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str:
|
|||||||
prompt_messages = assemble_prompt(image_url, stack)
|
prompt_messages = assemble_prompt(image_url, stack)
|
||||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
anthropic_api_key = ANTHROPIC_API_KEY
|
anthropic_api_key = ANTHROPIC_API_KEY
|
||||||
|
aws_access_key = AWS_ACCESS_KEY
|
||||||
|
aws_secret_access_key = AWS_SECRET_ACCESS_KEY
|
||||||
openai_base_url = None
|
openai_base_url = None
|
||||||
|
|
||||||
async def process_chunk(content: str):
|
async def process_chunk(content: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if model == Llm.CLAUDE_3_SONNET:
|
if model == Llm.CLAUDE_3_SONNET:
|
||||||
if not anthropic_api_key:
|
if not anthropic_api_key and not aws_access_key and not aws_secret_access_key:
|
||||||
raise Exception("Anthropic API key not found")
|
raise Exception("Anthropic API key or AWS Access Key not found")
|
||||||
|
|
||||||
|
if anthropic_api_key:
|
||||||
completion = await stream_claude_response(
|
completion = await stream_claude_response(
|
||||||
prompt_messages,
|
prompt_messages,
|
||||||
api_key=anthropic_api_key,
|
api_key=anthropic_api_key,
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x: process_chunk(x),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
completion = await stream_claude_response_aws_bedrock(
|
||||||
|
prompt_messages,
|
||||||
|
access_key=aws_access_key,
|
||||||
|
secret_access_key=aws_secret_access_key,
|
||||||
|
callback=lambda x: process_chunk(x),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if not openai_api_key:
|
if not openai_api_key:
|
||||||
raise Exception("OpenAI API key not found")
|
raise Exception("OpenAI API key not found")
|
||||||
|
|||||||
200
backend/llm.py
200
backend/llm.py
@ -5,7 +5,10 @@ from openai import AsyncOpenAI
|
|||||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
||||||
from config import IS_DEBUG_ENABLED
|
from config import IS_DEBUG_ENABLED
|
||||||
from debug.DebugFileWriter import DebugFileWriter
|
from debug.DebugFileWriter import DebugFileWriter
|
||||||
|
import json
|
||||||
|
import boto3
|
||||||
|
from typing import List
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
from utils import pprint_prompt
|
from utils import pprint_prompt
|
||||||
|
|
||||||
|
|
||||||
@ -123,6 +126,93 @@ async def stream_claude_response(
|
|||||||
return response.content[0].text
|
return response.content[0].text
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_claude_response_aws_bedrock(
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
access_key: str,
|
||||||
|
secret_access_key: str,
|
||||||
|
callback: Callable[[str], Awaitable[None]],
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize the Bedrock Runtime client
|
||||||
|
bedrock_runtime = boto3.client(
|
||||||
|
service_name='bedrock-runtime',
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_access_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set model parameters
|
||||||
|
model_id = 'anthropic.claude-3-sonnet-20240229-v1:0'
|
||||||
|
max_tokens = 4096
|
||||||
|
content_type = 'application/json'
|
||||||
|
accept = '*/*'
|
||||||
|
|
||||||
|
# Translate OpenAI messages to Claude messages
|
||||||
|
system_prompt = cast(str, messages[0].get("content"))
|
||||||
|
claude_messages = [dict(message) for message in messages[1:]]
|
||||||
|
for message in claude_messages:
|
||||||
|
if not isinstance(message["content"], list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for content in message["content"]: # type: ignore
|
||||||
|
if content["type"] == "image_url":
|
||||||
|
content["type"] = "image"
|
||||||
|
|
||||||
|
# Extract base64 data and media type from data URL
|
||||||
|
# Example base64 data URL: data:image/png;base64,iVBOR...
|
||||||
|
image_data_url = cast(str, content["image_url"]["url"])
|
||||||
|
media_type = image_data_url.split(";")[0].split(":")[1]
|
||||||
|
base64_data = image_data_url.split(",")[1]
|
||||||
|
|
||||||
|
# Remove OpenAI parameter
|
||||||
|
del content["image_url"]
|
||||||
|
|
||||||
|
content["source"] = {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": media_type,
|
||||||
|
"data": base64_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prepare the request body
|
||||||
|
body = json.dumps({
|
||||||
|
"anthropic_version": "bedrock-2023-05-31",
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"messages": claude_messages,
|
||||||
|
"system": system_prompt,
|
||||||
|
"temperature": 0.0
|
||||||
|
})
|
||||||
|
|
||||||
|
# Invoke the Bedrock Runtime API with response stream
|
||||||
|
response = bedrock_runtime.invoke_model_with_response_stream(
|
||||||
|
body=body,
|
||||||
|
modelId=model_id,
|
||||||
|
accept=accept,
|
||||||
|
contentType=content_type,
|
||||||
|
)
|
||||||
|
stream = response.get("body")
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
final_message = ""
|
||||||
|
if stream:
|
||||||
|
for event in stream:
|
||||||
|
chunk = event.get("chunk")
|
||||||
|
if chunk:
|
||||||
|
data = chunk.get("bytes").decode()
|
||||||
|
chunk_obj = json.loads(data)
|
||||||
|
if chunk_obj["type"] == "content_block_delta":
|
||||||
|
text = chunk_obj["delta"]["text"]
|
||||||
|
await callback(text)
|
||||||
|
final_message += text
|
||||||
|
|
||||||
|
return final_message
|
||||||
|
|
||||||
|
except ClientError as err:
|
||||||
|
message = err.response["Error"]["Message"]
|
||||||
|
print(f"A client error occurred: {message}")
|
||||||
|
except Exception as err:
|
||||||
|
print("An error occurred!")
|
||||||
|
raise err
|
||||||
|
|
||||||
async def stream_claude_response_native(
|
async def stream_claude_response_native(
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
messages: list[Any],
|
messages: list[Any],
|
||||||
@ -210,3 +300,111 @@ async def stream_claude_response_native(
|
|||||||
raise Exception("No HTML response found in AI response")
|
raise Exception("No HTML response found in AI response")
|
||||||
else:
|
else:
|
||||||
return response.content[0].text
|
return response.content[0].text
|
||||||
|
|
||||||
|
async def stream_claude_response_native_aws_bedrock(
|
||||||
|
system_prompt: str,
|
||||||
|
messages: list[Any],
|
||||||
|
access_key: str,
|
||||||
|
secret_access_key: str,
|
||||||
|
callback: Callable[[str], Awaitable[None]],
|
||||||
|
include_thinking: bool = False,
|
||||||
|
model: Llm = Llm.CLAUDE_3_OPUS,
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize the Bedrock Runtime client
|
||||||
|
bedrock_runtime = boto3.client(
|
||||||
|
service_name='bedrock-runtime',
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_access_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set model parameters
|
||||||
|
# model_id = model.value
|
||||||
|
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||||
|
max_tokens = 4096
|
||||||
|
content_type = 'application/json'
|
||||||
|
accept = '*/*'
|
||||||
|
temperature = 0.0
|
||||||
|
|
||||||
|
# Multi-pass flow
|
||||||
|
current_pass_num = 1
|
||||||
|
max_passes = 2
|
||||||
|
|
||||||
|
prefix = "<thinking>"
|
||||||
|
response = None
|
||||||
|
|
||||||
|
# For debugging
|
||||||
|
full_stream = ""
|
||||||
|
debug_file_writer = DebugFileWriter()
|
||||||
|
|
||||||
|
while current_pass_num <= max_passes:
|
||||||
|
current_pass_num += 1
|
||||||
|
|
||||||
|
# Set up message depending on whether we have a <thinking> prefix
|
||||||
|
messages_to_send = (
|
||||||
|
messages + [{"role": "assistant", "content": prefix}]
|
||||||
|
if include_thinking
|
||||||
|
else messages
|
||||||
|
)
|
||||||
|
|
||||||
|
pprint_prompt(messages_to_send)
|
||||||
|
|
||||||
|
# Prepare the request body
|
||||||
|
body = json.dumps({
|
||||||
|
"anthropic_version": "bedrock-2023-05-31",
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"messages": messages_to_send,
|
||||||
|
"system": system_prompt,
|
||||||
|
"temperature": temperature
|
||||||
|
})
|
||||||
|
|
||||||
|
# Invoke the Bedrock Runtime API with response stream
|
||||||
|
response = bedrock_runtime.invoke_model_with_response_stream(
|
||||||
|
body=body,
|
||||||
|
modelId=model_id,
|
||||||
|
accept=accept,
|
||||||
|
contentType=content_type,
|
||||||
|
)
|
||||||
|
stream = response.get("body")
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
response_text = ""
|
||||||
|
if stream:
|
||||||
|
for event in stream:
|
||||||
|
chunk = event.get("chunk")
|
||||||
|
if chunk:
|
||||||
|
data = chunk.get("bytes").decode()
|
||||||
|
chunk_obj = json.loads(data)
|
||||||
|
if chunk_obj["type"] == "content_block_delta":
|
||||||
|
text = chunk_obj["delta"]["text"]
|
||||||
|
response_text += text
|
||||||
|
print(text, end="", flush=True)
|
||||||
|
full_stream += text
|
||||||
|
await callback(text)
|
||||||
|
|
||||||
|
|
||||||
|
# Set up messages array for next pass
|
||||||
|
messages += [
|
||||||
|
{"role": "assistant", "content": str(prefix) + response_text},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "You've done a good job with a first draft. Improve this further based on the original instructions so that the app is fully functional and looks like the original video of the app we're trying to replicate.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# print(
|
||||||
|
# f"Token usage: Input Tokens: {response.usage.input_tokens}, Output Tokens: {response.usage.output_tokens}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
if not response:
|
||||||
|
raise Exception("No HTML response found in AI response")
|
||||||
|
else:
|
||||||
|
return full_stream
|
||||||
|
|
||||||
|
except ClientError as err:
|
||||||
|
message = err.response["Error"]["Message"]
|
||||||
|
print(f"A client error occurred: {message}")
|
||||||
|
except Exception as err:
|
||||||
|
print("An error occurred!")
|
||||||
|
raise err
|
||||||
@ -2,14 +2,14 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, WebSocket
|
||||||
import openai
|
import openai
|
||||||
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE, AWS_ACCESS_KEY, AWS_SECRET_ACCESS_KEY
|
||||||
from custom_types import InputMode
|
from custom_types import InputMode
|
||||||
from llm import (
|
from llm import (
|
||||||
Llm,
|
Llm,
|
||||||
convert_frontend_str_to_llm,
|
convert_frontend_str_to_llm,
|
||||||
stream_claude_response,
|
stream_claude_response,
|
||||||
stream_claude_response_native,
|
stream_claude_response_native,
|
||||||
stream_openai_response,
|
stream_openai_response, stream_claude_response_aws_bedrock, stream_claude_response_native_aws_bedrock,
|
||||||
)
|
)
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
from mock_llm import mock_completion
|
from mock_llm import mock_completion
|
||||||
@ -218,12 +218,13 @@ async def stream_code(websocket: WebSocket):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
if validated_input_mode == "video":
|
if validated_input_mode == "video":
|
||||||
if not ANTHROPIC_API_KEY:
|
if not ANTHROPIC_API_KEY and not AWS_ACCESS_KEY and not AWS_SECRET_ACCESS_KEY:
|
||||||
await throw_error(
|
await throw_error(
|
||||||
"Video only works with Anthropic models. No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env"
|
"Video only works with Anthropic models. Neither Anthropic API key or AWS Access Key found. Please add the environment variable ANTHROPIC_API_KEY or AWS_ACCESS_KEY/AWS_SECRET_ACCESS_KEY to backend/.env"
|
||||||
)
|
)
|
||||||
raise Exception("No Anthropic key")
|
raise Exception("No Anthropic key")
|
||||||
|
|
||||||
|
if ANTHROPIC_API_KEY:
|
||||||
completion = await stream_claude_response_native(
|
completion = await stream_claude_response_native(
|
||||||
system_prompt=VIDEO_PROMPT,
|
system_prompt=VIDEO_PROMPT,
|
||||||
messages=prompt_messages, # type: ignore
|
messages=prompt_messages, # type: ignore
|
||||||
@ -232,19 +233,36 @@ async def stream_code(websocket: WebSocket):
|
|||||||
model=Llm.CLAUDE_3_OPUS,
|
model=Llm.CLAUDE_3_OPUS,
|
||||||
include_thinking=True,
|
include_thinking=True,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
completion = await stream_claude_response_native_aws_bedrock(
|
||||||
|
system_prompt=VIDEO_PROMPT,
|
||||||
|
messages=prompt_messages, # type: ignore
|
||||||
|
access_key=AWS_ACCESS_KEY,
|
||||||
|
secret_access_key=AWS_SECRET_ACCESS_KEY,
|
||||||
|
callback=lambda x: process_chunk(x),
|
||||||
|
model=Llm.CLAUDE_3_OPUS,
|
||||||
|
include_thinking=True,
|
||||||
|
)
|
||||||
exact_llm_version = Llm.CLAUDE_3_OPUS
|
exact_llm_version = Llm.CLAUDE_3_OPUS
|
||||||
elif code_generation_model == Llm.CLAUDE_3_SONNET:
|
elif code_generation_model == Llm.CLAUDE_3_SONNET:
|
||||||
if not ANTHROPIC_API_KEY:
|
if not ANTHROPIC_API_KEY and not AWS_ACCESS_KEY and not AWS_SECRET_ACCESS_KEY:
|
||||||
await throw_error(
|
await throw_error(
|
||||||
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env"
|
"No Anthropic API key or AWS Access Key found. Please add the environment variable ANTHROPIC_API_KEY or AWS_ACCESS_KEY/AWS_SECRET_ACCESS_KEY to backend/.env"
|
||||||
)
|
)
|
||||||
raise Exception("No Anthropic key")
|
raise Exception("No Anthropic key")
|
||||||
|
if ANTHROPIC_API_KEY:
|
||||||
completion = await stream_claude_response(
|
completion = await stream_claude_response(
|
||||||
prompt_messages, # type: ignore
|
prompt_messages, # type: ignore
|
||||||
api_key=ANTHROPIC_API_KEY,
|
api_key=ANTHROPIC_API_KEY,
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x: process_chunk(x),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
completion = await stream_claude_response_aws_bedrock(
|
||||||
|
prompt_messages, # type: ignore
|
||||||
|
access_key=AWS_ACCESS_KEY,
|
||||||
|
secret_access_key=AWS_SECRET_ACCESS_KEY,
|
||||||
|
callback=lambda x: process_chunk(x),
|
||||||
|
)
|
||||||
exact_llm_version = code_generation_model
|
exact_llm_version = code_generation_model
|
||||||
else:
|
else:
|
||||||
completion = await stream_openai_response(
|
completion = await stream_openai_response(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user