diff --git a/backend/config.py b/backend/config.py index 19bdca4..30da119 100644 --- a/backend/config.py +++ b/backend/config.py @@ -5,6 +5,7 @@ import os OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None) ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None) +OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", None) # Image generation (optional) REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 48709d5..8b8d45a 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -1,5 +1,5 @@ -import os import asyncio +from dataclasses import dataclass from fastapi import APIRouter, WebSocket import openai from codegen.utils import extract_html_content @@ -7,6 +7,7 @@ from config import ( ANTHROPIC_API_KEY, IS_PROD, OPENAI_API_KEY, + OPENAI_BASE_URL, REPLICATE_API_KEY, SHOULD_MOCK_AI_RESPONSE, ) @@ -21,12 +22,13 @@ from llm import ( ) from fs_logging.core import write_logs from mock_llm import mock_completion -from typing import Any, Coroutine, Dict, List, Literal, Union, cast, get_args +from typing import Any, Callable, Coroutine, Dict, List, Literal, cast, get_args from image_generation.core import generate_images from prompts import create_prompt from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack -from utils import pprint_prompt + +# from utils import pprint_prompt from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore @@ -83,10 +85,95 @@ async def perform_image_generation( ) +@dataclass +class ExtractedParams: + stack: Stack + input_mode: InputMode + code_generation_model: Llm + should_generate_images: bool + openai_api_key: str | None + anthropic_api_key: str | None + openai_base_url: str | None + + +async def extract_params( + params: Dict[str, str], throw_error: Callable[[str], Coroutine[Any, Any, None]] +) -> ExtractedParams: + # Read the code config settings (stack) from the request. + generated_code_config = params.get("generatedCodeConfig", "") + if generated_code_config not in get_args(Stack): + await throw_error(f"Invalid generated code config: {generated_code_config}") + raise ValueError(f"Invalid generated code config: {generated_code_config}") + validated_stack = cast(Stack, generated_code_config) + + # Validate the input mode + input_mode = params.get("inputMode") + if input_mode not in get_args(InputMode): + await throw_error(f"Invalid input mode: {input_mode}") + raise ValueError(f"Invalid input mode: {input_mode}") + validated_input_mode = cast(InputMode, input_mode) + + # Read the model from the request. Fall back to default if not provided. + code_generation_model_str = params.get( + "codeGenerationModel", Llm.GPT_4O_2024_05_13.value + ) + try: + code_generation_model = convert_frontend_str_to_llm(code_generation_model_str) + except ValueError: + await throw_error(f"Invalid model: {code_generation_model_str}") + raise ValueError(f"Invalid model: {code_generation_model_str}") + + openai_api_key = get_from_settings_dialog_or_env( + params, "openAiApiKey", OPENAI_API_KEY + ) + + # If neither is provided, we throw an error later only if Claude is used. + anthropic_api_key = get_from_settings_dialog_or_env( + params, "anthropicApiKey", ANTHROPIC_API_KEY + ) + + # Base URL for OpenAI API + openai_base_url: str | None = None + # Disable user-specified OpenAI Base URL in prod + if not IS_PROD: + openai_base_url = get_from_settings_dialog_or_env( + params, "openAiBaseURL", OPENAI_BASE_URL + ) + if not openai_base_url: + print("Using official OpenAI URL") + + # Get the image generation flag from the request. Fall back to True if not provided. + should_generate_images = bool(params.get("isImageGenerationEnabled", True)) + + return ExtractedParams( + stack=validated_stack, + input_mode=validated_input_mode, + code_generation_model=code_generation_model, + should_generate_images=should_generate_images, + openai_api_key=openai_api_key, + anthropic_api_key=anthropic_api_key, + openai_base_url=openai_base_url, + ) + + +def get_from_settings_dialog_or_env( + params: dict[str, str], key: str, env_var: str | None +) -> str | None: + value = params.get(key) + if value: + print(f"Using {key} from client-side settings dialog") + return value + + if env_var: + print(f"Using {key} from environment variable") + return env_var + + return None + + @router.websocket("/generate-code") async def stream_code(websocket: WebSocket): await websocket.accept() - print("Incoming websocket connection...") ## Communication protocol setup @@ -112,89 +199,37 @@ async def stream_code(websocket: WebSocket): {"type": type, "value": value, "variantIndex": variantIndex} ) - ## Parameter validation + ## Parameter extract and validation # TODO: Are the values always strings? - params: Dict[str, str] = await websocket.receive_json() + params: dict[str, str] = await websocket.receive_json() print("Received params") - # Read the code config settings (stack) from the request. - generated_code_config = params.get("generatedCodeConfig", "") - if not generated_code_config in get_args(Stack): - await throw_error(f"Invalid generated code config: {generated_code_config}") - raise Exception(f"Invalid generated code config: {generated_code_config}") - # Cast the variable to the Stack type - valid_stack = cast(Stack, generated_code_config) - - # Validate the input mode - input_mode = params.get("inputMode") - if not input_mode in get_args(InputMode): - await throw_error(f"Invalid input mode: {input_mode}") - raise Exception(f"Invalid input mode: {input_mode}") - # Cast the variable to the right type - validated_input_mode = cast(InputMode, input_mode) - - # Read the model from the request. Fall back to default if not provided. - code_generation_model_str = params.get( - "codeGenerationModel", Llm.GPT_4O_2024_05_13.value - ) - try: - code_generation_model = convert_frontend_str_to_llm(code_generation_model_str) - except: - await throw_error(f"Invalid model: {code_generation_model_str}") - raise Exception(f"Invalid model: {code_generation_model_str}") + extracted_params = await extract_params(params, throw_error) + # TODO(*): Rename to stack and input_mode + valid_stack = extracted_params.stack + validated_input_mode = extracted_params.input_mode + code_generation_model = extracted_params.code_generation_model + openai_api_key = extracted_params.openai_api_key + openai_base_url = extracted_params.openai_base_url + anthropic_api_key = extracted_params.anthropic_api_key + should_generate_images = extracted_params.should_generate_images # Auto-upgrade usage of older models code_generation_model = auto_upgrade_model(code_generation_model) + print( - f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." + f"Generating {valid_stack} code in {validated_input_mode} mode using {code_generation_model}..." ) - # Get the OpenAI API key from the request. Fall back to environment variable if not provided. - # If neither is provided, we throw an error. - openai_api_key = params.get("openAiApiKey") - if openai_api_key: - print("Using OpenAI API key from client-side settings dialog") - else: - openai_api_key = OPENAI_API_KEY - if openai_api_key: - print("Using OpenAI API key from environment variable") - + # TODO(*): Do I still need this? if not openai_api_key and is_openai_model(code_generation_model): await throw_error( "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server." ) return - # Get the Anthropic API key from the request. Fall back to environment variable if not provided. - # If neither is provided, we throw an error later only if Claude is used. - anthropic_api_key = params.get("anthropicApiKey") - if anthropic_api_key: - print("Using Anthropic API key from client-side settings dialog") - else: - anthropic_api_key = ANTHROPIC_API_KEY - if anthropic_api_key: - print("Using Anthropic API key from environment variable") - - # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. - openai_base_url: Union[str, None] = None - # Disable user-specified OpenAI Base URL in prod - if not os.environ.get("IS_PROD"): - openai_base_url = params.get("openAiBaseURL") - if openai_base_url: - print("Using OpenAI Base URL from client-side settings dialog") - else: - openai_base_url = os.environ.get("OPENAI_BASE_URL") - if openai_base_url: - print("Using OpenAI Base URL from environment variable") - - if not openai_base_url: - print("Using official OpenAI URL") - - # Get the image generation flag from the request. Fall back to True if not provided. - should_generate_images = bool(params.get("isImageGenerationEnabled", True)) - - # TODO(*): Print with send_message instead of print statements + # TODO(*): Don't assume number of variants await send_message("status", "Generating code...", 0) await send_message("status", "Generating code...", 1) @@ -213,7 +248,7 @@ async def stream_code(websocket: WebSocket): ) raise - pprint_prompt(prompt_messages) # type: ignore + # pprint_prompt(prompt_messages) # type: ignore ### Code generation