diff --git a/backend/main.py b/backend/main.py index 93d2e9a..1476d1b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,7 +14,7 @@ from fastapi.middleware.cors import CORSMiddleware from llm import stream_openai_response from mock import mock_completion from image_generation import create_alt_url_mapping, generate_images -from prompts import assemble_prompt +from prompts import assemble_prompt, assemble_instruction_generation_prompt from routes import screenshot app = FastAPI(openapi_url=None, docs_url=None, redoc_url=None) @@ -149,3 +149,64 @@ async def stream_code(websocket: WebSocket): ) finally: await websocket.close() + + +@app.websocket("/generate-instruction") +async def stream_code(websocket: WebSocket): + await websocket.accept() + + params = await websocket.receive_json() + + # Get the OpenAI API key from the request. Fall back to environment variable if not provided. + # If neither is provided, we throw an error. + if params["openAiApiKey"]: + openai_api_key = params["openAiApiKey"] + print("Using OpenAI API key from client-side settings dialog") + else: + openai_api_key = os.environ.get("OPENAI_API_KEY") + if openai_api_key: + print("Using OpenAI API key from environment variable") + + if not openai_api_key: + print("OpenAI API key not found") + await websocket.send_json( + { + "type": "error", + "value": "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.", + } + ) + return + + print("generating code...") + await websocket.send_json({"type": "status", "value": "Generating instruction..."}) + + async def process_chunk(content): + await websocket.send_json({"type": "chunk", "value": content}) + + prompt_messages = assemble_instruction_generation_prompt(params["image"], params["resultImage"]) + + if SHOULD_MOCK_AI_RESPONSE: + completion = await mock_completion(process_chunk) + else: + completion = await stream_openai_response( + prompt_messages, + api_key=openai_api_key, + callback=lambda x: process_chunk(x), + ) + + # Write the messages dict into a log so that we can debug later + write_logs(prompt_messages, completion) + + try: + await websocket.send_json({"type": "setInstruction", "value": completion}) + await websocket.send_json( + {"type": "status", "value": "Instruction generation complete."} + ) + except Exception as e: + traceback.print_exc() + await websocket.send_json( + {"type": "status", "value": "Instruction generation failed."} + ) + finally: + await websocket.close() +