support all forms of payment (subscription, api key, etc.)

This commit is contained in:
Abi Raja 2024-09-10 14:28:16 +02:00
parent c107f4eda5
commit 86f6781682
2 changed files with 20 additions and 29 deletions

View File

@ -25,6 +25,8 @@ IS_PROD = os.environ.get("IS_PROD", False)
# Hosted version only
PLATFORM_OPENAI_API_KEY = os.environ.get("PLATFORM_OPENAI_API_KEY", "")
PLATFORM_ANTHROPIC_API_KEY = os.environ.get("PLATFORM_ANTHROPIC_API_KEY", "")
PLATFORM_SCREENSHOTONE_API_KEY = os.environ.get("PLATFORM_SCREENSHOTONE_API_KEY", "")
BACKEND_SAAS_URL = os.environ.get("BACKEND_SAAS_URL", "")

View File

@ -5,11 +5,11 @@ from fastapi import APIRouter, WebSocket
import openai
from codegen.utils import extract_html_content
from config import (
ANTHROPIC_API_KEY,
IS_PROD,
NUM_VARIANTS,
OPENAI_API_KEY,
OPENAI_BASE_URL,
PLATFORM_ANTHROPIC_API_KEY,
PLATFORM_OPENAI_API_KEY,
REPLICATE_API_KEY,
SHOULD_MOCK_AI_RESPONSE,
)
@ -128,12 +128,14 @@ async def extract_params(
await throw_error(f"Invalid model: {code_generation_model_str}")
raise ValueError(f"Invalid model: {code_generation_model_str}")
# Read the auth token from the request (on the hosted version)
auth_token = params.get("authToken")
if not auth_token:
await throw_error("You need to be logged in to use screenshot to code")
raise Exception("No auth token")
openai_api_key = None
anthropic_api_key = None
# Track how this generation is being paid for
payment_method: PaymentMethod = PaymentMethod.UNKNOWN
@ -151,7 +153,8 @@ async def extract_params(
if res.status == "subscriber_has_credits"
else PaymentMethod.TRIAL
)
openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY")
openai_api_key = PLATFORM_OPENAI_API_KEY
anthropic_api_key = PLATFORM_ANTHROPIC_API_KEY
print("Subscription - using platform API key")
elif res.status == "subscriber_has_no_credits":
await throw_error(
@ -161,32 +164,20 @@ async def extract_params(
await throw_error("Unknown error occurred. Contact support.")
raise Exception("Unknown error occurred when checking subscription credits")
# If we still don't have an API key, use the user's API key from environment variable
# Use the user's OpenAI API key from the settings dialog if they are not a subscriber
if not openai_api_key:
openai_api_key = get_from_settings_dialog_or_env(
params, "openAiApiKey", OPENAI_API_KEY
)
payment_method = PaymentMethod.OPENAI_API_KEY
openai_api_key = get_from_settings_dialog_or_env(params, "openAiApiKey", None)
if openai_api_key:
# TODO
print("Using OpenAI API key from environment variable")
payment_method = PaymentMethod.OPENAI_API_KEY
print("Using OpenAI API key from user's settings dialog")
# if not openai_api_key and (
# code_generation_model == Llm.GPT_4_VISION
# or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
# or code_generation_model == Llm.GPT_4O_2024_05_13
# ):
# print("OpenAI API key not found")
# await throw_error(
# "Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support."
# )
# raise Exception("No OpenAI API key found")
print("Payment method: ", payment_method)
# TODO: Do not allow usage of key
# If neither is provided, we throw an error later only if Claude is used.
anthropic_api_key = get_from_settings_dialog_or_env(
params, "anthropicApiKey", ANTHROPIC_API_KEY
)
if payment_method is PaymentMethod.UNKNOWN:
await throw_error(
"Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support."
)
raise Exception("No payment method found")
# Base URL for OpenAI API
openai_base_url: str | None = None
@ -327,7 +318,6 @@ async def stream_code(websocket: WebSocket):
)
]
else:
# Depending on the presence and absence of various keys,
# we decide which models to run
variant_models = []
@ -343,7 +333,7 @@ async def stream_code(websocket: WebSocket):
)
raise Exception("No OpenAI or Anthropic key")
tasks: List[Coroutine[Any, Any, str]] = []
tasks: list[Coroutine[Any, Any, str]] = []
for index, model in enumerate(variant_models):
if model == "openai":
if openai_api_key is None:
@ -430,7 +420,7 @@ async def stream_code(websocket: WebSocket):
completions[0],
payment_method=payment_method,
# TODO*
llm_version=Llm.CLAUDE_3_5_SONNET_2024_06_20,
llm_version=Llm.GPT_4O_2024_05_13,
stack=stack,
is_imported_from_code=bool(params.get("isImportedFromCode", False)),
includes_result_image=bool(params.get("resultImage", False)),
@ -441,7 +431,6 @@ async def stream_code(websocket: WebSocket):
print("Error sending to SaaS backend", e)
## Image Generation
for index, _ in enumerate(completions):
await send_message("status", "Generating images...", index)