From 7032342f919b3ca5ef06412b0d55cbb1431a1bbf Mon Sep 17 00:00:00 2001 From: yoreland Date: Sun, 28 Apr 2024 08:57:52 +0800 Subject: [PATCH] AWS Bedrock support --- backend/config.py | 2 + backend/evals/core.py | 28 ++-- backend/llm.py | 248 ++++++++++++++++++++++++++++---- backend/routes/generate_code.py | 102 +++++++------ 4 files changed, 305 insertions(+), 75 deletions(-) diff --git a/backend/config.py b/backend/config.py index 05592b0..971c830 100644 --- a/backend/config.py +++ b/backend/config.py @@ -4,6 +4,8 @@ import os ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None) +AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None) +AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY", None) # Debugging-related diff --git a/backend/evals/core.py b/backend/evals/core.py index 5e05362..081c3b3 100644 --- a/backend/evals/core.py +++ b/backend/evals/core.py @@ -1,7 +1,9 @@ import os +from config import AWS_ACCESS_KEY +from config import AWS_SECRET_ACCESS_KEY from config import ANTHROPIC_API_KEY -from llm import Llm, stream_claude_response, stream_openai_response +from llm import Llm, stream_claude_response, stream_openai_response, stream_claude_response_aws_bedrock from prompts import assemble_prompt from prompts.types import Stack @@ -10,20 +12,30 @@ async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str: prompt_messages = assemble_prompt(image_url, stack) openai_api_key = os.environ.get("OPENAI_API_KEY") anthropic_api_key = ANTHROPIC_API_KEY + aws_access_key = AWS_ACCESS_KEY + aws_secret_access_key = AWS_SECRET_ACCESS_KEY openai_base_url = None async def process_chunk(content: str): pass if model == Llm.CLAUDE_3_SONNET: - if not anthropic_api_key: - raise Exception("Anthropic API key not found") + if not anthropic_api_key and not aws_access_key and not aws_secret_access_key: + raise Exception("Anthropic API key or AWS Access Key not found") - completion = await stream_claude_response( - prompt_messages, - api_key=anthropic_api_key, - callback=lambda x: process_chunk(x), - ) + if anthropic_api_key: + completion = await stream_claude_response( + prompt_messages, + api_key=anthropic_api_key, + callback=lambda x: process_chunk(x), + ) + else: + completion = await stream_claude_response_aws_bedrock( + prompt_messages, + access_key=aws_access_key, + secret_access_key=aws_secret_access_key, + callback=lambda x: process_chunk(x), + ) else: if not openai_api_key: raise Exception("OpenAI API key not found") diff --git a/backend/llm.py b/backend/llm.py index 3d653b2..997d94f 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -5,7 +5,10 @@ from openai import AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk from config import IS_DEBUG_ENABLED from debug.DebugFileWriter import DebugFileWriter - +import json +import boto3 +from typing import List +from botocore.exceptions import ClientError from utils import pprint_prompt @@ -29,11 +32,11 @@ def convert_frontend_str_to_llm(frontend_str: str) -> Llm: async def stream_openai_response( - messages: List[ChatCompletionMessageParam], - api_key: str, - base_url: str | None, - callback: Callable[[str], Awaitable[None]], - model: Llm, + messages: List[ChatCompletionMessageParam], + api_key: str, + base_url: str | None, + callback: Callable[[str], Awaitable[None]], + model: Llm, ) -> str: client = AsyncOpenAI(api_key=api_key, base_url=base_url) @@ -65,9 +68,9 @@ async def stream_openai_response( # TODO: Have a seperate function that translates OpenAI messages to Claude messages async def stream_claude_response( - messages: List[ChatCompletionMessageParam], - api_key: str, - callback: Callable[[str], Awaitable[None]], + messages: List[ChatCompletionMessageParam], + api_key: str, + callback: Callable[[str], Awaitable[None]], ) -> str: client = AsyncAnthropic(api_key=api_key) @@ -105,11 +108,11 @@ async def stream_claude_response( # Stream Claude response async with client.messages.stream( - model=model.value, - max_tokens=max_tokens, - temperature=temperature, - system=system_prompt, - messages=claude_messages, # type: ignore + model=model.value, + max_tokens=max_tokens, + temperature=temperature, + system=system_prompt, + messages=claude_messages, # type: ignore ) as stream: async for text in stream.text_stream: await callback(text) @@ -123,13 +126,100 @@ async def stream_claude_response( return response.content[0].text +async def stream_claude_response_aws_bedrock( + messages: List[ChatCompletionMessageParam], + access_key: str, + secret_access_key: str, + callback: Callable[[str], Awaitable[None]], +) -> str: + + try: + # Initialize the Bedrock Runtime client + bedrock_runtime = boto3.client( + service_name='bedrock-runtime', + aws_access_key_id=access_key, + aws_secret_access_key=secret_access_key, + ) + + # Set model parameters + model_id = 'anthropic.claude-3-sonnet-20240229-v1:0' + max_tokens = 4096 + content_type = 'application/json' + accept = '*/*' + + # Translate OpenAI messages to Claude messages + system_prompt = cast(str, messages[0].get("content")) + claude_messages = [dict(message) for message in messages[1:]] + for message in claude_messages: + if not isinstance(message["content"], list): + continue + + for content in message["content"]: # type: ignore + if content["type"] == "image_url": + content["type"] = "image" + + # Extract base64 data and media type from data URL + # Example base64 data URL: data:image/png;base64,iVBOR... + image_data_url = cast(str, content["image_url"]["url"]) + media_type = image_data_url.split(";")[0].split(":")[1] + base64_data = image_data_url.split(",")[1] + + # Remove OpenAI parameter + del content["image_url"] + + content["source"] = { + "type": "base64", + "media_type": media_type, + "data": base64_data, + } + + # Prepare the request body + body = json.dumps({ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": max_tokens, + "messages": claude_messages, + "system": system_prompt, + "temperature": 0.0 + }) + + # Invoke the Bedrock Runtime API with response stream + response = bedrock_runtime.invoke_model_with_response_stream( + body=body, + modelId=model_id, + accept=accept, + contentType=content_type, + ) + stream = response.get("body") + + # Stream the response + final_message = "" + if stream: + for event in stream: + chunk = event.get("chunk") + if chunk: + data = chunk.get("bytes").decode() + chunk_obj = json.loads(data) + if chunk_obj["type"] == "content_block_delta": + text = chunk_obj["delta"]["text"] + await callback(text) + final_message += text + + return final_message + + except ClientError as err: + message = err.response["Error"]["Message"] + print(f"A client error occurred: {message}") + except Exception as err: + print("An error occurred!") + raise err + async def stream_claude_response_native( - system_prompt: str, - messages: list[Any], - api_key: str, - callback: Callable[[str], Awaitable[None]], - include_thinking: bool = False, - model: Llm = Llm.CLAUDE_3_OPUS, + system_prompt: str, + messages: list[Any], + api_key: str, + callback: Callable[[str], Awaitable[None]], + include_thinking: bool = False, + model: Llm = Llm.CLAUDE_3_OPUS, ) -> str: client = AsyncAnthropic(api_key=api_key) @@ -162,11 +252,11 @@ async def stream_claude_response_native( pprint_prompt(messages_to_send) async with client.messages.stream( - model=model.value, - max_tokens=max_tokens, - temperature=temperature, - system=system_prompt, - messages=messages_to_send, # type: ignore + model=model.value, + max_tokens=max_tokens, + temperature=temperature, + system=system_prompt, + messages=messages_to_send, # type: ignore ) as stream: async for text in stream.text_stream: print(text, end="", flush=True) @@ -210,3 +300,111 @@ async def stream_claude_response_native( raise Exception("No HTML response found in AI response") else: return response.content[0].text + +async def stream_claude_response_native_aws_bedrock( + system_prompt: str, + messages: list[Any], + access_key: str, + secret_access_key: str, + callback: Callable[[str], Awaitable[None]], + include_thinking: bool = False, + model: Llm = Llm.CLAUDE_3_OPUS, +) -> str: + + try: + # Initialize the Bedrock Runtime client + bedrock_runtime = boto3.client( + service_name='bedrock-runtime', + aws_access_key_id=access_key, + aws_secret_access_key=secret_access_key, + ) + + # Set model parameters + # model_id = model.value + model_id = "anthropic.claude-3-sonnet-20240229-v1:0" + max_tokens = 4096 + content_type = 'application/json' + accept = '*/*' + temperature = 0.0 + + # Multi-pass flow + current_pass_num = 1 + max_passes = 2 + + prefix = "" + response = None + + # For debugging + full_stream = "" + debug_file_writer = DebugFileWriter() + + while current_pass_num <= max_passes: + current_pass_num += 1 + + # Set up message depending on whether we have a prefix + messages_to_send = ( + messages + [{"role": "assistant", "content": prefix}] + if include_thinking + else messages + ) + + pprint_prompt(messages_to_send) + + # Prepare the request body + body = json.dumps({ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": max_tokens, + "messages": messages_to_send, + "system": system_prompt, + "temperature": temperature + }) + + # Invoke the Bedrock Runtime API with response stream + response = bedrock_runtime.invoke_model_with_response_stream( + body=body, + modelId=model_id, + accept=accept, + contentType=content_type, + ) + stream = response.get("body") + + # Stream the response + response_text = "" + if stream: + for event in stream: + chunk = event.get("chunk") + if chunk: + data = chunk.get("bytes").decode() + chunk_obj = json.loads(data) + if chunk_obj["type"] == "content_block_delta": + text = chunk_obj["delta"]["text"] + response_text += text + print(text, end="", flush=True) + full_stream += text + await callback(text) + + + # Set up messages array for next pass + messages += [ + {"role": "assistant", "content": str(prefix) + response_text}, + { + "role": "user", + "content": "You've done a good job with a first draft. Improve this further based on the original instructions so that the app is fully functional and looks like the original video of the app we're trying to replicate.", + }, + ] + + # print( + # f"Token usage: Input Tokens: {response.usage.input_tokens}, Output Tokens: {response.usage.output_tokens}" + # ) + + if not response: + raise Exception("No HTML response found in AI response") + else: + return full_stream + + except ClientError as err: + message = err.response["Error"]["Message"] + print(f"A client error occurred: {message}") + except Exception as err: + print("An error occurred!") + raise err \ No newline at end of file diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index fa5c7a5..b7ceb10 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -2,14 +2,14 @@ import os import traceback from fastapi import APIRouter, WebSocket import openai -from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE +from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE, AWS_ACCESS_KEY, AWS_SECRET_ACCESS_KEY from custom_types import InputMode from llm import ( Llm, convert_frontend_str_to_llm, stream_claude_response, stream_claude_response_native, - stream_openai_response, + stream_openai_response, stream_claude_response_aws_bedrock, stream_claude_response_native_aws_bedrock, ) from openai.types.chat import ChatCompletionMessageParam from mock_llm import mock_completion @@ -55,7 +55,7 @@ async def stream_code(websocket: WebSocket): print("Incoming websocket connection...") async def throw_error( - message: str, + message: str, ): await websocket.send_json({"type": "error", "value": message}) await websocket.close(APP_ERROR_WEB_SOCKET_CODE) @@ -110,8 +110,8 @@ async def stream_code(websocket: WebSocket): print("Using OpenAI API key from environment variable") if not openai_api_key and ( - code_generation_model == Llm.GPT_4_VISION - or code_generation_model == Llm.GPT_4_TURBO_2024_04_09 + code_generation_model == Llm.GPT_4_VISION + or code_generation_model == Llm.GPT_4_TURBO_2024_04_09 ): print("OpenAI API key not found") await throw_error( @@ -218,33 +218,51 @@ async def stream_code(websocket: WebSocket): else: try: if validated_input_mode == "video": - if not ANTHROPIC_API_KEY: + if not ANTHROPIC_API_KEY and not AWS_ACCESS_KEY and not AWS_SECRET_ACCESS_KEY: await throw_error( - "Video only works with Anthropic models. No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env" + "Video only works with Anthropic models. Neither Anthropic API key or AWS Access Key found. Please add the environment variable ANTHROPIC_API_KEY or AWS_ACCESS_KEY/AWS_SECRET_ACCESS_KEY to backend/.env" ) raise Exception("No Anthropic key") - completion = await stream_claude_response_native( - system_prompt=VIDEO_PROMPT, - messages=prompt_messages, # type: ignore - api_key=ANTHROPIC_API_KEY, - callback=lambda x: process_chunk(x), - model=Llm.CLAUDE_3_OPUS, - include_thinking=True, - ) + if ANTHROPIC_API_KEY: + completion = await stream_claude_response_native( + system_prompt=VIDEO_PROMPT, + messages=prompt_messages, # type: ignore + api_key=ANTHROPIC_API_KEY, + callback=lambda x: process_chunk(x), + model=Llm.CLAUDE_3_OPUS, + include_thinking=True, + ) + else: + completion = await stream_claude_response_native_aws_bedrock( + system_prompt=VIDEO_PROMPT, + messages=prompt_messages, # type: ignore + access_key=AWS_ACCESS_KEY, + secret_access_key=AWS_SECRET_ACCESS_KEY, + callback=lambda x: process_chunk(x), + model=Llm.CLAUDE_3_OPUS, + include_thinking=True, + ) exact_llm_version = Llm.CLAUDE_3_OPUS elif code_generation_model == Llm.CLAUDE_3_SONNET: - if not ANTHROPIC_API_KEY: + if not ANTHROPIC_API_KEY and not AWS_ACCESS_KEY and not AWS_SECRET_ACCESS_KEY: await throw_error( - "No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env" + "No Anthropic API key or AWS Access Key found. Please add the environment variable ANTHROPIC_API_KEY or AWS_ACCESS_KEY/AWS_SECRET_ACCESS_KEY to backend/.env" ) raise Exception("No Anthropic key") - - completion = await stream_claude_response( - prompt_messages, # type: ignore - api_key=ANTHROPIC_API_KEY, - callback=lambda x: process_chunk(x), - ) + if ANTHROPIC_API_KEY: + completion = await stream_claude_response( + prompt_messages, # type: ignore + api_key=ANTHROPIC_API_KEY, + callback=lambda x: process_chunk(x), + ) + else: + completion = await stream_claude_response_aws_bedrock( + prompt_messages, # type: ignore + access_key=AWS_ACCESS_KEY, + secret_access_key=AWS_SECRET_ACCESS_KEY, + callback=lambda x: process_chunk(x), + ) exact_llm_version = code_generation_model else: completion = await stream_openai_response( @@ -258,35 +276,35 @@ async def stream_code(websocket: WebSocket): except openai.AuthenticationError as e: print("[GENERATE_CODE] Authentication failed", e) error_message = ( - "Incorrect OpenAI key. Please make sure your OpenAI API key is correct, or create a new OpenAI API key on your OpenAI dashboard." - + ( - " Alternatively, you can purchase code generation credits directly on this website." - if IS_PROD - else "" - ) + "Incorrect OpenAI key. Please make sure your OpenAI API key is correct, or create a new OpenAI API key on your OpenAI dashboard." + + ( + " Alternatively, you can purchase code generation credits directly on this website." + if IS_PROD + else "" + ) ) return await throw_error(error_message) except openai.NotFoundError as e: print("[GENERATE_CODE] Model not found", e) error_message = ( - e.message - + ". Please make sure you have followed the instructions correctly to obtain an OpenAI key with GPT vision access: https://github.com/abi/screenshot-to-code/blob/main/Troubleshooting.md" - + ( - " Alternatively, you can purchase code generation credits directly on this website." - if IS_PROD - else "" - ) + e.message + + ". Please make sure you have followed the instructions correctly to obtain an OpenAI key with GPT vision access: https://github.com/abi/screenshot-to-code/blob/main/Troubleshooting.md" + + ( + " Alternatively, you can purchase code generation credits directly on this website." + if IS_PROD + else "" + ) ) return await throw_error(error_message) except openai.RateLimitError as e: print("[GENERATE_CODE] Rate limit exceeded", e) error_message = ( - "OpenAI error - 'You exceeded your current quota, please check your plan and billing details.'" - + ( - " Alternatively, you can purchase code generation credits directly on this website." - if IS_PROD - else "" - ) + "OpenAI error - 'You exceeded your current quota, please check your plan and billing details.'" + + ( + " Alternatively, you can purchase code generation credits directly on this website." + if IS_PROD + else "" + ) ) return await throw_error(error_message)