From ee9b40d99040ba6a8f59e1ebc80e0c4ecb812c3d Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Thu, 16 Nov 2023 18:12:07 -0500 Subject: [PATCH] support setting openai api key on the client side --- backend/image_generation.py | 12 ++--- backend/llm.py | 6 ++- backend/main.py | 26 ++++++++- frontend/src/App.tsx | 1 + frontend/src/components/SettingsDialog.tsx | 62 +++++++++++++++------- frontend/src/components/ui/input.tsx | 25 +++++++++ frontend/src/generateCode.ts | 3 ++ frontend/src/types.ts | 1 + 8 files changed, 108 insertions(+), 28 deletions(-) create mode 100644 frontend/src/components/ui/input.tsx diff --git a/backend/image_generation.py b/backend/image_generation.py index e1230a8..080334f 100644 --- a/backend/image_generation.py +++ b/backend/image_generation.py @@ -5,8 +5,8 @@ from openai import AsyncOpenAI from bs4 import BeautifulSoup -async def process_tasks(prompts): - tasks = [generate_image(prompt) for prompt in prompts] +async def process_tasks(prompts, api_key): + tasks = [generate_image(prompt, api_key) for prompt in prompts] results = await asyncio.gather(*tasks, return_exceptions=True) processed_results = [] @@ -20,8 +20,8 @@ async def process_tasks(prompts): return processed_results -async def generate_image(prompt): - client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) +async def generate_image(prompt, api_key): + client = AsyncOpenAI(api_key=api_key) image_params = { "model": "dall-e-3", "quality": "standard", @@ -60,7 +60,7 @@ def create_alt_url_mapping(code): return mapping -async def generate_images(code, image_cache): +async def generate_images(code, api_key, image_cache): # Find all images soup = BeautifulSoup(code, "html.parser") images = soup.find_all("img") @@ -87,7 +87,7 @@ async def generate_images(code, image_cache): return code # Generate images - results = await process_tasks(prompts) + results = await process_tasks(prompts, api_key) # Create a dict mapping alt text to image URL mapped_image_urls = dict(zip(prompts, results)) diff --git a/backend/llm.py b/backend/llm.py index 686b008..b52c3c9 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -4,10 +4,12 @@ from openai import AsyncOpenAI MODEL_GPT_4_VISION = "gpt-4-vision-preview" -client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) +async def stream_openai_response( + messages, api_key: str, callback: Callable[[str], Awaitable[None]] +): + client = AsyncOpenAI(api_key=api_key) -async def stream_openai_response(messages, callback: Callable[[str], Awaitable[None]]): model = MODEL_GPT_4_VISION # Base parameters diff --git a/backend/main.py b/backend/main.py index 42b3f2f..a3a74ad 100644 --- a/backend/main.py +++ b/backend/main.py @@ -41,6 +41,25 @@ async def stream_code_test(websocket: WebSocket): params = await websocket.receive_json() + # Get the OpenAI API key from the request. Fall back to environment variable if not provided. + # If neither is provided, we throw an error. + if params["openAiApiKey"]: + openai_api_key = params["openAiApiKey"] + print("Using OpenAI API key from client-side settings dialog") + else: + openai_api_key = os.environ.get("OPENAI_API_KEY") + print("Using OpenAI API key from environment variable") + + if not openai_api_key: + print("OpenAI API key not found") + await websocket.send_json( + { + "type": "error", + "value": "OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.", + } + ) + return + should_generate_images = ( params["isImageGenerationEnabled"] if "isImageGenerationEnabled" in params @@ -73,7 +92,8 @@ async def stream_code_test(websocket: WebSocket): else: completion = await stream_openai_response( prompt_messages, - lambda x: process_chunk(x), + api_key=openai_api_key, + callback=lambda x: process_chunk(x), ) # Write the messages dict into a log so that we can debug later @@ -84,7 +104,9 @@ async def stream_code_test(websocket: WebSocket): await websocket.send_json( {"type": "status", "value": "Generating images..."} ) - updated_html = await generate_images(completion, image_cache=image_cache) + updated_html = await generate_images( + completion, api_key=openai_api_key, image_cache=image_cache + ) else: updated_html = completion await websocket.send_json({"type": "setCode", "value": updated_html}) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index d5560be..634c800 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -29,6 +29,7 @@ function App() { const [updateInstruction, setUpdateInstruction] = useState(""); const [history, setHistory] = useState([]); const [settings, setSettings] = useState({ + openAiApiKey: null, isImageGenerationEnabled: true, }); diff --git a/frontend/src/components/SettingsDialog.tsx b/frontend/src/components/SettingsDialog.tsx index 294ec59..512894f 100644 --- a/frontend/src/components/SettingsDialog.tsx +++ b/frontend/src/components/SettingsDialog.tsx @@ -1,6 +1,8 @@ import { Dialog, + DialogClose, DialogContent, + DialogFooter, DialogHeader, DialogTitle, DialogTrigger, @@ -9,6 +11,7 @@ import { FaCog } from "react-icons/fa"; import { Settings } from "../types"; import { Switch } from "./ui/switch"; import { Label } from "./ui/label"; +import { Input } from "./ui/input"; interface Props { settings: Settings; @@ -24,25 +27,48 @@ function SettingsDialog({ settings, setSettings }: Props) { Settings -
- - - setSettings((s) => ({ - ...s, - isImageGenerationEnabled: !s.isImageGenerationEnabled, - })) - } - /> -
+
+ + + setSettings((s) => ({ + ...s, + isImageGenerationEnabled: !s.isImageGenerationEnabled, + })) + } + /> +
+
+ + + + setSettings((s) => ({ + ...s, + openAiApiKey: e.target.value, + })) + } + /> +
+ + Save +
); diff --git a/frontend/src/components/ui/input.tsx b/frontend/src/components/ui/input.tsx new file mode 100644 index 0000000..a92b8e0 --- /dev/null +++ b/frontend/src/components/ui/input.tsx @@ -0,0 +1,25 @@ +import * as React from "react" + +import { cn } from "@/lib/utils" + +export interface InputProps + extends React.InputHTMLAttributes {} + +const Input = React.forwardRef( + ({ className, type, ...props }, ref) => { + return ( + + ) + } +) +Input.displayName = "Input" + +export { Input } diff --git a/frontend/src/generateCode.ts b/frontend/src/generateCode.ts index 06ecc5b..c3e9218 100644 --- a/frontend/src/generateCode.ts +++ b/frontend/src/generateCode.ts @@ -36,6 +36,9 @@ export function generateCode( onStatusUpdate(response.value); } else if (response.type === "setCode") { onSetCode(response.value); + } else if (response.type === "error") { + console.error("Error generating code", response.value); + toast.error(response.value); } }); diff --git a/frontend/src/types.ts b/frontend/src/types.ts index d9fe60b..b2c6305 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -1,3 +1,4 @@ export interface Settings { + openAiApiKey: string | null; isImageGenerationEnabled: boolean; }