This commit is contained in:
Cristiano Revil 2024-01-09 14:22:55 +08:00 committed by GitHub
commit 25d6ac7e01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 225 additions and 26 deletions

View File

@ -26,9 +26,11 @@ See the [Examples](#-examples) section below for more demos.
## 🛠 Getting Started ## 🛠 Getting Started
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API key with access to the GPT-4 Vision API. The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API/Azure key with access to the GPT-4 Vision API.
Run the backend (I use Poetry for package management - `pip install poetry` if you don't have it): Run the backend based on the AI provider you want to use (I use Poetry for package management - `pip install poetry` if you don't have it):
For OpenAI Version:
```bash ```bash
cd backend cd backend
@ -38,6 +40,21 @@ poetry shell
poetry run uvicorn main:app --reload --port 7001 poetry run uvicorn main:app --reload --port 7001
``` ```
For Azure version, you need to add some additional environment keys (vision and dalle3 deployment must be int the same resource on Azure):
```bash
cd backend
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
poetry install
poetry shell
poetry run uvicorn main:app --reload --port 7001
```
Run the frontend: Run the frontend:
```bash ```bash
@ -58,17 +75,31 @@ MOCK=true poetry run uvicorn main:app --reload --port 7001
## Configuration ## Configuration
* You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog - You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
## Docker ## Docker
If you have Docker installed on your system, in the root directory, run: If you have Docker installed on your system, in the root directory, run:
For OpenAI Version:
```bash ```bash
echo "OPENAI_API_KEY=sk-your-key" > .env echo "OPENAI_API_KEY=sk-your-key" > .env
docker-compose up -d --build docker-compose up -d --build
``` ```
For Azure version:
```bash
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
docker-compose up -d --build
```
The app will be up and running at http://localhost:5173. Note that you can't develop the application with this setup as the file changes won't trigger a rebuild. The app will be up and running at http://localhost:5173. Note that you can't develop the application with this setup as the file changes won't trigger a rebuild.
## 🙋‍♂️ FAQs ## 🙋‍♂️ FAQs

23
backend/api_types.py Normal file
View File

@ -0,0 +1,23 @@
from pydantic import BaseModel
from typing import Union, Literal, Optional
class ApiProviderInfoBase(BaseModel):
name: Literal["openai", "azure"]
class OpenAiProviderInfo(ApiProviderInfoBase):
name: Literal["openai"] = "openai" # type: ignore
api_key: str
base_url: Optional[str] = None
class AzureProviderInfo(ApiProviderInfoBase):
name: Literal["azure"] = "azure" # type: ignore
api_version: str
api_key: str
deployment_name: str
resource_name: str
ApiProviderInfo = Union[OpenAiProviderInfo, AzureProviderInfo]

View File

@ -8,7 +8,7 @@ from eval_utils import image_to_data_url
load_dotenv() load_dotenv()
import os import os
from llm import stream_openai_response from llm import stream_openai_response, stream_azure_openai_response
from prompts import assemble_prompt from prompts import assemble_prompt
import asyncio import asyncio
@ -19,21 +19,35 @@ async def generate_code_core(image_url: str, stack: str) -> str:
prompt_messages = assemble_prompt(image_url, stack) prompt_messages = assemble_prompt(image_url, stack)
openai_api_key = os.environ.get("OPENAI_API_KEY") openai_api_key = os.environ.get("OPENAI_API_KEY")
openai_base_url = None openai_base_url = None
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
azure_openai_deployment_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
pprint_prompt(prompt_messages) pprint_prompt(prompt_messages)
async def process_chunk(content: str): async def process_chunk(content: str):
pass pass
if not openai_api_key: if not openai_api_key and not azure_openai_api_key:
raise Exception("OpenAI API key not found") raise Exception("OpenAI API or Azure key not found")
completion = await stream_openai_response( if not openai_api_key:
prompt_messages, completion = await stream_openai_response(
api_key=openai_api_key, prompt_messages,
base_url=openai_base_url, api_key=openai_api_key,
callback=lambda x: process_chunk(x), base_url=openai_base_url,
) callback=lambda x: process_chunk(x),
)
if not azure_openai_api_key:
completion = await stream_azure_openai_response(
prompt_messages,
azure_openai_api_key=azure_openai_api_key,
azure_openai_api_version=azure_openai_api_version,
azure_openai_resource_name=azure_openai_resource_name,
azure_openai_deployment_name=azure_openai_deployment_name,
callback=lambda x: process_chunk(x),
)
return completion return completion

View File

@ -1,12 +1,32 @@
import asyncio import asyncio
import re import re
from typing import Dict, List, Union from typing import Dict, List, Union
from openai import AsyncOpenAI from openai import AsyncOpenAI, AsyncAzureOpenAI
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
async def process_tasks(prompts: List[str], api_key: str, base_url: str): async def process_tasks(
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts] prompts: List[str],
api_key: str | None,
base_url: str | None,
azure_openai_api_key: str | None,
azure_openai_dalle3_api_version: str | None,
azure_openai_resource_name: str | None,
azure_openai_dalle3_deployment_name: str | None,
):
if api_key is not None:
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
if azure_openai_api_key is not None:
tasks = [
generate_image_azure(
prompt,
azure_openai_api_key,
azure_openai_dalle3_api_version,
azure_openai_resource_name,
azure_openai_dalle3_deployment_name,
)
for prompt in prompts
]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
processed_results: List[Union[str, None]] = [] processed_results: List[Union[str, None]] = []
@ -35,6 +55,32 @@ async def generate_image(prompt: str, api_key: str, base_url: str):
return res.data[0].url return res.data[0].url
async def generate_image_azure(
prompt: str,
azure_openai_api_key: str,
azure_openai_api_version: str,
azure_openai_resource_name: str,
azure_openai_dalle3_deployment_name: str,
):
client = AsyncAzureOpenAI(
api_version=azure_openai_api_version,
api_key=azure_openai_api_key,
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
azure_deployment=azure_openai_dalle3_deployment_name,
)
image_params: Dict[str, Union[str, int]] = {
"model": "dall-e-3",
"quality": "standard",
"style": "natural",
"n": 1,
"size": "1024x1024",
"prompt": prompt,
}
res = await client.images.generate(**image_params)
await client.close()
return res.data[0].url
def extract_dimensions(url: str): def extract_dimensions(url: str):
# Regular expression to match numbers in the format '300x200' # Regular expression to match numbers in the format '300x200'
matches = re.findall(r"(\d+)x(\d+)", url) matches = re.findall(r"(\d+)x(\d+)", url)
@ -62,7 +108,14 @@ def create_alt_url_mapping(code: str) -> Dict[str, str]:
async def generate_images( async def generate_images(
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str] code: str,
api_key: str | None,
base_url: Union[str, None] | None,
image_cache: Dict[str, str],
azure_openai_api_key: str | None,
azure_openai_dalle3_api_version: str | None,
azure_openai_resource_name: str | None,
azure_openai_dalle3_deployment_name: str | None,
): ):
# Find all images # Find all images
soup = BeautifulSoup(code, "html.parser") soup = BeautifulSoup(code, "html.parser")
@ -90,7 +143,15 @@ async def generate_images(
return code return code
# Generate images # Generate images
results = await process_tasks(prompts, api_key, base_url) results = await process_tasks(
prompts,
api_key,
base_url,
azure_openai_api_key,
azure_openai_dalle3_api_version,
azure_openai_resource_name,
azure_openai_dalle3_deployment_name,
)
# Create a dict mapping alt text to image URL # Create a dict mapping alt text to image URL
mapped_image_urls = dict(zip(prompts, results)) mapped_image_urls = dict(zip(prompts, results))

View File

@ -1,17 +1,30 @@
from typing import Awaitable, Callable, List from typing import Awaitable, Callable, List
from openai import AsyncOpenAI from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
from api_types import ApiProviderInfo
MODEL_GPT_4_VISION = "gpt-4-vision-preview" MODEL_GPT_4_VISION = "gpt-4-vision-preview"
async def stream_openai_response( async def stream_openai_response(
messages: List[ChatCompletionMessageParam], messages: List[ChatCompletionMessageParam],
api_key: str, api_provider_info: ApiProviderInfo,
base_url: str | None,
callback: Callable[[str], Awaitable[None]], callback: Callable[[str], Awaitable[None]],
) -> str: ) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url) if api_provider_info.name == "openai":
client = AsyncOpenAI(
api_key=api_provider_info.api_key, base_url=api_provider_info.base_url
)
elif api_provider_info.name == "azure":
client = AsyncAzureOpenAI(
api_version=api_provider_info.api_version,
api_key=api_provider_info.api_key,
azure_endpoint=f"https://{api_provider_info.resource_name}.openai.azure.com/",
azure_deployment=api_provider_info.deployment_name,
)
else:
raise Exception("Invalid api_provider_info")
model = MODEL_GPT_4_VISION model = MODEL_GPT_4_VISION

View File

@ -2,6 +2,7 @@ import os
import traceback import traceback
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import openai import openai
from api_types import AzureProviderInfo, OpenAiProviderInfo
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
from llm import stream_openai_response from llm import stream_openai_response
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
@ -64,6 +65,12 @@ async def stream_code(websocket: WebSocket):
# Get the OpenAI API key from the request. Fall back to environment variable if not provided. # Get the OpenAI API key from the request. Fall back to environment variable if not provided.
# If neither is provided, we throw an error. # If neither is provided, we throw an error.
openai_api_key = None openai_api_key = None
azure_openai_api_key = None
azure_openai_resource_name = None
azure_openai_deployment_name = None
azure_openai_api_version = None
azure_openai_dalle3_deployment_name = None
azure_openai_dalle3_api_version = None
if "accessCode" in params and params["accessCode"]: if "accessCode" in params and params["accessCode"]:
print("Access code - using platform API key") print("Access code - using platform API key")
res = await validate_access_token(params["accessCode"]) res = await validate_access_token(params["accessCode"])
@ -83,15 +90,29 @@ async def stream_code(websocket: WebSocket):
print("Using OpenAI API key from client-side settings dialog") print("Using OpenAI API key from client-side settings dialog")
else: else:
openai_api_key = os.environ.get("OPENAI_API_KEY") openai_api_key = os.environ.get("OPENAI_API_KEY")
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
azure_openai_deployment_name = os.environ.get(
"AZURE_OPENAI_DEPLOYMENT_NAME"
)
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
azure_openai_dalle3_deployment_name = os.environ.get(
"AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME"
)
azure_openai_dalle3_api_version = os.environ.get(
"AZURE_OPENAI_DALLE3_API_VERSION"
)
if openai_api_key: if openai_api_key:
print("Using OpenAI API key from environment variable") print("Using OpenAI API key from environment variable")
if azure_openai_api_key:
print("Using Azure OpenAI API key from environment variable")
if not openai_api_key: if not openai_api_key and not azure_openai_api_key:
print("OpenAI API key not found") print("OpenAI API or Azure key not found")
await websocket.send_json( await websocket.send_json(
{ {
"type": "error", "type": "error",
"value": "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.", "value": "No OpenAI API or Azure key found. Please add your API key in the settings dialog or add it to backend/.env file.",
} }
) )
return return
@ -190,12 +211,44 @@ async def stream_code(websocket: WebSocket):
completion = await mock_completion(process_chunk) completion = await mock_completion(process_chunk)
else: else:
try: try:
api_provider_info = None
if openai_api_key is not None:
api_provider_info = {
"name": "openai",
"api_key": openai_api_key,
"base_url": openai_base_url,
}
api_provider_info = OpenAiProviderInfo(
api_key=openai_api_key, base_url=openai_base_url
)
if azure_openai_api_key is not None:
if (
not azure_openai_api_version
or not azure_openai_resource_name
or not azure_openai_deployment_name
):
raise Exception(
"Missing Azure OpenAI API version, resource name, or deployment name"
)
api_provider_info = AzureProviderInfo(
api_key=azure_openai_api_key,
api_version=azure_openai_api_version,
deployment_name=azure_openai_deployment_name,
resource_name=azure_openai_resource_name,
)
if api_provider_info is None:
raise Exception("Invalid api_provider_info")
completion = await stream_openai_response( completion = await stream_openai_response(
prompt_messages, prompt_messages,
api_key=openai_api_key, api_provider_info=api_provider_info,
base_url=openai_base_url,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
) )
except openai.AuthenticationError as e: except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e) print("[GENERATE_CODE] Authentication failed", e)
error_message = ( error_message = (
@ -244,6 +297,10 @@ async def stream_code(websocket: WebSocket):
api_key=openai_api_key, api_key=openai_api_key,
base_url=openai_base_url, base_url=openai_base_url,
image_cache=image_cache, image_cache=image_cache,
azure_openai_api_key=azure_openai_api_key,
azure_openai_dalle3_api_version=azure_openai_dalle3_api_version,
azure_openai_resource_name=azure_openai_resource_name,
azure_openai_dalle3_deployment_name=azure_openai_dalle3_deployment_name,
) )
else: else:
updated_html = completion updated_html = completion