support subscription credits and store payment method for each generation

This commit is contained in:
Abi Raja 2023-12-19 11:40:00 -05:00
parent 730e58da72
commit 1a5f05d574
3 changed files with 71 additions and 5 deletions

View File

@ -12,7 +12,8 @@ from prompts import assemble_imported_code_prompt, assemble_prompt
from access_token import validate_access_token
from datetime import datetime
import json
from routes.logging_utils import send_to_saas_backend
from routes.logging_utils import PaymentMethod, send_to_saas_backend
from routes.saas_utils import does_user_have_subscription_credits
from utils import pprint_prompt # type: ignore
@ -62,6 +63,9 @@ async def stream_code(websocket: WebSocket):
generated_code_config = params["generatedCodeConfig"]
print(f"Generating {generated_code_config} code")
# Track how this generation is being paid for
payment_method: PaymentMethod = PaymentMethod.UNKNOWN
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
# If neither is provided, we throw an error.
openai_api_key = None
@ -69,6 +73,7 @@ async def stream_code(websocket: WebSocket):
print("Access code - using platform API key")
res = await validate_access_token(params["accessCode"])
if res["success"]:
payment_method = PaymentMethod.ACCESS_CODE
openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY")
else:
await websocket.send_json(
@ -79,10 +84,32 @@ async def stream_code(websocket: WebSocket):
)
return
else:
auth_token = params.get("authToken")
if auth_token:
# TODO: Rename does_user_have_subscription_credits
res = await does_user_have_subscription_credits(auth_token)
if res.status == "not_subscriber":
# Keep going for non-subscriber users
pass
elif res.status == "subscriber_has_credits":
payment_method = PaymentMethod.SUBSCRIPTION
openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY")
elif res.status == "subscriber_has_no_credits":
return await throw_error(
"Your subscription has run out of monthly credits. Contact support and we can add more credits to your account for free."
)
else:
return await throw_error("Unknown error occurred. Contact support.")
else:
# Log but keep going for users
print("Missing auth token")
if params["openAiApiKey"]:
openai_api_key = params["openAiApiKey"]
payment_method = PaymentMethod.OPENAI_API_KEY
print("Using OpenAI API key from client-side settings dialog")
else:
elif not openai_api_key:
openai_api_key = os.environ.get("OPENAI_API_KEY")
if openai_api_key:
print("Using OpenAI API key from environment variable")
@ -238,7 +265,12 @@ async def stream_code(websocket: WebSocket):
if IS_PROD:
# Catch any errors from sending to SaaS backend and continue
try:
await send_to_saas_backend(prompt_messages, completion, params["authToken"])
await send_to_saas_backend(
prompt_messages,
completion,
payment_method=payment_method,
auth_token=params["authToken"],
)
except Exception as e:
print("Error sending to SaaS backend", e)

View File

@ -1,3 +1,4 @@
from enum import Enum
import httpx
from openai.types.chat import ChatCompletionMessageParam
from typing import List
@ -6,20 +7,30 @@ import json
from config import IS_PROD
class PaymentMethod(Enum):
LEGACY = "legacy"
UNKNOWN = "unknown"
OPENAI_API_KEY = "openai_api_key"
SUBSCRIPTION = "subscription"
ACCESS_CODE = "access_code"
async def send_to_saas_backend(
prompt_messages: List[ChatCompletionMessageParam],
completion: str,
payment_method: PaymentMethod,
auth_token: str | None = None,
):
if IS_PROD:
async with httpx.AsyncClient() as client:
url = "https://screenshot-to-code-saas.onrender.com/generations/store"
# url = "http://localhost:8001/generations/store"
# url = "https://screenshot-to-code-saas.onrender.com/generations/store"
url = "http://localhost:8001/generations/store"
data = json.dumps(
{
"prompt": json.dumps(prompt_messages),
"completion": completion,
"payment_method": payment_method.value,
}
)

View File

@ -0,0 +1,23 @@
import httpx
from pydantic import BaseModel
class SubscriptionCreditsResponse(BaseModel):
status: str
async def does_user_have_subscription_credits(
auth_token: str,
):
async with httpx.AsyncClient() as client:
# url = "https://screenshot-to-code-saas.onrender.com/credits/has_credits"
url = "http://localhost:8001/credits/has_credits"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {auth_token}",
}
response = await client.post(url, headers=headers)
parsed_response = SubscriptionCreditsResponse.parse_obj(response.json())
return parsed_response