pass user_id and api secret to saas API for storage

This commit is contained in:
Abi Raja 2024-09-18 15:20:50 +02:00
parent 76defbb1c2
commit 1bf3340502
3 changed files with 11 additions and 6 deletions

View File

@ -30,3 +30,4 @@ PLATFORM_ANTHROPIC_API_KEY = os.environ.get("PLATFORM_ANTHROPIC_API_KEY", "")
PLATFORM_SCREENSHOTONE_API_KEY = os.environ.get("PLATFORM_SCREENSHOTONE_API_KEY", "") PLATFORM_SCREENSHOTONE_API_KEY = os.environ.get("PLATFORM_SCREENSHOTONE_API_KEY", "")
BACKEND_SAAS_URL = os.environ.get("BACKEND_SAAS_URL", "") BACKEND_SAAS_URL = os.environ.get("BACKEND_SAAS_URL", "")
BACKEND_SAAS_API_SECRET = os.environ.get("BACKEND_SAAS_API_SECRET", "")

View File

@ -22,11 +22,11 @@ from llm import (
stream_openai_response, stream_openai_response,
) )
from mock_llm import mock_completion from mock_llm import mock_completion
from typing import Dict, List, cast, get_args from typing import Dict, cast, get_args
from image_generation.core import generate_images from image_generation.core import generate_images
from routes.logging_utils import PaymentMethod, send_to_saas_backend from routes.logging_utils import PaymentMethod, send_to_saas_backend
from routes.saas_utils import does_user_have_subscription_credits from routes.saas_utils import does_user_have_subscription_credits
from typing import Any, Callable, Coroutine, Dict, List, Literal, cast, get_args from typing import Any, Callable, Coroutine, Dict, Literal, cast, get_args
from image_generation.core import generate_images from image_generation.core import generate_images
from prompts import create_prompt from prompts import create_prompt
from prompts.claude_prompts import VIDEO_PROMPT from prompts.claude_prompts import VIDEO_PROMPT
@ -91,6 +91,7 @@ async def perform_image_generation(
@dataclass @dataclass
class ExtractedParams: class ExtractedParams:
user_id: str
stack: Stack stack: Stack
input_mode: InputMode input_mode: InputMode
code_generation_model: Llm code_generation_model: Llm
@ -198,6 +199,7 @@ async def extract_params(
) )
return ExtractedParams( return ExtractedParams(
user_id="fake_user_id",
stack=validated_stack, stack=validated_stack,
input_mode=validated_input_mode, input_mode=validated_input_mode,
code_generation_model=code_generation_model, code_generation_model=code_generation_model,
@ -259,6 +261,7 @@ async def stream_code(websocket: WebSocket):
print("Received params") print("Received params")
extracted_params = await extract_params(params, throw_error) extracted_params = await extract_params(params, throw_error)
user_id = extracted_params.user_id
stack = extracted_params.stack stack = extracted_params.stack
input_mode = extracted_params.input_mode input_mode = extracted_params.input_mode
code_generation_model = extracted_params.code_generation_model code_generation_model = extracted_params.code_generation_model
@ -451,6 +454,7 @@ async def stream_code(websocket: WebSocket):
# Catch any errors from sending to SaaS backend and continue # Catch any errors from sending to SaaS backend and continue
try: try:
await send_to_saas_backend( await send_to_saas_backend(
user_id,
prompt_messages, prompt_messages,
completions, completions,
payment_method=payment_method, payment_method=payment_method,
@ -459,7 +463,6 @@ async def stream_code(websocket: WebSocket):
is_imported_from_code=bool(params.get("isImportedFromCode", False)), is_imported_from_code=bool(params.get("isImportedFromCode", False)),
includes_result_image=bool(params.get("resultImage", False)), includes_result_image=bool(params.get("resultImage", False)),
input_mode=input_mode, input_mode=input_mode,
auth_token=params["authToken"],
) )
except Exception as e: except Exception as e:
print("Error sending to SaaS backend", e) print("Error sending to SaaS backend", e)

View File

@ -4,7 +4,7 @@ from openai.types.chat import ChatCompletionMessageParam
from typing import List from typing import List
import json import json
from config import BACKEND_SAAS_URL, IS_PROD from config import BACKEND_SAAS_API_SECRET, BACKEND_SAAS_URL, IS_PROD
from custom_types import InputMode from custom_types import InputMode
from llm import Llm from llm import Llm
from prompts.types import Stack from prompts.types import Stack
@ -19,6 +19,7 @@ class PaymentMethod(Enum):
async def send_to_saas_backend( async def send_to_saas_backend(
user_id: str,
prompt_messages: List[ChatCompletionMessageParam], prompt_messages: List[ChatCompletionMessageParam],
completions: list[str], completions: list[str],
llm_versions: list[Llm], llm_versions: list[Llm],
@ -27,7 +28,6 @@ async def send_to_saas_backend(
is_imported_from_code: bool, is_imported_from_code: bool,
includes_result_image: bool, includes_result_image: bool,
input_mode: InputMode, input_mode: InputMode,
auth_token: str | None = None,
): ):
if IS_PROD: if IS_PROD:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@ -35,6 +35,7 @@ async def send_to_saas_backend(
data = json.dumps( data = json.dumps(
{ {
"user_id": user_id,
"prompt": json.dumps(prompt_messages), "prompt": json.dumps(prompt_messages),
"completions": completions, "completions": completions,
"payment_method": payment_method.value, "payment_method": payment_method.value,
@ -48,7 +49,7 @@ async def send_to_saas_backend(
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {auth_token}", # Add the auth token to the headers "Authorization": f"Bearer {BACKEND_SAAS_API_SECRET}", # Add the auth token to the headers
} }
response = await client.post(url, content=data, headers=headers, timeout=10) response = await client.post(url, content=data, headers=headers, timeout=10)