This commit is contained in:
ADou 2023-11-30 16:59:35 +08:00
parent 1f08d71d4d
commit 46202a516e
6 changed files with 40 additions and 9 deletions

View File

@ -5,8 +5,8 @@ from openai import AsyncOpenAI
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
async def process_tasks(prompts, api_key): async def process_tasks(prompts, api_key, base_url):
tasks = [generate_image(prompt, api_key) for prompt in prompts] tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
processed_results = [] processed_results = []
@ -20,8 +20,8 @@ async def process_tasks(prompts, api_key):
return processed_results return processed_results
async def generate_image(prompt, api_key): async def generate_image(prompt, api_key, base_url):
client = AsyncOpenAI(api_key=api_key) client = AsyncOpenAI(api_key=api_key, base_url=base_url)
image_params = { image_params = {
"model": "dall-e-3", "model": "dall-e-3",
"quality": "standard", "quality": "standard",
@ -60,7 +60,7 @@ def create_alt_url_mapping(code):
return mapping return mapping
async def generate_images(code, api_key, image_cache): async def generate_images(code, api_key, base_url, image_cache):
# Find all images # Find all images
soup = BeautifulSoup(code, "html.parser") soup = BeautifulSoup(code, "html.parser")
images = soup.find_all("img") images = soup.find_all("img")
@ -87,7 +87,7 @@ async def generate_images(code, api_key, image_cache):
return code return code
# Generate images # Generate images
results = await process_tasks(prompts, api_key) results = await process_tasks(prompts, api_key, base_url)
# Create a dict mapping alt text to image URL # Create a dict mapping alt text to image URL
mapped_image_urls = dict(zip(prompts, results)) mapped_image_urls = dict(zip(prompts, results))

View File

@ -6,9 +6,9 @@ MODEL_GPT_4_VISION = "gpt-4-vision-preview"
async def stream_openai_response( async def stream_openai_response(
messages, api_key: str, callback: Callable[[str], Awaitable[None]] messages, api_key: str, base_url : str, callback: Callable[[str], Awaitable[None]]
): ):
client = AsyncOpenAI(api_key=api_key) client = AsyncOpenAI(api_key=api_key,base_url=base_url)
model = MODEL_GPT_4_VISION model = MODEL_GPT_4_VISION

View File

@ -80,6 +80,7 @@ async def stream_code(websocket: WebSocket):
# Get the OpenAI API key from the request. Fall back to environment variable if not provided. # Get the OpenAI API key from the request. Fall back to environment variable if not provided.
# If neither is provided, we throw an error. # If neither is provided, we throw an error.
openai_api_key = None openai_api_key = None
openai_api_base_url = None
if "accessCode" in params and params["accessCode"]: if "accessCode" in params and params["accessCode"]:
print("Access code - using platform API key") print("Access code - using platform API key")
if await validate_access_token(params["accessCode"]): if await validate_access_token(params["accessCode"]):
@ -101,6 +102,14 @@ async def stream_code(websocket: WebSocket):
if openai_api_key: if openai_api_key:
print("Using OpenAI API key from environment variable") print("Using OpenAI API key from environment variable")
if params["openAiApiBaseUrl"]:
openai_api_base_url = params["openAiApiBaseUrl"].strip()
print("Using OpenAI API Base Url from client-side settings dialog")
else:
openai_api_base_url = os.environ.get("OPENAI_API_BASE_URL")
if openai_api_base_url:
print("Using OpenAI API Base Url from environment variable")
if not openai_api_key: if not openai_api_key:
print("OpenAI API key not found") print("OpenAI API key not found")
await websocket.send_json( await websocket.send_json(
@ -149,6 +158,7 @@ async def stream_code(websocket: WebSocket):
completion = await stream_openai_response( completion = await stream_openai_response(
prompt_messages, prompt_messages,
api_key=openai_api_key, api_key=openai_api_key,
base_url=openai_api_base_url,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
) )
@ -161,7 +171,7 @@ async def stream_code(websocket: WebSocket):
{"type": "status", "value": "Generating images..."} {"type": "status", "value": "Generating images..."}
) )
updated_html = await generate_images( updated_html = await generate_images(
completion, api_key=openai_api_key, image_cache=image_cache completion, api_key=openai_api_key, base_url=openai_api_base_url, image_cache=image_cache
) )
else: else:
updated_html = completion updated_html = completion

View File

@ -47,6 +47,7 @@ function App() {
const [settings, setSettings] = usePersistedState<Settings>( const [settings, setSettings] = usePersistedState<Settings>(
{ {
openAiApiKey: null, openAiApiKey: null,
openAiApiBaseUrl: null,
screenshotOneApiKey: null, screenshotOneApiKey: null,
isImageGenerationEnabled: true, isImageGenerationEnabled: true,
editorTheme: EditorTheme.COBALT, editorTheme: EditorTheme.COBALT,

View File

@ -103,6 +103,25 @@ function SettingsDialog({ settings, setSettings }: Props) {
} }
/> />
<Label htmlFor="openai-api-base-url">
<div>OpenAI API Base Url</div>
<div className="font-light mt-2 leading-relaxed">
The default is "https://api.openai.com/v1"
</div>
</Label>
<Input
id="openai-api-base-url"
placeholder="openAI API BASE URL"
value={settings.openAiApiBaseUrl || ""}
onChange={(e) =>
setSettings((s) => ({
...s,
openAiApiBaseUrl: e.target.value,
}))
}
/>
<Label htmlFor="screenshot-one-api-key"> <Label htmlFor="screenshot-one-api-key">
<div> <div>
ScreenshotOne API key (optional - only needed if you want to use ScreenshotOne API key (optional - only needed if you want to use

View File

@ -21,6 +21,7 @@ export interface OutputSettings {
export interface Settings { export interface Settings {
openAiApiKey: string | null; openAiApiKey: string | null;
openAiApiBaseUrl: string | null;
screenshotOneApiKey: string | null; screenshotOneApiKey: string | null;
isImageGenerationEnabled: boolean; isImageGenerationEnabled: boolean;
editorTheme: EditorTheme; editorTheme: EditorTheme;