implement saas storage of multiple generations

This commit is contained in:
Abi Raja 2024-09-12 12:30:12 +02:00
parent 960e905a73
commit f9aa14b566
4 changed files with 38 additions and 30 deletions

View File

@ -7,15 +7,15 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: check-yaml - id: check-yaml
- id: check-added-large-files - id: check-added-large-files
- repo: local # - repo: local
hooks: # hooks:
- id: poetry-pytest # - id: poetry-pytest
name: Run pytest with Poetry # name: Run pytest with Poetry
entry: poetry run --directory backend pytest # entry: poetry run --directory backend pytest
language: system # language: system
pass_filenames: false # pass_filenames: false
always_run: true # always_run: true
files: ^backend/ # files: ^backend/
# - id: poetry-pyright # - id: poetry-pyright
# name: Run pyright with Poetry # name: Run pyright with Poetry
# entry: poetry run --directory backend pyright # entry: poetry run --directory backend pyright

12
backend/poetry.lock generated
View File

@ -647,13 +647,13 @@ socks = ["socksio (==1.*)"]
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "0.24.6" version = "0.24.7"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false optional = false
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
files = [ files = [
{file = "huggingface_hub-0.24.6-py3-none-any.whl", hash = "sha256:a990f3232aa985fe749bc9474060cbad75e8b2f115f6665a9fda5b9c97818970"}, {file = "huggingface_hub-0.24.7-py3-none-any.whl", hash = "sha256:a212c555324c8a7b1ffdd07266bb7e7d69ca71aa238d27b7842d65e9a26ac3e5"},
{file = "huggingface_hub-0.24.6.tar.gz", hash = "sha256:cc2579e761d070713eaa9c323e3debe39d5b464ae3a7261c39a9195b27bb8000"}, {file = "huggingface_hub-0.24.7.tar.gz", hash = "sha256:0ad8fb756e2831da0ac0491175b960f341fe06ebcf80ed6f8728313f95fc0207"},
] ]
[package.dependencies] [package.dependencies]
@ -1293,13 +1293,13 @@ email = ["email-validator (>=1.0.3)"]
[[package]] [[package]]
name = "pyright" name = "pyright"
version = "1.1.379" version = "1.1.380"
description = "Command line wrapper for pyright" description = "Command line wrapper for pyright"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "pyright-1.1.379-py3-none-any.whl", hash = "sha256:01954811ac71db8646f50de1577576dc275ffb891a9e7324350e676cf6df323f"}, {file = "pyright-1.1.380-py3-none-any.whl", hash = "sha256:a6404392053d8848bacc7aebcbd9d318bb46baf1a1a000359305481920f43879"},
{file = "pyright-1.1.379.tar.gz", hash = "sha256:6f426cb6443786fa966b930c23ad1941c8cb9fe672e4589daea8d80bb34193ea"}, {file = "pyright-1.1.380.tar.gz", hash = "sha256:e6ceb1a5f7e9f03106e0aa1d6fbb4d97735a5e7ffb59f3de6b2db590baf935b2"},
] ]
[package.dependencies] [package.dependencies]

View File

@ -295,6 +295,7 @@ async def stream_code(websocket: WebSocket):
if SHOULD_MOCK_AI_RESPONSE: if SHOULD_MOCK_AI_RESPONSE:
completions = [await mock_completion(process_chunk, input_mode=input_mode)] completions = [await mock_completion(process_chunk, input_mode=input_mode)]
variant_models = [Llm.GPT_4O_2024_05_13]
else: else:
try: try:
if input_mode == "video": if input_mode == "video":
@ -317,16 +318,26 @@ async def stream_code(websocket: WebSocket):
include_thinking=True, include_thinking=True,
) )
] ]
variant_models = [Llm.CLAUDE_3_OPUS]
else: else:
# Depending on the presence and absence of various keys, # Depending on the presence and absence of various keys,
# we decide which models to run # we decide which models to run
variant_models = [] variant_models = []
if openai_api_key and anthropic_api_key: if openai_api_key and anthropic_api_key:
variant_models = ["openai", "anthropic"] variant_models = [
Llm.GPT_4O_2024_05_13,
Llm.CLAUDE_3_5_SONNET_2024_06_20,
]
elif openai_api_key: elif openai_api_key:
variant_models = ["openai", "openai"] variant_models = [
Llm.GPT_4O_2024_05_13,
Llm.GPT_4O_2024_05_13,
]
elif anthropic_api_key: elif anthropic_api_key:
variant_models = ["anthropic", "anthropic"] variant_models = [
Llm.CLAUDE_3_5_SONNET_2024_06_20,
Llm.CLAUDE_3_5_SONNET_2024_06_20,
]
else: else:
await throw_error( await throw_error(
"No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog. If you add it to .env, make sure to restart the backend server." "No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog. If you add it to .env, make sure to restart the backend server."
@ -335,7 +346,7 @@ async def stream_code(websocket: WebSocket):
tasks: list[Coroutine[Any, Any, str]] = [] tasks: list[Coroutine[Any, Any, str]] = []
for index, model in enumerate(variant_models): for index, model in enumerate(variant_models):
if model == "openai": if model == Llm.GPT_4O_2024_05_13:
if openai_api_key is None: if openai_api_key is None:
await throw_error("OpenAI API key is missing.") await throw_error("OpenAI API key is missing.")
raise Exception("OpenAI API key is missing.") raise Exception("OpenAI API key is missing.")
@ -349,7 +360,7 @@ async def stream_code(websocket: WebSocket):
model=Llm.GPT_4O_2024_05_13, model=Llm.GPT_4O_2024_05_13,
) )
) )
elif model == "anthropic": elif model == Llm.CLAUDE_3_5_SONNET_2024_06_20:
if anthropic_api_key is None: if anthropic_api_key is None:
await throw_error("Anthropic API key is missing.") await throw_error("Anthropic API key is missing.")
raise Exception("Anthropic API key is missing.") raise Exception("Anthropic API key is missing.")
@ -412,15 +423,11 @@ async def stream_code(websocket: WebSocket):
if IS_PROD: if IS_PROD:
# Catch any errors from sending to SaaS backend and continue # Catch any errors from sending to SaaS backend and continue
try: try:
# TODO*
# assert exact_llm_version is not None, "exact_llm_version is not set"
await send_to_saas_backend( await send_to_saas_backend(
prompt_messages, prompt_messages,
# TODO*: Store both completions completions,
completions[0],
payment_method=payment_method, payment_method=payment_method,
# TODO* llm_versions=variant_models,
llm_version=Llm.GPT_4O_2024_05_13,
stack=stack, stack=stack,
is_imported_from_code=bool(params.get("isImportedFromCode", False)), is_imported_from_code=bool(params.get("isImportedFromCode", False)),
includes_result_image=bool(params.get("resultImage", False)), includes_result_image=bool(params.get("resultImage", False)),

View File

@ -20,9 +20,9 @@ class PaymentMethod(Enum):
async def send_to_saas_backend( async def send_to_saas_backend(
prompt_messages: List[ChatCompletionMessageParam], prompt_messages: List[ChatCompletionMessageParam],
completion: str, completions: list[str],
llm_versions: list[Llm],
payment_method: PaymentMethod, payment_method: PaymentMethod,
llm_version: Llm,
stack: Stack, stack: Stack,
is_imported_from_code: bool, is_imported_from_code: bool,
includes_result_image: bool, includes_result_image: bool,
@ -36,9 +36,9 @@ async def send_to_saas_backend(
data = json.dumps( data = json.dumps(
{ {
"prompt": json.dumps(prompt_messages), "prompt": json.dumps(prompt_messages),
"completion": completion, "completions": completions,
"payment_method": payment_method.value, "payment_method": payment_method.value,
"llm_version": llm_version.value, "llm_versions": [llm_version.value for llm_version in llm_versions],
"stack": stack, "stack": stack,
"is_imported_from_code": is_imported_from_code, "is_imported_from_code": is_imported_from_code,
"includes_result_image": includes_result_image, "includes_result_image": includes_result_image,
@ -52,5 +52,6 @@ async def send_to_saas_backend(
} }
response = await client.post(url, content=data, headers=headers) response = await client.post(url, content=data, headers=headers)
response.raise_for_status()
response_data = response.json() response_data = response.json()
return response_data return response_data