import os import asyncio import traceback from fastapi import APIRouter, WebSocket import openai from codegen.utils import extract_html_content from config import ANTHROPIC_API_KEY, IS_PROD, OPENAI_API_KEY, SHOULD_MOCK_AI_RESPONSE from custom_types import InputMode from llm import ( Llm, convert_frontend_str_to_llm, is_openai_model, stream_claude_response, stream_claude_response_native, stream_openai_response, ) from openai.types.chat import ChatCompletionMessageParam from fs_logging.core import write_logs from mock_llm import mock_completion from typing import Any, Callable, Coroutine, Dict, List, Literal, Union, cast, get_args from image_generation import create_alt_url_mapping, generate_images from prompts import assemble_imported_code_prompt, assemble_prompt from prompts.claude_prompts import VIDEO_PROMPT from prompts.types import Stack from utils import pprint_prompt from video.utils import assemble_claude_prompt_video from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore router = APIRouter() async def create_prompt( params: Dict[str, str], stack: Stack, model: Llm, input_mode: InputMode ) -> tuple[list[ChatCompletionMessageParam], Dict[str, str]]: image_cache: Dict[str, str] = {} # If this generation started off with imported code, we need to assemble the prompt differently if params.get("isImportedFromCode"): original_imported_code = params["history"][0] prompt_messages = assemble_imported_code_prompt( original_imported_code, stack, model ) for index, text in enumerate(params["history"][1:]): if index % 2 == 0: message: ChatCompletionMessageParam = { "role": "user", "content": text, } else: message: ChatCompletionMessageParam = { "role": "assistant", "content": text, } prompt_messages.append(message) else: # Assemble the prompt for non-imported code if params.get("resultImage"): prompt_messages = assemble_prompt( params["image"], stack, params["resultImage"] ) else: prompt_messages = assemble_prompt(params["image"], stack) if params["generationType"] == "update": # Transform the history tree into message format # TODO: Move this to frontend for index, text in enumerate(params["history"]): if index % 2 == 0: message: ChatCompletionMessageParam = { "role": "assistant", "content": text, } else: message: ChatCompletionMessageParam = { "role": "user", "content": text, } prompt_messages.append(message) image_cache = create_alt_url_mapping(params["history"][-2]) if input_mode == "video": video_data_url = params["image"] prompt_messages = await assemble_claude_prompt_video(video_data_url) return prompt_messages, image_cache # Auto-upgrade usage of older models def auto_upgrade_model(code_generation_model: Llm) -> Llm: if code_generation_model in {Llm.GPT_4_VISION, Llm.GPT_4_TURBO_2024_04_09}: print( f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to GPT-4O-2024-05-13" ) return Llm.GPT_4O_2024_05_13 elif code_generation_model == Llm.CLAUDE_3_SONNET: print( f"Initial deprecated model: {code_generation_model}. Auto-updating code generation model to CLAUDE-3.5-SONNET-2024-06-20" ) return Llm.CLAUDE_3_5_SONNET_2024_06_20 return code_generation_model # Generate images and return updated completions async def process_completion( completion: str, index: int, should_generate_images: bool, openai_api_key: str | None, openai_base_url: str | None, image_cache: dict[str, str], send_message: Callable[ [Literal["chunk", "status", "setCode", "error"], str, int], Coroutine[Any, Any, None], ], ): try: if should_generate_images and openai_api_key: await send_message("status", "Generating images...", index) updated_html = await generate_images( completion, api_key=openai_api_key, base_url=openai_base_url, image_cache=image_cache, ) else: updated_html = completion await send_message("setCode", updated_html, index) await send_message("status", "Code generation complete.", index) except Exception as e: traceback.print_exc() print(f"Image generation failed for variant {index}", e) await send_message("setCode", completion, index) await send_message( "status", "Image generation failed but code is complete.", index ) @router.websocket("/generate-code") async def stream_code(websocket: WebSocket): await websocket.accept() print("Incoming websocket connection...") async def throw_error( message: str, ): print(message) await websocket.send_json({"type": "error", "value": message}) await websocket.close(APP_ERROR_WEB_SOCKET_CODE) async def send_message( type: Literal["chunk", "status", "setCode", "error"], value: str, variantIndex: int, ): await websocket.send_json( {"type": type, "value": value, "variantIndex": variantIndex} ) # TODO: Are the values always strings? params: Dict[str, str] = await websocket.receive_json() print("Received params") # Read the code config settings (stack) from the request. generated_code_config = params.get("generatedCodeConfig", "") if not generated_code_config in get_args(Stack): await throw_error(f"Invalid generated code config: {generated_code_config}") raise Exception(f"Invalid generated code config: {generated_code_config}") # Cast the variable to the Stack type valid_stack = cast(Stack, generated_code_config) # Validate the input mode input_mode = params.get("inputMode") if not input_mode in get_args(InputMode): await throw_error(f"Invalid input mode: {input_mode}") raise Exception(f"Invalid input mode: {input_mode}") # Cast the variable to the right type validated_input_mode = cast(InputMode, input_mode) # Read the model from the request. Fall back to default if not provided. code_generation_model_str = params.get( "codeGenerationModel", Llm.GPT_4O_2024_05_13.value ) try: code_generation_model = convert_frontend_str_to_llm(code_generation_model_str) except: await throw_error(f"Invalid model: {code_generation_model_str}") raise Exception(f"Invalid model: {code_generation_model_str}") # Auto-upgrade usage of older models code_generation_model = auto_upgrade_model(code_generation_model) print( f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..." ) # Get the OpenAI API key from the request. Fall back to environment variable if not provided. # If neither is provided, we throw an error. openai_api_key = params.get("openAiApiKey") if openai_api_key: print("Using OpenAI API key from client-side settings dialog") else: openai_api_key = OPENAI_API_KEY if openai_api_key: print("Using OpenAI API key from environment variable") if not openai_api_key and is_openai_model(code_generation_model): await throw_error( "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server." ) return # Get the Anthropic API key from the request. Fall back to environment variable if not provided. # If neither is provided, we throw an error later only if Claude is used. anthropic_api_key = params.get("anthropicApiKey") if anthropic_api_key: print("Using Anthropic API key from client-side settings dialog") else: anthropic_api_key = ANTHROPIC_API_KEY if anthropic_api_key: print("Using Anthropic API key from environment variable") # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. openai_base_url: Union[str, None] = None # Disable user-specified OpenAI Base URL in prod if not os.environ.get("IS_PROD"): openai_base_url = params.get("openAiBaseURL") if openai_base_url: print("Using OpenAI Base URL from client-side settings dialog") else: openai_base_url = os.environ.get("OPENAI_BASE_URL") if openai_base_url: print("Using OpenAI Base URL from environment variable") if not openai_base_url: print("Using official OpenAI URL") # Get the image generation flag from the request. Fall back to True if not provided. should_generate_images = bool(params.get("isImageGenerationEnabled", True)) print("generating code...") # TODO(*): Print with send_message instead of print statements await send_message("status", "Generating code...", 0) await send_message("status", "Generating code...", 1) # TODO(*): Move down async def process_chunk(content: str, variantIndex: int): await send_message("chunk", content, variantIndex) # Image cache for updates so that we don't have to regenerate images image_cache: Dict[str, str] = {} try: prompt_messages, image_cache = await create_prompt( params, valid_stack, code_generation_model, validated_input_mode ) except: # TODO(*): This should use variantIndex await websocket.send_json( { "type": "error", "value": "Error assembling prompt. Contact support at support@picoapps.xyz", } ) await websocket.close() raise pprint_prompt(prompt_messages) # type: ignore if SHOULD_MOCK_AI_RESPONSE: completions = [ await mock_completion(process_chunk, input_mode=validated_input_mode) ] else: try: if validated_input_mode == "video": if not anthropic_api_key: await throw_error( "Video only works with Anthropic models. No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog" ) raise Exception("No Anthropic key") completions = [ await stream_claude_response_native( system_prompt=VIDEO_PROMPT, messages=prompt_messages, # type: ignore api_key=anthropic_api_key, callback=lambda x: process_chunk(x, 0), model=Llm.CLAUDE_3_OPUS, include_thinking=True, ) ] else: # Depending on the presence and absence of various keys, # we decide which models to run variant_models = [] if openai_api_key and anthropic_api_key: variant_models = ["openai", "anthropic"] elif openai_api_key: variant_models = ["openai", "openai"] elif anthropic_api_key: variant_models = ["anthropic", "anthropic"] else: await throw_error( "No OpenAI or Anthropic API key found. Please add the environment variable OPENAI_API_KEY or ANTHROPIC_API_KEY to backend/.env or in the settings dialog" ) raise Exception("No OpenAI or Anthropic key") tasks: List[Coroutine[Any, Any, str]] = [] for index, model in enumerate(variant_models): if model == "openai": if openai_api_key is None: await throw_error("OpenAI API key is missing.") raise Exception("OpenAI API key is missing.") tasks.append( stream_openai_response( prompt_messages, api_key=openai_api_key, base_url=openai_base_url, callback=lambda x, i=index: process_chunk(x, i), model=Llm.GPT_4O_2024_05_13, ) ) elif model == "anthropic": if anthropic_api_key is None: await throw_error("Anthropic API key is missing.") raise Exception("Anthropic API key is missing.") tasks.append( stream_claude_response( prompt_messages, api_key=anthropic_api_key, callback=lambda x, i=index: process_chunk(x, i), model=Llm.CLAUDE_3_5_SONNET_2024_06_20, ) ) completions = await asyncio.gather(*tasks) print("Models used for generation: ", variant_models) except openai.AuthenticationError as e: print("[GENERATE_CODE] Authentication failed", e) error_message = ( "Incorrect OpenAI key. Please make sure your OpenAI API key is correct, or create a new OpenAI API key on your OpenAI dashboard." + ( " Alternatively, you can purchase code generation credits directly on this website." if IS_PROD else "" ) ) return await throw_error(error_message) except openai.NotFoundError as e: print("[GENERATE_CODE] Model not found", e) error_message = ( e.message + ". Please make sure you have followed the instructions correctly to obtain an OpenAI key with GPT vision access: https://github.com/abi/screenshot-to-code/blob/main/Troubleshooting.md" + ( " Alternatively, you can purchase code generation credits directly on this website." if IS_PROD else "" ) ) return await throw_error(error_message) except openai.RateLimitError as e: print("[GENERATE_CODE] Rate limit exceeded", e) error_message = ( "OpenAI error - 'You exceeded your current quota, please check your plan and billing details.'" + ( " Alternatively, you can purchase code generation credits directly on this website." if IS_PROD else "" ) ) return await throw_error(error_message) # if validated_input_mode == "video": # completion = extract_tag_content("html", completions[0]) # Strip the completion of everything except the HTML content completions = [extract_html_content(completion) for completion in completions] # Write the messages dict into a log so that we can debug later write_logs(prompt_messages, completions[0]) try: image_generation_tasks = [ process_completion( completion, index, should_generate_images, openai_api_key, openai_base_url, image_cache, send_message, ) for index, completion in enumerate(completions) ] await asyncio.gather(*image_generation_tasks) except Exception as e: traceback.print_exc() print("An error occurred during image generation and processing", e) await send_message( "status", "An error occurred during image generation and processing.", 0 ) await websocket.close()