support subscription credits and store payment method for each generation
This commit is contained in:
parent
730e58da72
commit
1a5f05d574
@ -12,7 +12,8 @@ from prompts import assemble_imported_code_prompt, assemble_prompt
|
||||
from access_token import validate_access_token
|
||||
from datetime import datetime
|
||||
import json
|
||||
from routes.logging_utils import send_to_saas_backend
|
||||
from routes.logging_utils import PaymentMethod, send_to_saas_backend
|
||||
from routes.saas_utils import does_user_have_subscription_credits
|
||||
|
||||
from utils import pprint_prompt # type: ignore
|
||||
|
||||
@ -62,6 +63,9 @@ async def stream_code(websocket: WebSocket):
|
||||
generated_code_config = params["generatedCodeConfig"]
|
||||
print(f"Generating {generated_code_config} code")
|
||||
|
||||
# Track how this generation is being paid for
|
||||
payment_method: PaymentMethod = PaymentMethod.UNKNOWN
|
||||
|
||||
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
||||
# If neither is provided, we throw an error.
|
||||
openai_api_key = None
|
||||
@ -69,6 +73,7 @@ async def stream_code(websocket: WebSocket):
|
||||
print("Access code - using platform API key")
|
||||
res = await validate_access_token(params["accessCode"])
|
||||
if res["success"]:
|
||||
payment_method = PaymentMethod.ACCESS_CODE
|
||||
openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY")
|
||||
else:
|
||||
await websocket.send_json(
|
||||
@ -79,10 +84,32 @@ async def stream_code(websocket: WebSocket):
|
||||
)
|
||||
return
|
||||
else:
|
||||
auth_token = params.get("authToken")
|
||||
if auth_token:
|
||||
# TODO: Rename does_user_have_subscription_credits
|
||||
res = await does_user_have_subscription_credits(auth_token)
|
||||
if res.status == "not_subscriber":
|
||||
# Keep going for non-subscriber users
|
||||
pass
|
||||
elif res.status == "subscriber_has_credits":
|
||||
payment_method = PaymentMethod.SUBSCRIPTION
|
||||
openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY")
|
||||
elif res.status == "subscriber_has_no_credits":
|
||||
return await throw_error(
|
||||
"Your subscription has run out of monthly credits. Contact support and we can add more credits to your account for free."
|
||||
)
|
||||
else:
|
||||
return await throw_error("Unknown error occurred. Contact support.")
|
||||
|
||||
else:
|
||||
# Log but keep going for users
|
||||
print("Missing auth token")
|
||||
|
||||
if params["openAiApiKey"]:
|
||||
openai_api_key = params["openAiApiKey"]
|
||||
payment_method = PaymentMethod.OPENAI_API_KEY
|
||||
print("Using OpenAI API key from client-side settings dialog")
|
||||
else:
|
||||
elif not openai_api_key:
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if openai_api_key:
|
||||
print("Using OpenAI API key from environment variable")
|
||||
@ -238,7 +265,12 @@ async def stream_code(websocket: WebSocket):
|
||||
if IS_PROD:
|
||||
# Catch any errors from sending to SaaS backend and continue
|
||||
try:
|
||||
await send_to_saas_backend(prompt_messages, completion, params["authToken"])
|
||||
await send_to_saas_backend(
|
||||
prompt_messages,
|
||||
completion,
|
||||
payment_method=payment_method,
|
||||
auth_token=params["authToken"],
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error sending to SaaS backend", e)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
import httpx
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from typing import List
|
||||
@ -6,20 +7,30 @@ import json
|
||||
from config import IS_PROD
|
||||
|
||||
|
||||
class PaymentMethod(Enum):
|
||||
LEGACY = "legacy"
|
||||
UNKNOWN = "unknown"
|
||||
OPENAI_API_KEY = "openai_api_key"
|
||||
SUBSCRIPTION = "subscription"
|
||||
ACCESS_CODE = "access_code"
|
||||
|
||||
|
||||
async def send_to_saas_backend(
|
||||
prompt_messages: List[ChatCompletionMessageParam],
|
||||
completion: str,
|
||||
payment_method: PaymentMethod,
|
||||
auth_token: str | None = None,
|
||||
):
|
||||
if IS_PROD:
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = "https://screenshot-to-code-saas.onrender.com/generations/store"
|
||||
# url = "http://localhost:8001/generations/store"
|
||||
# url = "https://screenshot-to-code-saas.onrender.com/generations/store"
|
||||
url = "http://localhost:8001/generations/store"
|
||||
|
||||
data = json.dumps(
|
||||
{
|
||||
"prompt": json.dumps(prompt_messages),
|
||||
"completion": completion,
|
||||
"payment_method": payment_method.value,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
23
backend/routes/saas_utils.py
Normal file
23
backend/routes/saas_utils.py
Normal file
@ -0,0 +1,23 @@
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SubscriptionCreditsResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
async def does_user_have_subscription_credits(
|
||||
auth_token: str,
|
||||
):
|
||||
async with httpx.AsyncClient() as client:
|
||||
# url = "https://screenshot-to-code-saas.onrender.com/credits/has_credits"
|
||||
url = "http://localhost:8001/credits/has_credits"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {auth_token}",
|
||||
}
|
||||
|
||||
response = await client.post(url, headers=headers)
|
||||
parsed_response = SubscriptionCreditsResponse.parse_obj(response.json())
|
||||
return parsed_response
|
||||
Loading…
Reference in New Issue
Block a user