AWS Bedrock support - adjust code style
This commit is contained in:
parent
7032342f91
commit
afd34dfaad
@ -4,8 +4,11 @@
|
||||
import os
|
||||
|
||||
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
|
||||
|
||||
# AWS
|
||||
AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None)
|
||||
AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME", "us-west-2")
|
||||
|
||||
# Debugging-related
|
||||
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import os
|
||||
|
||||
from config import AWS_REGION_NAME
|
||||
from config import AWS_ACCESS_KEY
|
||||
from config import AWS_SECRET_ACCESS_KEY
|
||||
from config import ANTHROPIC_API_KEY
|
||||
@ -14,6 +16,7 @@ async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str:
|
||||
anthropic_api_key = ANTHROPIC_API_KEY
|
||||
aws_access_key = AWS_ACCESS_KEY
|
||||
aws_secret_access_key = AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name = AWS_REGION_NAME,
|
||||
openai_base_url = None
|
||||
|
||||
async def process_chunk(content: str):
|
||||
@ -34,6 +37,7 @@ async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str:
|
||||
prompt_messages,
|
||||
access_key=aws_access_key,
|
||||
secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
else:
|
||||
|
||||
309
backend/llm.py
309
backend/llm.py
@ -17,6 +17,7 @@ class Llm(Enum):
|
||||
GPT_4_VISION = "gpt-4-vision-preview"
|
||||
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09"
|
||||
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
||||
CLAUDE_3_SONNET_BEDROCK = "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
CLAUDE_3_OPUS = "claude-3-opus-20240229"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
|
||||
@ -32,11 +33,11 @@ def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
|
||||
|
||||
|
||||
async def stream_openai_response(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
api_key: str,
|
||||
base_url: str | None,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
model: Llm,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
api_key: str,
|
||||
base_url: str | None,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
model: Llm,
|
||||
) -> str:
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
@ -68,9 +69,9 @@ async def stream_openai_response(
|
||||
|
||||
# TODO: Have a seperate function that translates OpenAI messages to Claude messages
|
||||
async def stream_claude_response(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
api_key: str,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
api_key: str,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
) -> str:
|
||||
|
||||
client = AsyncAnthropic(api_key=api_key)
|
||||
@ -108,11 +109,11 @@ async def stream_claude_response(
|
||||
|
||||
# Stream Claude response
|
||||
async with client.messages.stream(
|
||||
model=model.value,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_prompt,
|
||||
messages=claude_messages, # type: ignore
|
||||
model=model.value,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_prompt,
|
||||
messages=claude_messages, # type: ignore
|
||||
) as stream:
|
||||
async for text in stream.text_stream:
|
||||
await callback(text)
|
||||
@ -125,61 +126,42 @@ async def stream_claude_response(
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
|
||||
async def stream_claude_response_aws_bedrock(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
access_key: str,
|
||||
secret_access_key: str,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
) -> str:
|
||||
|
||||
def initialize_bedrock_client(access_key: str, secret_access_key: str, aws_region_name: str):
|
||||
try:
|
||||
# Initialize the Bedrock Runtime client
|
||||
bedrock_runtime = boto3.client(
|
||||
service_name='bedrock-runtime',
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
return bedrock_runtime
|
||||
except ClientError as err:
|
||||
message = err.response["Error"]["Message"]
|
||||
print(f"A client error occurred: {message}")
|
||||
except Exception as err:
|
||||
print("An error occurred!")
|
||||
raise err
|
||||
|
||||
# Set model parameters
|
||||
model_id = 'anthropic.claude-3-sonnet-20240229-v1:0'
|
||||
max_tokens = 4096
|
||||
content_type = 'application/json'
|
||||
accept = '*/*'
|
||||
|
||||
# Translate OpenAI messages to Claude messages
|
||||
system_prompt = cast(str, messages[0].get("content"))
|
||||
claude_messages = [dict(message) for message in messages[1:]]
|
||||
for message in claude_messages:
|
||||
if not isinstance(message["content"], list):
|
||||
continue
|
||||
|
||||
for content in message["content"]: # type: ignore
|
||||
if content["type"] == "image_url":
|
||||
content["type"] = "image"
|
||||
|
||||
# Extract base64 data and media type from data URL
|
||||
# Example base64 data URL: data:image/png;base64,iVBOR...
|
||||
image_data_url = cast(str, content["image_url"]["url"])
|
||||
media_type = image_data_url.split(";")[0].split(":")[1]
|
||||
base64_data = image_data_url.split(",")[1]
|
||||
|
||||
# Remove OpenAI parameter
|
||||
del content["image_url"]
|
||||
|
||||
content["source"] = {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data,
|
||||
}
|
||||
|
||||
async def stream_bedrock_response(
|
||||
bedrock_runtime,
|
||||
messages: List[dict],
|
||||
system_prompt: str,
|
||||
model_id: str,
|
||||
max_tokens: int,
|
||||
content_type: str,
|
||||
accept: str,
|
||||
temperature: float,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
) -> str:
|
||||
try:
|
||||
# Prepare the request body
|
||||
body = json.dumps({
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": max_tokens,
|
||||
"messages": claude_messages,
|
||||
"system": system_prompt,
|
||||
"temperature": 0.0
|
||||
"messages": messages,
|
||||
"system": system_prompt,
|
||||
"temperature": temperature
|
||||
})
|
||||
|
||||
# Invoke the Bedrock Runtime API with response stream
|
||||
@ -213,13 +195,67 @@ async def stream_claude_response_aws_bedrock(
|
||||
print("An error occurred!")
|
||||
raise err
|
||||
|
||||
async def stream_claude_response_native(
|
||||
system_prompt: str,
|
||||
messages: list[Any],
|
||||
api_key: str,
|
||||
async def stream_claude_response_aws_bedrock(
|
||||
messages: List[dict],
|
||||
access_key: str,
|
||||
secret_access_key: str,
|
||||
aws_region_name: str,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
include_thinking: bool = False,
|
||||
model: Llm = Llm.CLAUDE_3_OPUS,
|
||||
) -> str:
|
||||
bedrock_runtime = initialize_bedrock_client(access_key, secret_access_key, aws_region_name)
|
||||
|
||||
# Set model parameters
|
||||
model_id = Llm.CLAUDE_3_SONNET_BEDROCK.value
|
||||
max_tokens = 4096
|
||||
content_type = 'application/json'
|
||||
accept = '*/*'
|
||||
temperature = 0.0
|
||||
|
||||
# Translate OpenAI messages to Claude messages
|
||||
system_prompt = cast(str, messages[0].get("content"))
|
||||
claude_messages = [dict(message) for message in messages[1:]]
|
||||
for message in claude_messages:
|
||||
if not isinstance(message["content"], list):
|
||||
continue
|
||||
|
||||
for content in message["content"]: # type: ignore
|
||||
if content["type"] == "image_url":
|
||||
content["type"] = "image"
|
||||
|
||||
# Extract base64 data and media type from data URL
|
||||
# Example base64 data URL: data:image/png;base64,iVBOR...
|
||||
image_data_url = cast(str, content["image_url"]["url"])
|
||||
media_type = image_data_url.split(";")[0].split(":")[1]
|
||||
base64_data = image_data_url.split(",")[1]
|
||||
|
||||
# Remove OpenAI parameter
|
||||
del content["image_url"]
|
||||
|
||||
content["source"] = {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data,
|
||||
}
|
||||
|
||||
return await stream_bedrock_response(
|
||||
bedrock_runtime,
|
||||
claude_messages,
|
||||
system_prompt,
|
||||
model_id,
|
||||
max_tokens,
|
||||
content_type,
|
||||
accept,
|
||||
temperature,
|
||||
callback,
|
||||
)
|
||||
|
||||
async def stream_claude_response_native(
|
||||
system_prompt: str,
|
||||
messages: list[Any],
|
||||
api_key: str,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
include_thinking: bool = False,
|
||||
model: Llm = Llm.CLAUDE_3_OPUS,
|
||||
) -> str:
|
||||
|
||||
client = AsyncAnthropic(api_key=api_key)
|
||||
@ -252,11 +288,11 @@ async def stream_claude_response_native(
|
||||
pprint_prompt(messages_to_send)
|
||||
|
||||
async with client.messages.stream(
|
||||
model=model.value,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_prompt,
|
||||
messages=messages_to_send, # type: ignore
|
||||
model=model.value,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_prompt,
|
||||
messages=messages_to_send, # type: ignore
|
||||
) as stream:
|
||||
async for text in stream.text_stream:
|
||||
print(text, end="", flush=True)
|
||||
@ -306,105 +342,56 @@ async def stream_claude_response_native_aws_bedrock(
|
||||
messages: list[Any],
|
||||
access_key: str,
|
||||
secret_access_key: str,
|
||||
aws_region_name: str,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
include_thinking: bool = False,
|
||||
model: Llm = Llm.CLAUDE_3_OPUS,
|
||||
model: Llm = Llm.CLAUDE_3_SONNET_BEDROCK,
|
||||
) -> str:
|
||||
bedrock_runtime = initialize_bedrock_client(access_key, secret_access_key, aws_region_name)
|
||||
|
||||
try:
|
||||
# Initialize the Bedrock Runtime client
|
||||
bedrock_runtime = boto3.client(
|
||||
service_name='bedrock-runtime',
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_access_key,
|
||||
# Set model parameters
|
||||
model_id = Llm.CLAUDE_3_SONNET_BEDROCK.value
|
||||
max_tokens = 4096
|
||||
content_type = 'application/json'
|
||||
accept = '*/*'
|
||||
temperature = 0.0
|
||||
|
||||
# Multi-pass flow
|
||||
current_pass_num = 1
|
||||
max_passes = 2
|
||||
|
||||
prefix = "<thinking>"
|
||||
response = None
|
||||
|
||||
while current_pass_num <= max_passes:
|
||||
current_pass_num += 1
|
||||
|
||||
# Set up message depending on whether we have a <thinking> prefix
|
||||
messages_to_send = (
|
||||
messages + [{"role": "assistant", "content": prefix}]
|
||||
if include_thinking
|
||||
else messages
|
||||
)
|
||||
|
||||
# Set model parameters
|
||||
# model_id = model.value
|
||||
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
max_tokens = 4096
|
||||
content_type = 'application/json'
|
||||
accept = '*/*'
|
||||
temperature = 0.0
|
||||
response_text = await stream_bedrock_response(
|
||||
bedrock_runtime,
|
||||
messages_to_send,
|
||||
system_prompt,
|
||||
model_id,
|
||||
max_tokens,
|
||||
content_type,
|
||||
accept,
|
||||
temperature,
|
||||
callback,
|
||||
)
|
||||
|
||||
# Multi-pass flow
|
||||
current_pass_num = 1
|
||||
max_passes = 2
|
||||
# Set up messages array for next pass
|
||||
messages += [
|
||||
{"role": "assistant", "content": str(prefix) + response_text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "You've done a good job with a first draft. Improve this further based on the original instructions so that the app is fully functional and looks like the original video of the app we're trying to replicate.",
|
||||
},
|
||||
]
|
||||
|
||||
prefix = "<thinking>"
|
||||
response = None
|
||||
|
||||
# For debugging
|
||||
full_stream = ""
|
||||
debug_file_writer = DebugFileWriter()
|
||||
|
||||
while current_pass_num <= max_passes:
|
||||
current_pass_num += 1
|
||||
|
||||
# Set up message depending on whether we have a <thinking> prefix
|
||||
messages_to_send = (
|
||||
messages + [{"role": "assistant", "content": prefix}]
|
||||
if include_thinking
|
||||
else messages
|
||||
)
|
||||
|
||||
pprint_prompt(messages_to_send)
|
||||
|
||||
# Prepare the request body
|
||||
body = json.dumps({
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": max_tokens,
|
||||
"messages": messages_to_send,
|
||||
"system": system_prompt,
|
||||
"temperature": temperature
|
||||
})
|
||||
|
||||
# Invoke the Bedrock Runtime API with response stream
|
||||
response = bedrock_runtime.invoke_model_with_response_stream(
|
||||
body=body,
|
||||
modelId=model_id,
|
||||
accept=accept,
|
||||
contentType=content_type,
|
||||
)
|
||||
stream = response.get("body")
|
||||
|
||||
# Stream the response
|
||||
response_text = ""
|
||||
if stream:
|
||||
for event in stream:
|
||||
chunk = event.get("chunk")
|
||||
if chunk:
|
||||
data = chunk.get("bytes").decode()
|
||||
chunk_obj = json.loads(data)
|
||||
if chunk_obj["type"] == "content_block_delta":
|
||||
text = chunk_obj["delta"]["text"]
|
||||
response_text += text
|
||||
print(text, end="", flush=True)
|
||||
full_stream += text
|
||||
await callback(text)
|
||||
|
||||
|
||||
# Set up messages array for next pass
|
||||
messages += [
|
||||
{"role": "assistant", "content": str(prefix) + response_text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "You've done a good job with a first draft. Improve this further based on the original instructions so that the app is fully functional and looks like the original video of the app we're trying to replicate.",
|
||||
},
|
||||
]
|
||||
|
||||
# print(
|
||||
# f"Token usage: Input Tokens: {response.usage.input_tokens}, Output Tokens: {response.usage.output_tokens}"
|
||||
# )
|
||||
|
||||
if not response:
|
||||
raise Exception("No HTML response found in AI response")
|
||||
else:
|
||||
return full_stream
|
||||
|
||||
except ClientError as err:
|
||||
message = err.response["Error"]["Message"]
|
||||
print(f"A client error occurred: {message}")
|
||||
except Exception as err:
|
||||
print("An error occurred!")
|
||||
raise err
|
||||
return response_text
|
||||
@ -2,7 +2,7 @@ import os
|
||||
import traceback
|
||||
from fastapi import APIRouter, WebSocket
|
||||
import openai
|
||||
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE, AWS_ACCESS_KEY, AWS_SECRET_ACCESS_KEY
|
||||
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE, AWS_ACCESS_KEY, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME
|
||||
from custom_types import InputMode
|
||||
from llm import (
|
||||
Llm,
|
||||
@ -25,7 +25,6 @@ from prompts.types import Stack
|
||||
from video.utils import extract_tag_content, assemble_claude_prompt_video
|
||||
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@ -110,8 +109,8 @@ async def stream_code(websocket: WebSocket):
|
||||
print("Using OpenAI API key from environment variable")
|
||||
|
||||
if not openai_api_key and (
|
||||
code_generation_model == Llm.GPT_4_VISION
|
||||
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
|
||||
code_generation_model == Llm.GPT_4_VISION
|
||||
or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
|
||||
):
|
||||
print("OpenAI API key not found")
|
||||
await throw_error(
|
||||
@ -239,6 +238,7 @@ async def stream_code(websocket: WebSocket):
|
||||
messages=prompt_messages, # type: ignore
|
||||
access_key=AWS_ACCESS_KEY,
|
||||
secret_access_key=AWS_SECRET_ACCESS_KEY,
|
||||
aws_region_name=AWS_REGION_NAME,
|
||||
callback=lambda x: process_chunk(x),
|
||||
model=Llm.CLAUDE_3_OPUS,
|
||||
include_thinking=True,
|
||||
@ -261,6 +261,7 @@ async def stream_code(websocket: WebSocket):
|
||||
prompt_messages, # type: ignore
|
||||
access_key=AWS_ACCESS_KEY,
|
||||
secret_access_key=AWS_SECRET_ACCESS_KEY,
|
||||
aws_region_name=AWS_REGION_NAME,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
exact_llm_version = code_generation_model
|
||||
@ -276,35 +277,35 @@ async def stream_code(websocket: WebSocket):
|
||||
except openai.AuthenticationError as e:
|
||||
print("[GENERATE_CODE] Authentication failed", e)
|
||||
error_message = (
|
||||
"Incorrect OpenAI key. Please make sure your OpenAI API key is correct, or create a new OpenAI API key on your OpenAI dashboard."
|
||||
+ (
|
||||
" Alternatively, you can purchase code generation credits directly on this website."
|
||||
if IS_PROD
|
||||
else ""
|
||||
)
|
||||
"Incorrect OpenAI key. Please make sure your OpenAI API key is correct, or create a new OpenAI API key on your OpenAI dashboard."
|
||||
+ (
|
||||
" Alternatively, you can purchase code generation credits directly on this website."
|
||||
if IS_PROD
|
||||
else ""
|
||||
)
|
||||
)
|
||||
return await throw_error(error_message)
|
||||
except openai.NotFoundError as e:
|
||||
print("[GENERATE_CODE] Model not found", e)
|
||||
error_message = (
|
||||
e.message
|
||||
+ ". Please make sure you have followed the instructions correctly to obtain an OpenAI key with GPT vision access: https://github.com/abi/screenshot-to-code/blob/main/Troubleshooting.md"
|
||||
+ (
|
||||
" Alternatively, you can purchase code generation credits directly on this website."
|
||||
if IS_PROD
|
||||
else ""
|
||||
)
|
||||
e.message
|
||||
+ ". Please make sure you have followed the instructions correctly to obtain an OpenAI key with GPT vision access: https://github.com/abi/screenshot-to-code/blob/main/Troubleshooting.md"
|
||||
+ (
|
||||
" Alternatively, you can purchase code generation credits directly on this website."
|
||||
if IS_PROD
|
||||
else ""
|
||||
)
|
||||
)
|
||||
return await throw_error(error_message)
|
||||
except openai.RateLimitError as e:
|
||||
print("[GENERATE_CODE] Rate limit exceeded", e)
|
||||
error_message = (
|
||||
"OpenAI error - 'You exceeded your current quota, please check your plan and billing details.'"
|
||||
+ (
|
||||
" Alternatively, you can purchase code generation credits directly on this website."
|
||||
if IS_PROD
|
||||
else ""
|
||||
)
|
||||
"OpenAI error - 'You exceeded your current quota, please check your plan and billing details.'"
|
||||
+ (
|
||||
" Alternatively, you can purchase code generation credits directly on this website."
|
||||
if IS_PROD
|
||||
else ""
|
||||
)
|
||||
)
|
||||
return await throw_error(error_message)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user