This commit is contained in:
Cristiano Revil 2024-01-09 14:22:55 +08:00 committed by GitHub
commit 25d6ac7e01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 225 additions and 26 deletions

View File

@ -26,9 +26,11 @@ See the [Examples](#-examples) section below for more demos.
## 🛠 Getting Started
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API key with access to the GPT-4 Vision API.
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API/Azure key with access to the GPT-4 Vision API.
Run the backend (I use Poetry for package management - `pip install poetry` if you don't have it):
Run the backend based on the AI provider you want to use (I use Poetry for package management - `pip install poetry` if you don't have it):
For OpenAI Version:
```bash
cd backend
@ -38,6 +40,21 @@ poetry shell
poetry run uvicorn main:app --reload --port 7001
```
For Azure version, you need to add some additional environment keys (vision and dalle3 deployment must be int the same resource on Azure):
```bash
cd backend
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
poetry install
poetry shell
poetry run uvicorn main:app --reload --port 7001
```
Run the frontend:
```bash
@ -58,17 +75,31 @@ MOCK=true poetry run uvicorn main:app --reload --port 7001
## Configuration
* You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
- You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
## Docker
If you have Docker installed on your system, in the root directory, run:
For OpenAI Version:
```bash
echo "OPENAI_API_KEY=sk-your-key" > .env
docker-compose up -d --build
```
For Azure version:
```bash
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
docker-compose up -d --build
```
The app will be up and running at http://localhost:5173. Note that you can't develop the application with this setup as the file changes won't trigger a rebuild.
## 🙋‍♂️ FAQs

23
backend/api_types.py Normal file
View File

@ -0,0 +1,23 @@
from pydantic import BaseModel
from typing import Union, Literal, Optional
class ApiProviderInfoBase(BaseModel):
name: Literal["openai", "azure"]
class OpenAiProviderInfo(ApiProviderInfoBase):
name: Literal["openai"] = "openai" # type: ignore
api_key: str
base_url: Optional[str] = None
class AzureProviderInfo(ApiProviderInfoBase):
name: Literal["azure"] = "azure" # type: ignore
api_version: str
api_key: str
deployment_name: str
resource_name: str
ApiProviderInfo = Union[OpenAiProviderInfo, AzureProviderInfo]

View File

@ -8,7 +8,7 @@ from eval_utils import image_to_data_url
load_dotenv()
import os
from llm import stream_openai_response
from llm import stream_openai_response, stream_azure_openai_response
from prompts import assemble_prompt
import asyncio
@ -19,21 +19,35 @@ async def generate_code_core(image_url: str, stack: str) -> str:
prompt_messages = assemble_prompt(image_url, stack)
openai_api_key = os.environ.get("OPENAI_API_KEY")
openai_base_url = None
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
azure_openai_deployment_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
pprint_prompt(prompt_messages)
async def process_chunk(content: str):
pass
if not openai_api_key:
raise Exception("OpenAI API key not found")
if not openai_api_key and not azure_openai_api_key:
raise Exception("OpenAI API or Azure key not found")
completion = await stream_openai_response(
prompt_messages,
api_key=openai_api_key,
base_url=openai_base_url,
callback=lambda x: process_chunk(x),
)
if not openai_api_key:
completion = await stream_openai_response(
prompt_messages,
api_key=openai_api_key,
base_url=openai_base_url,
callback=lambda x: process_chunk(x),
)
if not azure_openai_api_key:
completion = await stream_azure_openai_response(
prompt_messages,
azure_openai_api_key=azure_openai_api_key,
azure_openai_api_version=azure_openai_api_version,
azure_openai_resource_name=azure_openai_resource_name,
azure_openai_deployment_name=azure_openai_deployment_name,
callback=lambda x: process_chunk(x),
)
return completion

View File

@ -1,12 +1,32 @@
import asyncio
import re
from typing import Dict, List, Union
from openai import AsyncOpenAI
from openai import AsyncOpenAI, AsyncAzureOpenAI
from bs4 import BeautifulSoup
async def process_tasks(prompts: List[str], api_key: str, base_url: str):
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
async def process_tasks(
prompts: List[str],
api_key: str | None,
base_url: str | None,
azure_openai_api_key: str | None,
azure_openai_dalle3_api_version: str | None,
azure_openai_resource_name: str | None,
azure_openai_dalle3_deployment_name: str | None,
):
if api_key is not None:
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
if azure_openai_api_key is not None:
tasks = [
generate_image_azure(
prompt,
azure_openai_api_key,
azure_openai_dalle3_api_version,
azure_openai_resource_name,
azure_openai_dalle3_deployment_name,
)
for prompt in prompts
]
results = await asyncio.gather(*tasks, return_exceptions=True)
processed_results: List[Union[str, None]] = []
@ -35,6 +55,32 @@ async def generate_image(prompt: str, api_key: str, base_url: str):
return res.data[0].url
async def generate_image_azure(
prompt: str,
azure_openai_api_key: str,
azure_openai_api_version: str,
azure_openai_resource_name: str,
azure_openai_dalle3_deployment_name: str,
):
client = AsyncAzureOpenAI(
api_version=azure_openai_api_version,
api_key=azure_openai_api_key,
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
azure_deployment=azure_openai_dalle3_deployment_name,
)
image_params: Dict[str, Union[str, int]] = {
"model": "dall-e-3",
"quality": "standard",
"style": "natural",
"n": 1,
"size": "1024x1024",
"prompt": prompt,
}
res = await client.images.generate(**image_params)
await client.close()
return res.data[0].url
def extract_dimensions(url: str):
# Regular expression to match numbers in the format '300x200'
matches = re.findall(r"(\d+)x(\d+)", url)
@ -62,7 +108,14 @@ def create_alt_url_mapping(code: str) -> Dict[str, str]:
async def generate_images(
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
code: str,
api_key: str | None,
base_url: Union[str, None] | None,
image_cache: Dict[str, str],
azure_openai_api_key: str | None,
azure_openai_dalle3_api_version: str | None,
azure_openai_resource_name: str | None,
azure_openai_dalle3_deployment_name: str | None,
):
# Find all images
soup = BeautifulSoup(code, "html.parser")
@ -90,7 +143,15 @@ async def generate_images(
return code
# Generate images
results = await process_tasks(prompts, api_key, base_url)
results = await process_tasks(
prompts,
api_key,
base_url,
azure_openai_api_key,
azure_openai_dalle3_api_version,
azure_openai_resource_name,
azure_openai_dalle3_deployment_name,
)
# Create a dict mapping alt text to image URL
mapped_image_urls = dict(zip(prompts, results))

View File

@ -1,17 +1,30 @@
from typing import Awaitable, Callable, List
from openai import AsyncOpenAI
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
from api_types import ApiProviderInfo
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
async def stream_openai_response(
messages: List[ChatCompletionMessageParam],
api_key: str,
base_url: str | None,
api_provider_info: ApiProviderInfo,
callback: Callable[[str], Awaitable[None]],
) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
if api_provider_info.name == "openai":
client = AsyncOpenAI(
api_key=api_provider_info.api_key, base_url=api_provider_info.base_url
)
elif api_provider_info.name == "azure":
client = AsyncAzureOpenAI(
api_version=api_provider_info.api_version,
api_key=api_provider_info.api_key,
azure_endpoint=f"https://{api_provider_info.resource_name}.openai.azure.com/",
azure_deployment=api_provider_info.deployment_name,
)
else:
raise Exception("Invalid api_provider_info")
model = MODEL_GPT_4_VISION

View File

@ -2,6 +2,7 @@ import os
import traceback
from fastapi import APIRouter, WebSocket
import openai
from api_types import AzureProviderInfo, OpenAiProviderInfo
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
from llm import stream_openai_response
from openai.types.chat import ChatCompletionMessageParam
@ -64,6 +65,12 @@ async def stream_code(websocket: WebSocket):
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
# If neither is provided, we throw an error.
openai_api_key = None
azure_openai_api_key = None
azure_openai_resource_name = None
azure_openai_deployment_name = None
azure_openai_api_version = None
azure_openai_dalle3_deployment_name = None
azure_openai_dalle3_api_version = None
if "accessCode" in params and params["accessCode"]:
print("Access code - using platform API key")
res = await validate_access_token(params["accessCode"])
@ -83,15 +90,29 @@ async def stream_code(websocket: WebSocket):
print("Using OpenAI API key from client-side settings dialog")
else:
openai_api_key = os.environ.get("OPENAI_API_KEY")
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
azure_openai_deployment_name = os.environ.get(
"AZURE_OPENAI_DEPLOYMENT_NAME"
)
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
azure_openai_dalle3_deployment_name = os.environ.get(
"AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME"
)
azure_openai_dalle3_api_version = os.environ.get(
"AZURE_OPENAI_DALLE3_API_VERSION"
)
if openai_api_key:
print("Using OpenAI API key from environment variable")
if azure_openai_api_key:
print("Using Azure OpenAI API key from environment variable")
if not openai_api_key:
print("OpenAI API key not found")
if not openai_api_key and not azure_openai_api_key:
print("OpenAI API or Azure key not found")
await websocket.send_json(
{
"type": "error",
"value": "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.",
"value": "No OpenAI API or Azure key found. Please add your API key in the settings dialog or add it to backend/.env file.",
}
)
return
@ -190,12 +211,44 @@ async def stream_code(websocket: WebSocket):
completion = await mock_completion(process_chunk)
else:
try:
api_provider_info = None
if openai_api_key is not None:
api_provider_info = {
"name": "openai",
"api_key": openai_api_key,
"base_url": openai_base_url,
}
api_provider_info = OpenAiProviderInfo(
api_key=openai_api_key, base_url=openai_base_url
)
if azure_openai_api_key is not None:
if (
not azure_openai_api_version
or not azure_openai_resource_name
or not azure_openai_deployment_name
):
raise Exception(
"Missing Azure OpenAI API version, resource name, or deployment name"
)
api_provider_info = AzureProviderInfo(
api_key=azure_openai_api_key,
api_version=azure_openai_api_version,
deployment_name=azure_openai_deployment_name,
resource_name=azure_openai_resource_name,
)
if api_provider_info is None:
raise Exception("Invalid api_provider_info")
completion = await stream_openai_response(
prompt_messages,
api_key=openai_api_key,
base_url=openai_base_url,
api_provider_info=api_provider_info,
callback=lambda x: process_chunk(x),
)
except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e)
error_message = (
@ -244,6 +297,10 @@ async def stream_code(websocket: WebSocket):
api_key=openai_api_key,
base_url=openai_base_url,
image_cache=image_cache,
azure_openai_api_key=azure_openai_api_key,
azure_openai_dalle3_api_version=azure_openai_dalle3_api_version,
azure_openai_resource_name=azure_openai_resource_name,
azure_openai_dalle3_deployment_name=azure_openai_dalle3_deployment_name,
)
else:
updated_html = completion