diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index d59b8d1..580df4d 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -12,7 +12,8 @@ from prompts import assemble_imported_code_prompt, assemble_prompt from access_token import validate_access_token from datetime import datetime import json -from routes.logging_utils import send_to_saas_backend +from routes.logging_utils import PaymentMethod, send_to_saas_backend +from routes.saas_utils import does_user_have_subscription_credits from utils import pprint_prompt # type: ignore @@ -62,6 +63,9 @@ async def stream_code(websocket: WebSocket): generated_code_config = params["generatedCodeConfig"] print(f"Generating {generated_code_config} code") + # Track how this generation is being paid for + payment_method: PaymentMethod = PaymentMethod.UNKNOWN + # Get the OpenAI API key from the request. Fall back to environment variable if not provided. # If neither is provided, we throw an error. openai_api_key = None @@ -69,6 +73,7 @@ async def stream_code(websocket: WebSocket): print("Access code - using platform API key") res = await validate_access_token(params["accessCode"]) if res["success"]: + payment_method = PaymentMethod.ACCESS_CODE openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY") else: await websocket.send_json( @@ -79,10 +84,32 @@ async def stream_code(websocket: WebSocket): ) return else: + auth_token = params.get("authToken") + if auth_token: + # TODO: Rename does_user_have_subscription_credits + res = await does_user_have_subscription_credits(auth_token) + if res.status == "not_subscriber": + # Keep going for non-subscriber users + pass + elif res.status == "subscriber_has_credits": + payment_method = PaymentMethod.SUBSCRIPTION + openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY") + elif res.status == "subscriber_has_no_credits": + return await throw_error( + "Your subscription has run out of monthly credits. Contact support and we can add more credits to your account for free." + ) + else: + return await throw_error("Unknown error occurred. Contact support.") + + else: + # Log but keep going for users + print("Missing auth token") + if params["openAiApiKey"]: openai_api_key = params["openAiApiKey"] + payment_method = PaymentMethod.OPENAI_API_KEY print("Using OpenAI API key from client-side settings dialog") - else: + elif not openai_api_key: openai_api_key = os.environ.get("OPENAI_API_KEY") if openai_api_key: print("Using OpenAI API key from environment variable") @@ -238,7 +265,12 @@ async def stream_code(websocket: WebSocket): if IS_PROD: # Catch any errors from sending to SaaS backend and continue try: - await send_to_saas_backend(prompt_messages, completion, params["authToken"]) + await send_to_saas_backend( + prompt_messages, + completion, + payment_method=payment_method, + auth_token=params["authToken"], + ) except Exception as e: print("Error sending to SaaS backend", e) diff --git a/backend/routes/logging_utils.py b/backend/routes/logging_utils.py index b13ce39..8d16558 100644 --- a/backend/routes/logging_utils.py +++ b/backend/routes/logging_utils.py @@ -1,3 +1,4 @@ +from enum import Enum import httpx from openai.types.chat import ChatCompletionMessageParam from typing import List @@ -6,20 +7,30 @@ import json from config import IS_PROD +class PaymentMethod(Enum): + LEGACY = "legacy" + UNKNOWN = "unknown" + OPENAI_API_KEY = "openai_api_key" + SUBSCRIPTION = "subscription" + ACCESS_CODE = "access_code" + + async def send_to_saas_backend( prompt_messages: List[ChatCompletionMessageParam], completion: str, + payment_method: PaymentMethod, auth_token: str | None = None, ): if IS_PROD: async with httpx.AsyncClient() as client: - url = "https://screenshot-to-code-saas.onrender.com/generations/store" - # url = "http://localhost:8001/generations/store" + # url = "https://screenshot-to-code-saas.onrender.com/generations/store" + url = "http://localhost:8001/generations/store" data = json.dumps( { "prompt": json.dumps(prompt_messages), "completion": completion, + "payment_method": payment_method.value, } ) diff --git a/backend/routes/saas_utils.py b/backend/routes/saas_utils.py new file mode 100644 index 0000000..94ce340 --- /dev/null +++ b/backend/routes/saas_utils.py @@ -0,0 +1,23 @@ +import httpx +from pydantic import BaseModel + + +class SubscriptionCreditsResponse(BaseModel): + status: str + + +async def does_user_have_subscription_credits( + auth_token: str, +): + async with httpx.AsyncClient() as client: + # url = "https://screenshot-to-code-saas.onrender.com/credits/has_credits" + url = "http://localhost:8001/credits/has_credits" + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {auth_token}", + } + + response = await client.post(url, headers=headers) + parsed_response = SubscriptionCreditsResponse.parse_obj(response.json()) + return parsed_response