diff --git a/backend/config.py b/backend/config.py index 30da119..2453631 100644 --- a/backend/config.py +++ b/backend/config.py @@ -3,6 +3,9 @@ # TODO: Should only be set to true when value is 'True', not any abitrary truthy value import os +NUM_VARIANTS = 2 + +# LLM-related OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None) ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None) OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", None) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index d7f46c5..719da53 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -6,6 +6,7 @@ from codegen.utils import extract_html_content from config import ( ANTHROPIC_API_KEY, IS_PROD, + NUM_VARIANTS, OPENAI_API_KEY, OPENAI_BASE_URL, REPLICATE_API_KEY, @@ -205,9 +206,8 @@ async def stream_code(websocket: WebSocket): print("Received params") extracted_params = await extract_params(params, throw_error) - # TODO(*): Rename to stack and input_mode - valid_stack = extracted_params.stack - validated_input_mode = extracted_params.input_mode + stack = extracted_params.stack + input_mode = extracted_params.input_mode code_generation_model = extracted_params.code_generation_model openai_api_key = extracted_params.openai_api_key openai_base_url = extracted_params.openai_base_url @@ -218,7 +218,7 @@ async def stream_code(websocket: WebSocket): code_generation_model = auto_upgrade_model(code_generation_model) print( - f"Generating {valid_stack} code in {validated_input_mode} mode using {code_generation_model}..." + f"Generating {stack} code in {input_mode} mode using {code_generation_model}..." ) # TODO(*): Do I still need this? @@ -228,9 +228,8 @@ async def stream_code(websocket: WebSocket): ) return - # TODO(*): Don't assume number of variants - await send_message("status", "Generating code...", 0) - await send_message("status", "Generating code...", 1) + for i in range(NUM_VARIANTS): + await send_message("status", "Generating code...", i) ### Prompt creation @@ -238,9 +237,7 @@ async def stream_code(websocket: WebSocket): image_cache: Dict[str, str] = {} try: - prompt_messages, image_cache = await create_prompt( - params, valid_stack, validated_input_mode - ) + prompt_messages, image_cache = await create_prompt(params, stack, input_mode) except: await throw_error( "Error assembling prompt. Contact support at support@picoapps.xyz" @@ -255,12 +252,10 @@ async def stream_code(websocket: WebSocket): await send_message("chunk", content, variantIndex) if SHOULD_MOCK_AI_RESPONSE: - completions = [ - await mock_completion(process_chunk, input_mode=validated_input_mode) - ] + completions = [await mock_completion(process_chunk, input_mode=input_mode)] else: try: - if validated_input_mode == "video": + if input_mode == "video": if not anthropic_api_key: await throw_error( "Video only works with Anthropic models. No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog"