move parameter extraction to separate fn
This commit is contained in:
parent
823bd2e249
commit
c76c7c202a
@ -5,6 +5,7 @@ import os
|
|||||||
|
|
||||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
|
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
|
||||||
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
|
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
|
||||||
|
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", None)
|
||||||
|
|
||||||
# Image generation (optional)
|
# Image generation (optional)
|
||||||
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
|
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, WebSocket
|
||||||
import openai
|
import openai
|
||||||
from codegen.utils import extract_html_content
|
from codegen.utils import extract_html_content
|
||||||
@ -7,6 +7,7 @@ from config import (
|
|||||||
ANTHROPIC_API_KEY,
|
ANTHROPIC_API_KEY,
|
||||||
IS_PROD,
|
IS_PROD,
|
||||||
OPENAI_API_KEY,
|
OPENAI_API_KEY,
|
||||||
|
OPENAI_BASE_URL,
|
||||||
REPLICATE_API_KEY,
|
REPLICATE_API_KEY,
|
||||||
SHOULD_MOCK_AI_RESPONSE,
|
SHOULD_MOCK_AI_RESPONSE,
|
||||||
)
|
)
|
||||||
@ -21,12 +22,13 @@ from llm import (
|
|||||||
)
|
)
|
||||||
from fs_logging.core import write_logs
|
from fs_logging.core import write_logs
|
||||||
from mock_llm import mock_completion
|
from mock_llm import mock_completion
|
||||||
from typing import Any, Coroutine, Dict, List, Literal, Union, cast, get_args
|
from typing import Any, Callable, Coroutine, Dict, List, Literal, cast, get_args
|
||||||
from image_generation.core import generate_images
|
from image_generation.core import generate_images
|
||||||
from prompts import create_prompt
|
from prompts import create_prompt
|
||||||
from prompts.claude_prompts import VIDEO_PROMPT
|
from prompts.claude_prompts import VIDEO_PROMPT
|
||||||
from prompts.types import Stack
|
from prompts.types import Stack
|
||||||
from utils import pprint_prompt
|
|
||||||
|
# from utils import pprint_prompt
|
||||||
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@ -83,10 +85,95 @@ async def perform_image_generation(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtractedParams:
|
||||||
|
stack: Stack
|
||||||
|
input_mode: InputMode
|
||||||
|
code_generation_model: Llm
|
||||||
|
should_generate_images: bool
|
||||||
|
openai_api_key: str | None
|
||||||
|
anthropic_api_key: str | None
|
||||||
|
openai_base_url: str | None
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_params(
|
||||||
|
params: Dict[str, str], throw_error: Callable[[str], Coroutine[Any, Any, None]]
|
||||||
|
) -> ExtractedParams:
|
||||||
|
# Read the code config settings (stack) from the request.
|
||||||
|
generated_code_config = params.get("generatedCodeConfig", "")
|
||||||
|
if generated_code_config not in get_args(Stack):
|
||||||
|
await throw_error(f"Invalid generated code config: {generated_code_config}")
|
||||||
|
raise ValueError(f"Invalid generated code config: {generated_code_config}")
|
||||||
|
validated_stack = cast(Stack, generated_code_config)
|
||||||
|
|
||||||
|
# Validate the input mode
|
||||||
|
input_mode = params.get("inputMode")
|
||||||
|
if input_mode not in get_args(InputMode):
|
||||||
|
await throw_error(f"Invalid input mode: {input_mode}")
|
||||||
|
raise ValueError(f"Invalid input mode: {input_mode}")
|
||||||
|
validated_input_mode = cast(InputMode, input_mode)
|
||||||
|
|
||||||
|
# Read the model from the request. Fall back to default if not provided.
|
||||||
|
code_generation_model_str = params.get(
|
||||||
|
"codeGenerationModel", Llm.GPT_4O_2024_05_13.value
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
code_generation_model = convert_frontend_str_to_llm(code_generation_model_str)
|
||||||
|
except ValueError:
|
||||||
|
await throw_error(f"Invalid model: {code_generation_model_str}")
|
||||||
|
raise ValueError(f"Invalid model: {code_generation_model_str}")
|
||||||
|
|
||||||
|
openai_api_key = get_from_settings_dialog_or_env(
|
||||||
|
params, "openAiApiKey", OPENAI_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
# If neither is provided, we throw an error later only if Claude is used.
|
||||||
|
anthropic_api_key = get_from_settings_dialog_or_env(
|
||||||
|
params, "anthropicApiKey", ANTHROPIC_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
# Base URL for OpenAI API
|
||||||
|
openai_base_url: str | None = None
|
||||||
|
# Disable user-specified OpenAI Base URL in prod
|
||||||
|
if not IS_PROD:
|
||||||
|
openai_base_url = get_from_settings_dialog_or_env(
|
||||||
|
params, "openAiBaseURL", OPENAI_BASE_URL
|
||||||
|
)
|
||||||
|
if not openai_base_url:
|
||||||
|
print("Using official OpenAI URL")
|
||||||
|
|
||||||
|
# Get the image generation flag from the request. Fall back to True if not provided.
|
||||||
|
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
|
||||||
|
|
||||||
|
return ExtractedParams(
|
||||||
|
stack=validated_stack,
|
||||||
|
input_mode=validated_input_mode,
|
||||||
|
code_generation_model=code_generation_model,
|
||||||
|
should_generate_images=should_generate_images,
|
||||||
|
openai_api_key=openai_api_key,
|
||||||
|
anthropic_api_key=anthropic_api_key,
|
||||||
|
openai_base_url=openai_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_from_settings_dialog_or_env(
|
||||||
|
params: dict[str, str], key: str, env_var: str | None
|
||||||
|
) -> str | None:
|
||||||
|
value = params.get(key)
|
||||||
|
if value:
|
||||||
|
print(f"Using {key} from client-side settings dialog")
|
||||||
|
return value
|
||||||
|
|
||||||
|
if env_var:
|
||||||
|
print(f"Using {key} from environment variable")
|
||||||
|
return env_var
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/generate-code")
|
@router.websocket("/generate-code")
|
||||||
async def stream_code(websocket: WebSocket):
|
async def stream_code(websocket: WebSocket):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|
||||||
print("Incoming websocket connection...")
|
print("Incoming websocket connection...")
|
||||||
|
|
||||||
## Communication protocol setup
|
## Communication protocol setup
|
||||||
@ -112,89 +199,37 @@ async def stream_code(websocket: WebSocket):
|
|||||||
{"type": type, "value": value, "variantIndex": variantIndex}
|
{"type": type, "value": value, "variantIndex": variantIndex}
|
||||||
)
|
)
|
||||||
|
|
||||||
## Parameter validation
|
## Parameter extract and validation
|
||||||
|
|
||||||
# TODO: Are the values always strings?
|
# TODO: Are the values always strings?
|
||||||
params: Dict[str, str] = await websocket.receive_json()
|
params: dict[str, str] = await websocket.receive_json()
|
||||||
print("Received params")
|
print("Received params")
|
||||||
|
|
||||||
# Read the code config settings (stack) from the request.
|
extracted_params = await extract_params(params, throw_error)
|
||||||
generated_code_config = params.get("generatedCodeConfig", "")
|
# TODO(*): Rename to stack and input_mode
|
||||||
if not generated_code_config in get_args(Stack):
|
valid_stack = extracted_params.stack
|
||||||
await throw_error(f"Invalid generated code config: {generated_code_config}")
|
validated_input_mode = extracted_params.input_mode
|
||||||
raise Exception(f"Invalid generated code config: {generated_code_config}")
|
code_generation_model = extracted_params.code_generation_model
|
||||||
# Cast the variable to the Stack type
|
openai_api_key = extracted_params.openai_api_key
|
||||||
valid_stack = cast(Stack, generated_code_config)
|
openai_base_url = extracted_params.openai_base_url
|
||||||
|
anthropic_api_key = extracted_params.anthropic_api_key
|
||||||
# Validate the input mode
|
should_generate_images = extracted_params.should_generate_images
|
||||||
input_mode = params.get("inputMode")
|
|
||||||
if not input_mode in get_args(InputMode):
|
|
||||||
await throw_error(f"Invalid input mode: {input_mode}")
|
|
||||||
raise Exception(f"Invalid input mode: {input_mode}")
|
|
||||||
# Cast the variable to the right type
|
|
||||||
validated_input_mode = cast(InputMode, input_mode)
|
|
||||||
|
|
||||||
# Read the model from the request. Fall back to default if not provided.
|
|
||||||
code_generation_model_str = params.get(
|
|
||||||
"codeGenerationModel", Llm.GPT_4O_2024_05_13.value
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
code_generation_model = convert_frontend_str_to_llm(code_generation_model_str)
|
|
||||||
except:
|
|
||||||
await throw_error(f"Invalid model: {code_generation_model_str}")
|
|
||||||
raise Exception(f"Invalid model: {code_generation_model_str}")
|
|
||||||
|
|
||||||
# Auto-upgrade usage of older models
|
# Auto-upgrade usage of older models
|
||||||
code_generation_model = auto_upgrade_model(code_generation_model)
|
code_generation_model = auto_upgrade_model(code_generation_model)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
|
f"Generating {valid_stack} code in {validated_input_mode} mode using {code_generation_model}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
# TODO(*): Do I still need this?
|
||||||
# If neither is provided, we throw an error.
|
|
||||||
openai_api_key = params.get("openAiApiKey")
|
|
||||||
if openai_api_key:
|
|
||||||
print("Using OpenAI API key from client-side settings dialog")
|
|
||||||
else:
|
|
||||||
openai_api_key = OPENAI_API_KEY
|
|
||||||
if openai_api_key:
|
|
||||||
print("Using OpenAI API key from environment variable")
|
|
||||||
|
|
||||||
if not openai_api_key and is_openai_model(code_generation_model):
|
if not openai_api_key and is_openai_model(code_generation_model):
|
||||||
await throw_error(
|
await throw_error(
|
||||||
"No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server."
|
"No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get the Anthropic API key from the request. Fall back to environment variable if not provided.
|
# TODO(*): Don't assume number of variants
|
||||||
# If neither is provided, we throw an error later only if Claude is used.
|
|
||||||
anthropic_api_key = params.get("anthropicApiKey")
|
|
||||||
if anthropic_api_key:
|
|
||||||
print("Using Anthropic API key from client-side settings dialog")
|
|
||||||
else:
|
|
||||||
anthropic_api_key = ANTHROPIC_API_KEY
|
|
||||||
if anthropic_api_key:
|
|
||||||
print("Using Anthropic API key from environment variable")
|
|
||||||
|
|
||||||
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
|
|
||||||
openai_base_url: Union[str, None] = None
|
|
||||||
# Disable user-specified OpenAI Base URL in prod
|
|
||||||
if not os.environ.get("IS_PROD"):
|
|
||||||
openai_base_url = params.get("openAiBaseURL")
|
|
||||||
if openai_base_url:
|
|
||||||
print("Using OpenAI Base URL from client-side settings dialog")
|
|
||||||
else:
|
|
||||||
openai_base_url = os.environ.get("OPENAI_BASE_URL")
|
|
||||||
if openai_base_url:
|
|
||||||
print("Using OpenAI Base URL from environment variable")
|
|
||||||
|
|
||||||
if not openai_base_url:
|
|
||||||
print("Using official OpenAI URL")
|
|
||||||
|
|
||||||
# Get the image generation flag from the request. Fall back to True if not provided.
|
|
||||||
should_generate_images = bool(params.get("isImageGenerationEnabled", True))
|
|
||||||
|
|
||||||
# TODO(*): Print with send_message instead of print statements
|
|
||||||
await send_message("status", "Generating code...", 0)
|
await send_message("status", "Generating code...", 0)
|
||||||
await send_message("status", "Generating code...", 1)
|
await send_message("status", "Generating code...", 1)
|
||||||
|
|
||||||
@ -213,7 +248,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
pprint_prompt(prompt_messages) # type: ignore
|
# pprint_prompt(prompt_messages) # type: ignore
|
||||||
|
|
||||||
### Code generation
|
### Code generation
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user