diff --git a/backend/image_generation.py b/backend/image_generation.py index 080334f..ad21772 100644 --- a/backend/image_generation.py +++ b/backend/image_generation.py @@ -5,8 +5,8 @@ from openai import AsyncOpenAI from bs4 import BeautifulSoup -async def process_tasks(prompts, api_key): - tasks = [generate_image(prompt, api_key) for prompt in prompts] +async def process_tasks(prompts, api_key, base_url): + tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts] results = await asyncio.gather(*tasks, return_exceptions=True) processed_results = [] @@ -20,8 +20,8 @@ async def process_tasks(prompts, api_key): return processed_results -async def generate_image(prompt, api_key): - client = AsyncOpenAI(api_key=api_key) +async def generate_image(prompt, api_key, base_url): + client = AsyncOpenAI(api_key=api_key, base_url=base_url) image_params = { "model": "dall-e-3", "quality": "standard", @@ -60,7 +60,7 @@ def create_alt_url_mapping(code): return mapping -async def generate_images(code, api_key, image_cache): +async def generate_images(code, api_key, base_url, image_cache): # Find all images soup = BeautifulSoup(code, "html.parser") images = soup.find_all("img") @@ -87,7 +87,7 @@ async def generate_images(code, api_key, image_cache): return code # Generate images - results = await process_tasks(prompts, api_key) + results = await process_tasks(prompts, api_key, base_url) # Create a dict mapping alt text to image URL mapped_image_urls = dict(zip(prompts, results)) diff --git a/backend/llm.py b/backend/llm.py index b52c3c9..fdb1ba0 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -6,9 +6,12 @@ MODEL_GPT_4_VISION = "gpt-4-vision-preview" async def stream_openai_response( - messages, api_key: str, callback: Callable[[str], Awaitable[None]] + messages, + api_key: str, + base_url: str | None, + callback: Callable[[str], Awaitable[None]], ): - client = AsyncOpenAI(api_key=api_key) + client = AsyncOpenAI(api_key=api_key, base_url=base_url) model = MODEL_GPT_4_VISION diff --git a/backend/main.py b/backend/main.py index 4c5823b..65fc15e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -111,6 +111,22 @@ async def stream_code(websocket: WebSocket): ) return + # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. + openai_base_url = None + # Disable user-specified OpenAI Base URL in prod + if not os.environ.get("IS_PROD"): + if "openAiBaseURL" in params and params["openAiBaseURL"]: + openai_base_url = params["openAiBaseURL"] + print("Using OpenAI Base URL from client-side settings dialog") + else: + openai_base_url = os.environ.get("OPENAI_BASE_URL") + if openai_base_url: + print("Using OpenAI Base URL from environment variable") + + if not openai_base_url: + print("Using official OpenAI URL") + + # Get the image generation flag from the request. Fall back to True if not provided. should_generate_images = ( params["isImageGenerationEnabled"] if "isImageGenerationEnabled" in params @@ -149,6 +165,7 @@ async def stream_code(websocket: WebSocket): completion = await stream_openai_response( prompt_messages, api_key=openai_api_key, + base_url=openai_base_url, callback=lambda x: process_chunk(x), ) @@ -161,7 +178,10 @@ async def stream_code(websocket: WebSocket): {"type": "status", "value": "Generating images..."} ) updated_html = await generate_images( - completion, api_key=openai_api_key, image_cache=image_cache + completion, + api_key=openai_api_key, + base_url=openai_base_url, + image_cache=image_cache, ) else: updated_html = completion diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 8911165..ed416da 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -47,6 +47,7 @@ function App() { const [settings, setSettings] = usePersistedState( { openAiApiKey: null, + openAiBaseURL: null, screenshotOneApiKey: null, isImageGenerationEnabled: true, editorTheme: EditorTheme.COBALT, diff --git a/frontend/src/components/SettingsDialog.tsx b/frontend/src/components/SettingsDialog.tsx index 361d194..d0d0a54 100644 --- a/frontend/src/components/SettingsDialog.tsx +++ b/frontend/src/components/SettingsDialog.tsx @@ -109,6 +109,29 @@ function SettingsDialog({ settings, setSettings }: Props) { } /> + {!IS_RUNNING_ON_CLOUD && ( + <> + + + + setSettings((s) => ({ + ...s, + openAiBaseURL: e.target.value, + })) + } + /> + + )} + Screenshot by URL Config diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 42b06f9..137f770 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -21,6 +21,7 @@ export interface OutputSettings { export interface Settings { openAiApiKey: string | null; + openAiBaseURL: string | null; screenshotOneApiKey: string | null; isImageGenerationEnabled: boolean; editorTheme: EditorTheme;