From 46202a516e8cf708da502e036b31d1b635f5eaa3 Mon Sep 17 00:00:00 2001 From: ADou <13738201957@163.com> Date: Thu, 30 Nov 2023 16:59:35 +0800 Subject: [PATCH] base_url --- backend/image_generation.py | 12 ++++++------ backend/llm.py | 4 ++-- backend/main.py | 12 +++++++++++- frontend/src/App.tsx | 1 + frontend/src/components/SettingsDialog.tsx | 19 +++++++++++++++++++ frontend/src/types.ts | 1 + 6 files changed, 40 insertions(+), 9 deletions(-) diff --git a/backend/image_generation.py b/backend/image_generation.py index 080334f..ad21772 100644 --- a/backend/image_generation.py +++ b/backend/image_generation.py @@ -5,8 +5,8 @@ from openai import AsyncOpenAI from bs4 import BeautifulSoup -async def process_tasks(prompts, api_key): - tasks = [generate_image(prompt, api_key) for prompt in prompts] +async def process_tasks(prompts, api_key, base_url): + tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts] results = await asyncio.gather(*tasks, return_exceptions=True) processed_results = [] @@ -20,8 +20,8 @@ async def process_tasks(prompts, api_key): return processed_results -async def generate_image(prompt, api_key): - client = AsyncOpenAI(api_key=api_key) +async def generate_image(prompt, api_key, base_url): + client = AsyncOpenAI(api_key=api_key, base_url=base_url) image_params = { "model": "dall-e-3", "quality": "standard", @@ -60,7 +60,7 @@ def create_alt_url_mapping(code): return mapping -async def generate_images(code, api_key, image_cache): +async def generate_images(code, api_key, base_url, image_cache): # Find all images soup = BeautifulSoup(code, "html.parser") images = soup.find_all("img") @@ -87,7 +87,7 @@ async def generate_images(code, api_key, image_cache): return code # Generate images - results = await process_tasks(prompts, api_key) + results = await process_tasks(prompts, api_key, base_url) # Create a dict mapping alt text to image URL mapped_image_urls = dict(zip(prompts, results)) diff --git a/backend/llm.py b/backend/llm.py index b52c3c9..6972f02 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -6,9 +6,9 @@ MODEL_GPT_4_VISION = "gpt-4-vision-preview" async def stream_openai_response( - messages, api_key: str, callback: Callable[[str], Awaitable[None]] + messages, api_key: str, base_url : str, callback: Callable[[str], Awaitable[None]] ): - client = AsyncOpenAI(api_key=api_key) + client = AsyncOpenAI(api_key=api_key,base_url=base_url) model = MODEL_GPT_4_VISION diff --git a/backend/main.py b/backend/main.py index 4c5823b..b500a52 100644 --- a/backend/main.py +++ b/backend/main.py @@ -80,6 +80,7 @@ async def stream_code(websocket: WebSocket): # Get the OpenAI API key from the request. Fall back to environment variable if not provided. # If neither is provided, we throw an error. openai_api_key = None + openai_api_base_url = None if "accessCode" in params and params["accessCode"]: print("Access code - using platform API key") if await validate_access_token(params["accessCode"]): @@ -101,6 +102,14 @@ async def stream_code(websocket: WebSocket): if openai_api_key: print("Using OpenAI API key from environment variable") + if params["openAiApiBaseUrl"]: + openai_api_base_url = params["openAiApiBaseUrl"].strip() + print("Using OpenAI API Base Url from client-side settings dialog") + else: + openai_api_base_url = os.environ.get("OPENAI_API_BASE_URL") + if openai_api_base_url: + print("Using OpenAI API Base Url from environment variable") + if not openai_api_key: print("OpenAI API key not found") await websocket.send_json( @@ -149,6 +158,7 @@ async def stream_code(websocket: WebSocket): completion = await stream_openai_response( prompt_messages, api_key=openai_api_key, + base_url=openai_api_base_url, callback=lambda x: process_chunk(x), ) @@ -161,7 +171,7 @@ async def stream_code(websocket: WebSocket): {"type": "status", "value": "Generating images..."} ) updated_html = await generate_images( - completion, api_key=openai_api_key, image_cache=image_cache + completion, api_key=openai_api_key, base_url=openai_api_base_url, image_cache=image_cache ) else: updated_html = completion diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index bde3cae..d109754 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -47,6 +47,7 @@ function App() { const [settings, setSettings] = usePersistedState( { openAiApiKey: null, + openAiApiBaseUrl: null, screenshotOneApiKey: null, isImageGenerationEnabled: true, editorTheme: EditorTheme.COBALT, diff --git a/frontend/src/components/SettingsDialog.tsx b/frontend/src/components/SettingsDialog.tsx index 8a8be03..45c5545 100644 --- a/frontend/src/components/SettingsDialog.tsx +++ b/frontend/src/components/SettingsDialog.tsx @@ -103,6 +103,25 @@ function SettingsDialog({ settings, setSettings }: Props) { } /> + + + + setSettings((s) => ({ + ...s, + openAiApiBaseUrl: e.target.value, + })) + } + /> +