This commit is contained in:
Naman Dhingra 2024-06-06 12:02:13 +08:00 committed by GitHub
commit 3da74be3c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 449 additions and 3 deletions

View File

@ -4,6 +4,7 @@
import os import os
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None) ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", None)
# Debugging-related # Debugging-related

View File

@ -5,6 +5,9 @@ from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
from config import IS_DEBUG_ENABLED from config import IS_DEBUG_ENABLED
from debug.DebugFileWriter import DebugFileWriter from debug.DebugFileWriter import DebugFileWriter
import google.generativeai as genai
import base64
from utils import pprint_prompt from utils import pprint_prompt
@ -17,6 +20,7 @@ class Llm(Enum):
CLAUDE_3_SONNET = "claude-3-sonnet-20240229" CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_OPUS = "claude-3-opus-20240229" CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307" CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
GEMINI_1_5_FLASH = "gemini-1.5-flash"
# Will throw errors if you send a garbage string # Will throw errors if you send a garbage string
@ -216,3 +220,68 @@ async def stream_claude_response_native(
raise Exception("No HTML response found in AI response") raise Exception("No HTML response found in AI response")
else: else:
return response.content[0].text return response.content[0].text
async def stream_gemini_response(
messages: List[ChatCompletionMessageParam],
api_key: str,
callback: Callable[[str], Awaitable[None]],
) -> str:
genai.configure(api_key=api_key)
model = genai.GenerativeModel("gemini-1.5-flash")
# Translate OpenAI messages to Gemini messages
gemini_messages = []
for message in messages:
if isinstance(message["content"], str):
gemini_messages.append(
{
"role": (
"model" if message["role"] == "system" else message["role"]
),
"parts": [message["content"]],
}
)
elif isinstance(message["content"], list):
for content in message["content"]:
if content["type"] == "text":
gemini_messages.append(
{
"role": (
"model"
if message["role"] == "system"
else message["role"]
),
"parts": [content["text"]],
}
)
elif content["type"] == "image_url":
# Extract base64 data and media type from data URL
# Example base64 data URL: data:image/png;base64,iVBOR...
image_data_url = cast(str, content["image_url"]["url"])
media_type = image_data_url.split(";")[0].split(":")[1]
base64_data = image_data_url.split(",")[1]
content["mime_type"] = media_type
content["data"] = base64.b64decode(base64_data)
# Remove OpenAI parameter
del content["image_url"]
del content["type"]
gemini_messages.append(
{"role": "user", "parts": {"inline": content}}
)
response = model.generate_content(gemini_messages, stream=True)
for chunk in response:
await callback(chunk.text)
if response._result and response._result.candidates:
candidate = response._result.candidates[0]
if candidate.content and candidate.content.parts:
part = candidate.content.parts[0]
if part.text:
return part.text
return ""

349
backend/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]] [[package]]
name = "anthropic" name = "anthropic"
@ -67,6 +67,17 @@ charset-normalizer = ["charset-normalizer"]
html5lib = ["html5lib"] html5lib = ["html5lib"]
lxml = ["lxml"] lxml = ["lxml"]
[[package]]
name = "cachetools"
version = "5.3.3"
description = "Extensible memoizing collections and decorators"
optional = false
python-versions = ">=3.7"
files = [
{file = "cachetools-5.3.3-py3-none-any.whl", hash = "sha256:0abad1021d3f8325b2fc1d2e9c8b9c9d57b04c3932657a72465447332c24d945"},
{file = "cachetools-5.3.3.tar.gz", hash = "sha256:ba29e2dfa0b8b556606f097407ed1aa62080ee108ab0dc5ec9d6a723a007d105"},
]
[[package]] [[package]]
name = "certifi" name = "certifi"
version = "2024.2.2" version = "2024.2.2"
@ -332,6 +343,224 @@ smb = ["smbprotocol"]
ssh = ["paramiko"] ssh = ["paramiko"]
tqdm = ["tqdm"] tqdm = ["tqdm"]
[[package]]
name = "google-ai-generativelanguage"
version = "0.6.4"
description = "Google Ai Generativelanguage API client library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google-ai-generativelanguage-0.6.4.tar.gz", hash = "sha256:1750848c12af96cb24ae1c3dd05e4bfe24867dc4577009ed03e1042d8421e874"},
{file = "google_ai_generativelanguage-0.6.4-py3-none-any.whl", hash = "sha256:730e471aa549797118fb1c88421ba1957741433ada575cf5dd08d3aebf903ab1"},
]
[package.dependencies]
google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
proto-plus = ">=1.22.3,<2.0.0dev"
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
[[package]]
name = "google-api-core"
version = "2.19.0"
description = "Google API client core library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google-api-core-2.19.0.tar.gz", hash = "sha256:cf1b7c2694047886d2af1128a03ae99e391108a08804f87cfd35970e49c9cd10"},
{file = "google_api_core-2.19.0-py3-none-any.whl", hash = "sha256:8661eec4078c35428fd3f69a2c7ee29e342896b70f01d1a1cbcb334372dd6251"},
]
[package.dependencies]
google-auth = ">=2.14.1,<3.0.dev0"
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
grpcio = [
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
]
grpcio-status = [
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
]
proto-plus = ">=1.22.3,<2.0.0dev"
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
requests = ">=2.18.0,<3.0.0.dev0"
[package.extras]
grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"]
grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
[[package]]
name = "google-api-python-client"
version = "2.130.0"
description = "Google API Client Library for Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "google-api-python-client-2.130.0.tar.gz", hash = "sha256:2bba3122b82a649c677b8a694b8e2bbf2a5fbf3420265caf3343bb88e2e9f0ae"},
{file = "google_api_python_client-2.130.0-py2.py3-none-any.whl", hash = "sha256:7d45a28d738628715944a9c9d73e8696e7e03ac50b7de87f5e3035cefa94ed3a"},
]
[package.dependencies]
google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0"
google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0"
google-auth-httplib2 = ">=0.2.0,<1.0.0"
httplib2 = ">=0.19.0,<1.dev0"
uritemplate = ">=3.0.1,<5"
[[package]]
name = "google-auth"
version = "2.29.0"
description = "Google Authentication Library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google-auth-2.29.0.tar.gz", hash = "sha256:672dff332d073227550ffc7457868ac4218d6c500b155fe6cc17d2b13602c360"},
{file = "google_auth-2.29.0-py2.py3-none-any.whl", hash = "sha256:d452ad095688cd52bae0ad6fafe027f6a6d6f560e810fec20914e17a09526415"},
]
[package.dependencies]
cachetools = ">=2.0.0,<6.0"
pyasn1-modules = ">=0.2.1"
rsa = ">=3.1.4,<5"
[package.extras]
aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"]
enterprise-cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"]
pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
reauth = ["pyu2f (>=0.1.5)"]
requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
[[package]]
name = "google-auth-httplib2"
version = "0.2.0"
description = "Google Authentication Library: httplib2 transport"
optional = false
python-versions = "*"
files = [
{file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"},
{file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"},
]
[package.dependencies]
google-auth = "*"
httplib2 = ">=0.19.0"
[[package]]
name = "google-generativeai"
version = "0.5.4"
description = "Google Generative AI High level API client library and tools."
optional = false
python-versions = ">=3.9"
files = [
{file = "google_generativeai-0.5.4-py3-none-any.whl", hash = "sha256:036d63ee35e7c8aedceda4f81c390a5102808af09ff3a6e57e27ed0be0708f3c"},
]
[package.dependencies]
google-ai-generativelanguage = "0.6.4"
google-api-core = "*"
google-api-python-client = "*"
google-auth = ">=2.15.0"
protobuf = "*"
pydantic = "*"
tqdm = "*"
typing-extensions = "*"
[package.extras]
dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
[[package]]
name = "googleapis-common-protos"
version = "1.63.0"
description = "Common protobufs used in Google APIs"
optional = false
python-versions = ">=3.7"
files = [
{file = "googleapis-common-protos-1.63.0.tar.gz", hash = "sha256:17ad01b11d5f1d0171c06d3ba5c04c54474e883b66b949722b4938ee2694ef4e"},
{file = "googleapis_common_protos-1.63.0-py2.py3-none-any.whl", hash = "sha256:ae45f75702f7c08b541f750854a678bd8f534a1a6bace6afe975f1d0a82d6632"},
]
[package.dependencies]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
[package.extras]
grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"]
[[package]]
name = "grpcio"
version = "1.64.0"
description = "HTTP/2-based RPC framework"
optional = false
python-versions = ">=3.8"
files = [
{file = "grpcio-1.64.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:3b09c3d9de95461214a11d82cc0e6a46a6f4e1f91834b50782f932895215e5db"},
{file = "grpcio-1.64.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:7e013428ab472892830287dd082b7d129f4d8afef49227a28223a77337555eaa"},
{file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:02cc9cc3f816d30f7993d0d408043b4a7d6a02346d251694d8ab1f78cc723e7e"},
{file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f5de082d936e0208ce8db9095821361dfa97af8767a6607ae71425ac8ace15c"},
{file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7b7bf346391dffa182fba42506adf3a84f4a718a05e445b37824136047686a1"},
{file = "grpcio-1.64.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b2cbdfba18408389a1371f8c2af1659119e1831e5ed24c240cae9e27b4abc38d"},
{file = "grpcio-1.64.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:aca4f15427d2df592e0c8f3d38847e25135e4092d7f70f02452c0e90d6a02d6d"},
{file = "grpcio-1.64.0-cp310-cp310-win32.whl", hash = "sha256:7c1f5b2298244472bcda49b599be04579f26425af0fd80d3f2eb5fd8bc84d106"},
{file = "grpcio-1.64.0-cp310-cp310-win_amd64.whl", hash = "sha256:73f84f9e5985a532e47880b3924867de16fa1aa513fff9b26106220c253c70c5"},
{file = "grpcio-1.64.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2a18090371d138a57714ee9bffd6c9c9cb2e02ce42c681aac093ae1e7189ed21"},
{file = "grpcio-1.64.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:59c68df3a934a586c3473d15956d23a618b8f05b5e7a3a904d40300e9c69cbf0"},
{file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b52e1ec7185512103dd47d41cf34ea78e7a7361ba460187ddd2416b480e0938c"},
{file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d598b5d5e2c9115d7fb7e2cb5508d14286af506a75950762aa1372d60e41851"},
{file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01615bbcae6875eee8091e6b9414072f4e4b00d8b7e141f89635bdae7cf784e5"},
{file = "grpcio-1.64.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0b2dfe6dcace264807d9123d483d4c43274e3f8c39f90ff51de538245d7a4145"},
{file = "grpcio-1.64.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7f17572dc9acd5e6dfd3014d10c0b533e9f79cd9517fc10b0225746f4c24b58e"},
{file = "grpcio-1.64.0-cp311-cp311-win32.whl", hash = "sha256:6ec5ed15b4ffe56e2c6bc76af45e6b591c9be0224b3fb090adfb205c9012367d"},
{file = "grpcio-1.64.0-cp311-cp311-win_amd64.whl", hash = "sha256:597191370951b477b7a1441e1aaa5cacebeb46a3b0bd240ec3bb2f28298c7553"},
{file = "grpcio-1.64.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:1ce4cd5a61d4532651079e7aae0fedf9a80e613eed895d5b9743e66b52d15812"},
{file = "grpcio-1.64.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:650a8150a9b288f40d5b7c1d5400cc11724eae50bd1f501a66e1ea949173649b"},
{file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8de0399b983f8676a7ccfdd45e5b2caec74a7e3cc576c6b1eecf3b3680deda5e"},
{file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46b8b43ba6a2a8f3103f103f97996cad507bcfd72359af6516363c48793d5a7b"},
{file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a54362f03d4dcfae63be455d0a7d4c1403673498b92c6bfe22157d935b57c7a9"},
{file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1f8ea18b928e539046bb5f9c124d717fbf00cc4b2d960ae0b8468562846f5aa1"},
{file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c56c91bd2923ddb6e7ed28ebb66d15633b03e0df22206f22dfcdde08047e0a48"},
{file = "grpcio-1.64.0-cp312-cp312-win32.whl", hash = "sha256:874c741c8a66f0834f653a69e7e64b4e67fcd4a8d40296919b93bab2ccc780ba"},
{file = "grpcio-1.64.0-cp312-cp312-win_amd64.whl", hash = "sha256:0da1d921f8e4bcee307aeef6c7095eb26e617c471f8cb1c454fd389c5c296d1e"},
{file = "grpcio-1.64.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:c46fb6bfca17bfc49f011eb53416e61472fa96caa0979b4329176bdd38cbbf2a"},
{file = "grpcio-1.64.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3d2004e85cf5213995d09408501f82c8534700d2babeb81dfdba2a3bff0bb396"},
{file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6d5541eb460d73a07418524fb64dcfe0adfbcd32e2dac0f8f90ce5b9dd6c046c"},
{file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f279ad72dd7d64412e10f2443f9f34872a938c67387863c4cd2fb837f53e7d2"},
{file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85fda90b81da25993aa47fae66cae747b921f8f6777550895fb62375b776a231"},
{file = "grpcio-1.64.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a053584079b793a54bece4a7d1d1b5c0645bdbee729215cd433703dc2532f72b"},
{file = "grpcio-1.64.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:579dd9fb11bc73f0de061cab5f8b2def21480fd99eb3743ed041ad6a1913ee2f"},
{file = "grpcio-1.64.0-cp38-cp38-win32.whl", hash = "sha256:23b6887bb21d77649d022fa1859e05853fdc2e60682fd86c3db652a555a282e0"},
{file = "grpcio-1.64.0-cp38-cp38-win_amd64.whl", hash = "sha256:753cb58683ba0c545306f4e17dabf468d29cb6f6b11832e1e432160bb3f8403c"},
{file = "grpcio-1.64.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:2186d76a7e383e1466e0ea2b0febc343ffeae13928c63c6ec6826533c2d69590"},
{file = "grpcio-1.64.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0f30596cdcbed3c98024fb4f1d91745146385b3f9fd10c9f2270cbfe2ed7ed91"},
{file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:d9171f025a196f5bcfec7e8e7ffb7c3535f7d60aecd3503f9e250296c7cfc150"},
{file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf4c8daed18ae2be2f1fc7d613a76ee2a2e28fdf2412d5c128be23144d28283d"},
{file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3550493ac1d23198d46dc9c9b24b411cef613798dc31160c7138568ec26bc9b4"},
{file = "grpcio-1.64.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:3161a8f8bb38077a6470508c1a7301cd54301c53b8a34bb83e3c9764874ecabd"},
{file = "grpcio-1.64.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2e8fabe2cc57a369638ab1ad8e6043721014fdf9a13baa7c0e35995d3a4a7618"},
{file = "grpcio-1.64.0-cp39-cp39-win32.whl", hash = "sha256:31890b24d47b62cc27da49a462efe3d02f3c120edb0e6c46dcc0025506acf004"},
{file = "grpcio-1.64.0-cp39-cp39-win_amd64.whl", hash = "sha256:5a56797dea8c02e7d3a85dfea879f286175cf4d14fbd9ab3ef2477277b927baa"},
{file = "grpcio-1.64.0.tar.gz", hash = "sha256:257baf07f53a571c215eebe9679c3058a313fd1d1f7c4eede5a8660108c52d9c"},
]
[package.extras]
protobuf = ["grpcio-tools (>=1.64.0)"]
[[package]]
name = "grpcio-status"
version = "1.62.2"
description = "Status proto mapping for gRPC"
optional = false
python-versions = ">=3.6"
files = [
{file = "grpcio-status-1.62.2.tar.gz", hash = "sha256:62e1bfcb02025a1cd73732a2d33672d3e9d0df4d21c12c51e0bbcaf09bab742a"},
{file = "grpcio_status-1.62.2-py3-none-any.whl", hash = "sha256:206ddf0eb36bc99b033f03b2c8e95d319f0044defae9b41ae21408e7e0cda48f"},
]
[package.dependencies]
googleapis-common-protos = ">=1.5.5"
grpcio = ">=1.62.2"
protobuf = ">=4.21.6"
[[package]] [[package]]
name = "h11" name = "h11"
version = "0.14.0" version = "0.14.0"
@ -364,6 +593,20 @@ http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"] socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<0.25.0)"] trio = ["trio (>=0.22.0,<0.25.0)"]
[[package]]
name = "httplib2"
version = "0.22.0"
description = "A comprehensive HTTP client library."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"},
{file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"},
]
[package.dependencies]
pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""}
[[package]] [[package]]
name = "httpx" name = "httpx"
version = "0.25.2" version = "0.25.2"
@ -771,6 +1014,68 @@ files = [
[package.dependencies] [package.dependencies]
tqdm = "*" tqdm = "*"
[[package]]
name = "proto-plus"
version = "1.23.0"
description = "Beautiful, Pythonic protocol buffers."
optional = false
python-versions = ">=3.6"
files = [
{file = "proto-plus-1.23.0.tar.gz", hash = "sha256:89075171ef11988b3fa157f5dbd8b9cf09d65fffee97e29ce403cd8defba19d2"},
{file = "proto_plus-1.23.0-py3-none-any.whl", hash = "sha256:a829c79e619e1cf632de091013a4173deed13a55f326ef84f05af6f50ff4c82c"},
]
[package.dependencies]
protobuf = ">=3.19.0,<5.0.0dev"
[package.extras]
testing = ["google-api-core[grpc] (>=1.31.5)"]
[[package]]
name = "protobuf"
version = "4.25.3"
description = ""
optional = false
python-versions = ">=3.8"
files = [
{file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"},
{file = "protobuf-4.25.3-cp310-abi3-win_amd64.whl", hash = "sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8"},
{file = "protobuf-4.25.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c"},
{file = "protobuf-4.25.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019"},
{file = "protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d"},
{file = "protobuf-4.25.3-cp38-cp38-win32.whl", hash = "sha256:f4f118245c4a087776e0a8408be33cf09f6c547442c00395fbfb116fac2f8ac2"},
{file = "protobuf-4.25.3-cp38-cp38-win_amd64.whl", hash = "sha256:c053062984e61144385022e53678fbded7aea14ebb3e0305ae3592fb219ccfa4"},
{file = "protobuf-4.25.3-cp39-cp39-win32.whl", hash = "sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4"},
{file = "protobuf-4.25.3-cp39-cp39-win_amd64.whl", hash = "sha256:e3c97a1555fd6388f857770ff8b9703083de6bf1f9274a002a332d65fbb56c8c"},
{file = "protobuf-4.25.3-py3-none-any.whl", hash = "sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9"},
{file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"},
]
[[package]]
name = "pyasn1"
version = "0.6.0"
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyasn1-0.6.0-py2.py3-none-any.whl", hash = "sha256:cca4bb0f2df5504f02f6f8a775b6e416ff9b0b3b16f7ee80b5a3153d9b804473"},
{file = "pyasn1-0.6.0.tar.gz", hash = "sha256:3a35ab2c4b5ef98e17dfdec8ab074046fbda76e281c5a706ccd82328cfc8f64c"},
]
[[package]]
name = "pyasn1-modules"
version = "0.4.0"
description = "A collection of ASN.1-based protocols modules"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyasn1_modules-0.4.0-py3-none-any.whl", hash = "sha256:be04f15b66c206eed667e0bb5ab27e2b1855ea54a842e5037738099e8ca4ae0b"},
{file = "pyasn1_modules-0.4.0.tar.gz", hash = "sha256:831dbcea1b177b28c9baddf4c6d1013c24c3accd14a1873fffaa6a2e905f17b6"},
]
[package.dependencies]
pyasn1 = ">=0.4.6,<0.7.0"
[[package]] [[package]]
name = "pydantic" name = "pydantic"
version = "1.10.14" version = "1.10.14"
@ -823,6 +1128,20 @@ typing-extensions = ">=4.2.0"
dotenv = ["python-dotenv (>=0.10.4)"] dotenv = ["python-dotenv (>=0.10.4)"]
email = ["email-validator (>=1.0.3)"] email = ["email-validator (>=1.0.3)"]
[[package]]
name = "pyparsing"
version = "3.1.2"
description = "pyparsing module - Classes and methods to define and execute parsing grammars"
optional = false
python-versions = ">=3.6.8"
files = [
{file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"},
{file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"},
]
[package.extras]
diagrams = ["jinja2", "railroad-diagrams"]
[[package]] [[package]]
name = "pyright" name = "pyright"
version = "1.1.354" version = "1.1.354"
@ -902,6 +1221,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -957,6 +1277,20 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"] socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rsa"
version = "4.9"
description = "Pure-Python RSA implementation"
optional = false
python-versions = ">=3.6,<4"
files = [
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
]
[package.dependencies]
pyasn1 = ">=0.1.3"
[[package]] [[package]]
name = "setuptools" name = "setuptools"
version = "69.2.0" version = "69.2.0"
@ -1181,6 +1515,17 @@ files = [
{file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"},
] ]
[[package]]
name = "uritemplate"
version = "4.1.1"
description = "Implementation of RFC 6570 URI Templates"
optional = false
python-versions = ">=3.6"
files = [
{file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"},
{file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"},
]
[[package]] [[package]]
name = "urllib3" name = "urllib3"
version = "2.2.1" version = "2.2.1"
@ -1321,4 +1666,4 @@ files = [
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "274fed55cf4a2f4e9954b3d196c103b72225409c6050759a939f0ca197ae3f79" content-hash = "e8e657a1d7bbded69757bded07b7fe0754935ef615c5703560ff74e3fb4889e2"

View File

@ -17,6 +17,7 @@ httpx = "^0.25.1"
pre-commit = "^3.6.2" pre-commit = "^3.6.2"
anthropic = "^0.18.0" anthropic = "^0.18.0"
moviepy = "^1.0.3" moviepy = "^1.0.3"
google-generativeai = "^0.5.4"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pytest = "^7.4.3" pytest = "^7.4.3"

View File

@ -2,7 +2,7 @@ import os
import traceback import traceback
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import openai import openai
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE from config import ANTHROPIC_API_KEY, GOOGLE_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
from custom_types import InputMode from custom_types import InputMode
from llm import ( from llm import (
Llm, Llm,
@ -10,7 +10,9 @@ from llm import (
stream_claude_response, stream_claude_response,
stream_claude_response_native, stream_claude_response_native,
stream_openai_response, stream_openai_response,
stream_gemini_response,
) )
from utils import make_json_serializable
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion from mock_llm import mock_completion
from typing import Dict, List, Union, cast, get_args from typing import Dict, List, Union, cast, get_args
@ -43,6 +45,9 @@ def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: st
# Generate a unique filename using the current timestamp within the logs directory # Generate a unique filename using the current timestamp within the logs directory
filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json") filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json")
# Convert the byte datatype in the prompt to string
prompt_messages = make_json_serializable(prompt_messages)
# Write the messages dict into a new file for each run # Write the messages dict into a new file for each run
with open(filename, "w") as f: with open(filename, "w") as f:
f.write(json.dumps({"prompt": prompt_messages, "completion": completion})) f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
@ -258,6 +263,19 @@ async def stream_code(websocket: WebSocket):
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
) )
exact_llm_version = code_generation_model exact_llm_version = code_generation_model
elif code_generation_model == Llm.GEMINI_1_5_FLASH:
if not GOOGLE_API_KEY:
await throw_error(
"No Google API key found. Please add the environment variable GOOGLE_API_KEY to backend/.env"
)
raise Exception("No Google key")
completion = await stream_gemini_response(
prompt_messages, # type: ignore
api_key=GOOGLE_API_KEY,
callback=lambda x: process_chunk(x),
)
exact_llm_version = code_generation_model
else: else:
completion = await stream_openai_response( completion = await stream_openai_response(
prompt_messages, # type: ignore prompt_messages, # type: ignore

View File

@ -2,6 +2,7 @@ import copy
import json import json
from typing import List from typing import List
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
import base64
def pprint_prompt(prompt_messages: List[ChatCompletionMessageParam]): def pprint_prompt(prompt_messages: List[ChatCompletionMessageParam]):
@ -28,3 +29,12 @@ def truncate_data_strings(data: List[ChatCompletionMessageParam]): # type: igno
cloned_data = [truncate_data_strings(item) for item in cloned_data] # type: ignore cloned_data = [truncate_data_strings(item) for item in cloned_data] # type: ignore
return cloned_data # type: ignore return cloned_data # type: ignore
def make_json_serializable(content):
if isinstance(content, bytes):
return base64.b64encode(content).decode('utf-8')
elif isinstance(content, dict):
return {key: make_json_serializable(value) for key, value in content.items()}
elif isinstance(content, list):
return [make_json_serializable(item) for item in content]
return content

View File

@ -5,6 +5,7 @@ export enum CodeGenerationModel {
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09", GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09",
GPT_4_VISION = "gpt_4_vision", GPT_4_VISION = "gpt_4_vision",
CLAUDE_3_SONNET = "claude_3_sonnet", CLAUDE_3_SONNET = "claude_3_sonnet",
GEMINI_1_5_FLASH = "gemini-1.5-flash",
} }
// Will generate a static error if a model in the enum above is not in the descriptions // Will generate a static error if a model in the enum above is not in the descriptions
@ -15,4 +16,5 @@ export const CODE_GENERATION_MODEL_DESCRIPTIONS: {
"gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false }, "gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false },
gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false }, gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false },
claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false }, claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false },
"gemini-1.5-flash": { name: "Gemini 1.5 Flash", inBeta: false },
}; };