support all forms of payment (subscription, api key, etc.)
This commit is contained in:
parent
c107f4eda5
commit
86f6781682
@ -25,6 +25,8 @@ IS_PROD = os.environ.get("IS_PROD", False)
|
|||||||
|
|
||||||
# Hosted version only
|
# Hosted version only
|
||||||
|
|
||||||
|
PLATFORM_OPENAI_API_KEY = os.environ.get("PLATFORM_OPENAI_API_KEY", "")
|
||||||
|
PLATFORM_ANTHROPIC_API_KEY = os.environ.get("PLATFORM_ANTHROPIC_API_KEY", "")
|
||||||
PLATFORM_SCREENSHOTONE_API_KEY = os.environ.get("PLATFORM_SCREENSHOTONE_API_KEY", "")
|
PLATFORM_SCREENSHOTONE_API_KEY = os.environ.get("PLATFORM_SCREENSHOTONE_API_KEY", "")
|
||||||
|
|
||||||
BACKEND_SAAS_URL = os.environ.get("BACKEND_SAAS_URL", "")
|
BACKEND_SAAS_URL = os.environ.get("BACKEND_SAAS_URL", "")
|
||||||
|
|||||||
@ -5,11 +5,11 @@ from fastapi import APIRouter, WebSocket
|
|||||||
import openai
|
import openai
|
||||||
from codegen.utils import extract_html_content
|
from codegen.utils import extract_html_content
|
||||||
from config import (
|
from config import (
|
||||||
ANTHROPIC_API_KEY,
|
|
||||||
IS_PROD,
|
IS_PROD,
|
||||||
NUM_VARIANTS,
|
NUM_VARIANTS,
|
||||||
OPENAI_API_KEY,
|
|
||||||
OPENAI_BASE_URL,
|
OPENAI_BASE_URL,
|
||||||
|
PLATFORM_ANTHROPIC_API_KEY,
|
||||||
|
PLATFORM_OPENAI_API_KEY,
|
||||||
REPLICATE_API_KEY,
|
REPLICATE_API_KEY,
|
||||||
SHOULD_MOCK_AI_RESPONSE,
|
SHOULD_MOCK_AI_RESPONSE,
|
||||||
)
|
)
|
||||||
@ -128,12 +128,14 @@ async def extract_params(
|
|||||||
await throw_error(f"Invalid model: {code_generation_model_str}")
|
await throw_error(f"Invalid model: {code_generation_model_str}")
|
||||||
raise ValueError(f"Invalid model: {code_generation_model_str}")
|
raise ValueError(f"Invalid model: {code_generation_model_str}")
|
||||||
|
|
||||||
|
# Read the auth token from the request (on the hosted version)
|
||||||
auth_token = params.get("authToken")
|
auth_token = params.get("authToken")
|
||||||
if not auth_token:
|
if not auth_token:
|
||||||
await throw_error("You need to be logged in to use screenshot to code")
|
await throw_error("You need to be logged in to use screenshot to code")
|
||||||
raise Exception("No auth token")
|
raise Exception("No auth token")
|
||||||
|
|
||||||
openai_api_key = None
|
openai_api_key = None
|
||||||
|
anthropic_api_key = None
|
||||||
|
|
||||||
# Track how this generation is being paid for
|
# Track how this generation is being paid for
|
||||||
payment_method: PaymentMethod = PaymentMethod.UNKNOWN
|
payment_method: PaymentMethod = PaymentMethod.UNKNOWN
|
||||||
@ -151,7 +153,8 @@ async def extract_params(
|
|||||||
if res.status == "subscriber_has_credits"
|
if res.status == "subscriber_has_credits"
|
||||||
else PaymentMethod.TRIAL
|
else PaymentMethod.TRIAL
|
||||||
)
|
)
|
||||||
openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY")
|
openai_api_key = PLATFORM_OPENAI_API_KEY
|
||||||
|
anthropic_api_key = PLATFORM_ANTHROPIC_API_KEY
|
||||||
print("Subscription - using platform API key")
|
print("Subscription - using platform API key")
|
||||||
elif res.status == "subscriber_has_no_credits":
|
elif res.status == "subscriber_has_no_credits":
|
||||||
await throw_error(
|
await throw_error(
|
||||||
@ -161,32 +164,20 @@ async def extract_params(
|
|||||||
await throw_error("Unknown error occurred. Contact support.")
|
await throw_error("Unknown error occurred. Contact support.")
|
||||||
raise Exception("Unknown error occurred when checking subscription credits")
|
raise Exception("Unknown error occurred when checking subscription credits")
|
||||||
|
|
||||||
# If we still don't have an API key, use the user's API key from environment variable
|
# Use the user's OpenAI API key from the settings dialog if they are not a subscriber
|
||||||
if not openai_api_key:
|
if not openai_api_key:
|
||||||
openai_api_key = get_from_settings_dialog_or_env(
|
openai_api_key = get_from_settings_dialog_or_env(params, "openAiApiKey", None)
|
||||||
params, "openAiApiKey", OPENAI_API_KEY
|
|
||||||
)
|
|
||||||
payment_method = PaymentMethod.OPENAI_API_KEY
|
|
||||||
if openai_api_key:
|
if openai_api_key:
|
||||||
# TODO
|
payment_method = PaymentMethod.OPENAI_API_KEY
|
||||||
print("Using OpenAI API key from environment variable")
|
print("Using OpenAI API key from user's settings dialog")
|
||||||
|
|
||||||
# if not openai_api_key and (
|
print("Payment method: ", payment_method)
|
||||||
# code_generation_model == Llm.GPT_4_VISION
|
|
||||||
# or code_generation_model == Llm.GPT_4_TURBO_2024_04_09
|
|
||||||
# or code_generation_model == Llm.GPT_4O_2024_05_13
|
|
||||||
# ):
|
|
||||||
# print("OpenAI API key not found")
|
|
||||||
# await throw_error(
|
|
||||||
# "Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support."
|
|
||||||
# )
|
|
||||||
# raise Exception("No OpenAI API key found")
|
|
||||||
|
|
||||||
# TODO: Do not allow usage of key
|
if payment_method is PaymentMethod.UNKNOWN:
|
||||||
# If neither is provided, we throw an error later only if Claude is used.
|
await throw_error(
|
||||||
anthropic_api_key = get_from_settings_dialog_or_env(
|
"Please subscribe to a paid plan to generate code. If you are a subscriber and seeing this error, please contact support."
|
||||||
params, "anthropicApiKey", ANTHROPIC_API_KEY
|
|
||||||
)
|
)
|
||||||
|
raise Exception("No payment method found")
|
||||||
|
|
||||||
# Base URL for OpenAI API
|
# Base URL for OpenAI API
|
||||||
openai_base_url: str | None = None
|
openai_base_url: str | None = None
|
||||||
@ -327,7 +318,6 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# Depending on the presence and absence of various keys,
|
# Depending on the presence and absence of various keys,
|
||||||
# we decide which models to run
|
# we decide which models to run
|
||||||
variant_models = []
|
variant_models = []
|
||||||
@ -343,7 +333,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
raise Exception("No OpenAI or Anthropic key")
|
raise Exception("No OpenAI or Anthropic key")
|
||||||
|
|
||||||
tasks: List[Coroutine[Any, Any, str]] = []
|
tasks: list[Coroutine[Any, Any, str]] = []
|
||||||
for index, model in enumerate(variant_models):
|
for index, model in enumerate(variant_models):
|
||||||
if model == "openai":
|
if model == "openai":
|
||||||
if openai_api_key is None:
|
if openai_api_key is None:
|
||||||
@ -430,7 +420,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
completions[0],
|
completions[0],
|
||||||
payment_method=payment_method,
|
payment_method=payment_method,
|
||||||
# TODO*
|
# TODO*
|
||||||
llm_version=Llm.CLAUDE_3_5_SONNET_2024_06_20,
|
llm_version=Llm.GPT_4O_2024_05_13,
|
||||||
stack=stack,
|
stack=stack,
|
||||||
is_imported_from_code=bool(params.get("isImportedFromCode", False)),
|
is_imported_from_code=bool(params.get("isImportedFromCode", False)),
|
||||||
includes_result_image=bool(params.get("resultImage", False)),
|
includes_result_image=bool(params.get("resultImage", False)),
|
||||||
@ -441,7 +431,6 @@ async def stream_code(websocket: WebSocket):
|
|||||||
print("Error sending to SaaS backend", e)
|
print("Error sending to SaaS backend", e)
|
||||||
|
|
||||||
## Image Generation
|
## Image Generation
|
||||||
|
|
||||||
for index, _ in enumerate(completions):
|
for index, _ in enumerate(completions):
|
||||||
await send_message("status", "Generating images...", index)
|
await send_message("status", "Generating images...", index)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user