move parameter extraction to separate fn

This commit is contained in:
Abi Raja 2024-07-31 15:46:53 -04:00
parent 823bd2e249
commit c76c7c202a
2 changed files with 108 additions and 72 deletions

View File

@ -5,6 +5,7 @@ import os
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None) OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None) ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", None)
# Image generation (optional) # Image generation (optional)
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None) REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)

View File

@ -1,5 +1,5 @@
import os
import asyncio import asyncio
from dataclasses import dataclass
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import openai import openai
from codegen.utils import extract_html_content from codegen.utils import extract_html_content
@ -7,6 +7,7 @@ from config import (
ANTHROPIC_API_KEY, ANTHROPIC_API_KEY,
IS_PROD, IS_PROD,
OPENAI_API_KEY, OPENAI_API_KEY,
OPENAI_BASE_URL,
REPLICATE_API_KEY, REPLICATE_API_KEY,
SHOULD_MOCK_AI_RESPONSE, SHOULD_MOCK_AI_RESPONSE,
) )
@ -21,12 +22,13 @@ from llm import (
) )
from fs_logging.core import write_logs from fs_logging.core import write_logs
from mock_llm import mock_completion from mock_llm import mock_completion
from typing import Any, Coroutine, Dict, List, Literal, Union, cast, get_args from typing import Any, Callable, Coroutine, Dict, List, Literal, cast, get_args
from image_generation.core import generate_images from image_generation.core import generate_images
from prompts import create_prompt from prompts import create_prompt
from prompts.claude_prompts import VIDEO_PROMPT from prompts.claude_prompts import VIDEO_PROMPT
from prompts.types import Stack from prompts.types import Stack
from utils import pprint_prompt
# from utils import pprint_prompt
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
@ -83,10 +85,95 @@ async def perform_image_generation(
) )
@dataclass
class ExtractedParams:
stack: Stack
input_mode: InputMode
code_generation_model: Llm
should_generate_images: bool
openai_api_key: str | None
anthropic_api_key: str | None
openai_base_url: str | None
async def extract_params(
params: Dict[str, str], throw_error: Callable[[str], Coroutine[Any, Any, None]]
) -> ExtractedParams:
# Read the code config settings (stack) from the request.
generated_code_config = params.get("generatedCodeConfig", "")
if generated_code_config not in get_args(Stack):
await throw_error(f"Invalid generated code config: {generated_code_config}")
raise ValueError(f"Invalid generated code config: {generated_code_config}")
validated_stack = cast(Stack, generated_code_config)
# Validate the input mode
input_mode = params.get("inputMode")
if input_mode not in get_args(InputMode):
await throw_error(f"Invalid input mode: {input_mode}")
raise ValueError(f"Invalid input mode: {input_mode}")
validated_input_mode = cast(InputMode, input_mode)
# Read the model from the request. Fall back to default if not provided.
code_generation_model_str = params.get(
"codeGenerationModel", Llm.GPT_4O_2024_05_13.value
)
try:
code_generation_model = convert_frontend_str_to_llm(code_generation_model_str)
except ValueError:
await throw_error(f"Invalid model: {code_generation_model_str}")
raise ValueError(f"Invalid model: {code_generation_model_str}")
openai_api_key = get_from_settings_dialog_or_env(
params, "openAiApiKey", OPENAI_API_KEY
)
# If neither is provided, we throw an error later only if Claude is used.
anthropic_api_key = get_from_settings_dialog_or_env(
params, "anthropicApiKey", ANTHROPIC_API_KEY
)
# Base URL for OpenAI API
openai_base_url: str | None = None
# Disable user-specified OpenAI Base URL in prod
if not IS_PROD:
openai_base_url = get_from_settings_dialog_or_env(
params, "openAiBaseURL", OPENAI_BASE_URL
)
if not openai_base_url:
print("Using official OpenAI URL")
# Get the image generation flag from the request. Fall back to True if not provided.
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
return ExtractedParams(
stack=validated_stack,
input_mode=validated_input_mode,
code_generation_model=code_generation_model,
should_generate_images=should_generate_images,
openai_api_key=openai_api_key,
anthropic_api_key=anthropic_api_key,
openai_base_url=openai_base_url,
)
def get_from_settings_dialog_or_env(
params: dict[str, str], key: str, env_var: str | None
) -> str | None:
value = params.get(key)
if value:
print(f"Using {key} from client-side settings dialog")
return value
if env_var:
print(f"Using {key} from environment variable")
return env_var
return None
@router.websocket("/generate-code") @router.websocket("/generate-code")
async def stream_code(websocket: WebSocket): async def stream_code(websocket: WebSocket):
await websocket.accept() await websocket.accept()
print("Incoming websocket connection...") print("Incoming websocket connection...")
## Communication protocol setup ## Communication protocol setup
@ -112,89 +199,37 @@ async def stream_code(websocket: WebSocket):
{"type": type, "value": value, "variantIndex": variantIndex} {"type": type, "value": value, "variantIndex": variantIndex}
) )
## Parameter validation ## Parameter extract and validation
# TODO: Are the values always strings? # TODO: Are the values always strings?
params: Dict[str, str] = await websocket.receive_json() params: dict[str, str] = await websocket.receive_json()
print("Received params") print("Received params")
# Read the code config settings (stack) from the request. extracted_params = await extract_params(params, throw_error)
generated_code_config = params.get("generatedCodeConfig", "") # TODO(*): Rename to stack and input_mode
if not generated_code_config in get_args(Stack): valid_stack = extracted_params.stack
await throw_error(f"Invalid generated code config: {generated_code_config}") validated_input_mode = extracted_params.input_mode
raise Exception(f"Invalid generated code config: {generated_code_config}") code_generation_model = extracted_params.code_generation_model
# Cast the variable to the Stack type openai_api_key = extracted_params.openai_api_key
valid_stack = cast(Stack, generated_code_config) openai_base_url = extracted_params.openai_base_url
anthropic_api_key = extracted_params.anthropic_api_key
# Validate the input mode should_generate_images = extracted_params.should_generate_images
input_mode = params.get("inputMode")
if not input_mode in get_args(InputMode):
await throw_error(f"Invalid input mode: {input_mode}")
raise Exception(f"Invalid input mode: {input_mode}")
# Cast the variable to the right type
validated_input_mode = cast(InputMode, input_mode)
# Read the model from the request. Fall back to default if not provided.
code_generation_model_str = params.get(
"codeGenerationModel", Llm.GPT_4O_2024_05_13.value
)
try:
code_generation_model = convert_frontend_str_to_llm(code_generation_model_str)
except:
await throw_error(f"Invalid model: {code_generation_model_str}")
raise Exception(f"Invalid model: {code_generation_model_str}")
# Auto-upgrade usage of older models # Auto-upgrade usage of older models
code_generation_model = auto_upgrade_model(code_generation_model) code_generation_model = auto_upgrade_model(code_generation_model)
print( print(
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." f"Generating {valid_stack} code in {validated_input_mode} mode using {code_generation_model}..."
) )
# Get the OpenAI API key from the request. Fall back to environment variable if not provided. # TODO(*): Do I still need this?
# If neither is provided, we throw an error.
openai_api_key = params.get("openAiApiKey")
if openai_api_key:
print("Using OpenAI API key from client-side settings dialog")
else:
openai_api_key = OPENAI_API_KEY
if openai_api_key:
print("Using OpenAI API key from environment variable")
if not openai_api_key and is_openai_model(code_generation_model): if not openai_api_key and is_openai_model(code_generation_model):
await throw_error( await throw_error(
"No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server." "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server."
) )
return return
# Get the Anthropic API key from the request. Fall back to environment variable if not provided. # TODO(*): Don't assume number of variants
# If neither is provided, we throw an error later only if Claude is used.
anthropic_api_key = params.get("anthropicApiKey")
if anthropic_api_key:
print("Using Anthropic API key from client-side settings dialog")
else:
anthropic_api_key = ANTHROPIC_API_KEY
if anthropic_api_key:
print("Using Anthropic API key from environment variable")
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
openai_base_url: Union[str, None] = None
# Disable user-specified OpenAI Base URL in prod
if not os.environ.get("IS_PROD"):
openai_base_url = params.get("openAiBaseURL")
if openai_base_url:
print("Using OpenAI Base URL from client-side settings dialog")
else:
openai_base_url = os.environ.get("OPENAI_BASE_URL")
if openai_base_url:
print("Using OpenAI Base URL from environment variable")
if not openai_base_url:
print("Using official OpenAI URL")
# Get the image generation flag from the request. Fall back to True if not provided.
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
# TODO(*): Print with send_message instead of print statements
await send_message("status", "Generating code...", 0) await send_message("status", "Generating code...", 0)
await send_message("status", "Generating code...", 1) await send_message("status", "Generating code...", 1)
@ -213,7 +248,7 @@ async def stream_code(websocket: WebSocket):
) )
raise raise
pprint_prompt(prompt_messages) # type: ignore # pprint_prompt(prompt_messages) # type: ignore
### Code generation ### Code generation