support subscription credits and store payment method for each generation
This commit is contained in:
parent
730e58da72
commit
1a5f05d574
@ -12,7 +12,8 @@ from prompts import assemble_imported_code_prompt, assemble_prompt
|
|||||||
from access_token import validate_access_token
|
from access_token import validate_access_token
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
from routes.logging_utils import send_to_saas_backend
|
from routes.logging_utils import PaymentMethod, send_to_saas_backend
|
||||||
|
from routes.saas_utils import does_user_have_subscription_credits
|
||||||
|
|
||||||
from utils import pprint_prompt # type: ignore
|
from utils import pprint_prompt # type: ignore
|
||||||
|
|
||||||
@ -62,6 +63,9 @@ async def stream_code(websocket: WebSocket):
|
|||||||
generated_code_config = params["generatedCodeConfig"]
|
generated_code_config = params["generatedCodeConfig"]
|
||||||
print(f"Generating {generated_code_config} code")
|
print(f"Generating {generated_code_config} code")
|
||||||
|
|
||||||
|
# Track how this generation is being paid for
|
||||||
|
payment_method: PaymentMethod = PaymentMethod.UNKNOWN
|
||||||
|
|
||||||
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
||||||
# If neither is provided, we throw an error.
|
# If neither is provided, we throw an error.
|
||||||
openai_api_key = None
|
openai_api_key = None
|
||||||
@ -69,6 +73,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
print("Access code - using platform API key")
|
print("Access code - using platform API key")
|
||||||
res = await validate_access_token(params["accessCode"])
|
res = await validate_access_token(params["accessCode"])
|
||||||
if res["success"]:
|
if res["success"]:
|
||||||
|
payment_method = PaymentMethod.ACCESS_CODE
|
||||||
openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY")
|
openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY")
|
||||||
else:
|
else:
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
@ -79,10 +84,32 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
auth_token = params.get("authToken")
|
||||||
|
if auth_token:
|
||||||
|
# TODO: Rename does_user_have_subscription_credits
|
||||||
|
res = await does_user_have_subscription_credits(auth_token)
|
||||||
|
if res.status == "not_subscriber":
|
||||||
|
# Keep going for non-subscriber users
|
||||||
|
pass
|
||||||
|
elif res.status == "subscriber_has_credits":
|
||||||
|
payment_method = PaymentMethod.SUBSCRIPTION
|
||||||
|
openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY")
|
||||||
|
elif res.status == "subscriber_has_no_credits":
|
||||||
|
return await throw_error(
|
||||||
|
"Your subscription has run out of monthly credits. Contact support and we can add more credits to your account for free."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await throw_error("Unknown error occurred. Contact support.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Log but keep going for users
|
||||||
|
print("Missing auth token")
|
||||||
|
|
||||||
if params["openAiApiKey"]:
|
if params["openAiApiKey"]:
|
||||||
openai_api_key = params["openAiApiKey"]
|
openai_api_key = params["openAiApiKey"]
|
||||||
|
payment_method = PaymentMethod.OPENAI_API_KEY
|
||||||
print("Using OpenAI API key from client-side settings dialog")
|
print("Using OpenAI API key from client-side settings dialog")
|
||||||
else:
|
elif not openai_api_key:
|
||||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
if openai_api_key:
|
if openai_api_key:
|
||||||
print("Using OpenAI API key from environment variable")
|
print("Using OpenAI API key from environment variable")
|
||||||
@ -238,7 +265,12 @@ async def stream_code(websocket: WebSocket):
|
|||||||
if IS_PROD:
|
if IS_PROD:
|
||||||
# Catch any errors from sending to SaaS backend and continue
|
# Catch any errors from sending to SaaS backend and continue
|
||||||
try:
|
try:
|
||||||
await send_to_saas_backend(prompt_messages, completion, params["authToken"])
|
await send_to_saas_backend(
|
||||||
|
prompt_messages,
|
||||||
|
completion,
|
||||||
|
payment_method=payment_method,
|
||||||
|
auth_token=params["authToken"],
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error sending to SaaS backend", e)
|
print("Error sending to SaaS backend", e)
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from enum import Enum
|
||||||
import httpx
|
import httpx
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
from typing import List
|
from typing import List
|
||||||
@ -6,20 +7,30 @@ import json
|
|||||||
from config import IS_PROD
|
from config import IS_PROD
|
||||||
|
|
||||||
|
|
||||||
|
class PaymentMethod(Enum):
|
||||||
|
LEGACY = "legacy"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
OPENAI_API_KEY = "openai_api_key"
|
||||||
|
SUBSCRIPTION = "subscription"
|
||||||
|
ACCESS_CODE = "access_code"
|
||||||
|
|
||||||
|
|
||||||
async def send_to_saas_backend(
|
async def send_to_saas_backend(
|
||||||
prompt_messages: List[ChatCompletionMessageParam],
|
prompt_messages: List[ChatCompletionMessageParam],
|
||||||
completion: str,
|
completion: str,
|
||||||
|
payment_method: PaymentMethod,
|
||||||
auth_token: str | None = None,
|
auth_token: str | None = None,
|
||||||
):
|
):
|
||||||
if IS_PROD:
|
if IS_PROD:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
url = "https://screenshot-to-code-saas.onrender.com/generations/store"
|
# url = "https://screenshot-to-code-saas.onrender.com/generations/store"
|
||||||
# url = "http://localhost:8001/generations/store"
|
url = "http://localhost:8001/generations/store"
|
||||||
|
|
||||||
data = json.dumps(
|
data = json.dumps(
|
||||||
{
|
{
|
||||||
"prompt": json.dumps(prompt_messages),
|
"prompt": json.dumps(prompt_messages),
|
||||||
"completion": completion,
|
"completion": completion,
|
||||||
|
"payment_method": payment_method.value,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
23
backend/routes/saas_utils.py
Normal file
23
backend/routes/saas_utils.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import httpx
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionCreditsResponse(BaseModel):
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
async def does_user_have_subscription_credits(
|
||||||
|
auth_token: str,
|
||||||
|
):
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# url = "https://screenshot-to-code-saas.onrender.com/credits/has_credits"
|
||||||
|
url = "http://localhost:8001/credits/has_credits"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {auth_token}",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(url, headers=headers)
|
||||||
|
parsed_response = SubscriptionCreditsResponse.parse_obj(response.json())
|
||||||
|
return parsed_response
|
||||||
Loading…
Reference in New Issue
Block a user