Add azure support

This commit is contained in:
Cristiano Revil 2023-12-14 11:57:55 +01:00
parent d23cec9bc0
commit 88e383cbd6
5 changed files with 206 additions and 28 deletions

View File

@ -26,9 +26,11 @@ See the [Examples](#-examples) section below for more demos.
## 🛠 Getting Started ## 🛠 Getting Started
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API key with access to the GPT-4 Vision API. The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API/Azure key with access to the GPT-4 Vision API.
Run the backend (I use Poetry for package management - `pip install poetry` if you don't have it): Run the backend based on the AI provider you want to use (I use Poetry for package management - `pip install poetry` if you don't have it):
For OpenAI Version:
```bash ```bash
cd backend cd backend
@ -38,6 +40,21 @@ poetry shell
poetry run uvicorn main:app --reload --port 7001 poetry run uvicorn main:app --reload --port 7001
``` ```
For Azure version, you need to add some additional environment keys (vision and dalle3 deployment must be int the same resource on Azure):
```bash
cd backend
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
poetry install
poetry shell
poetry run uvicorn main:app --reload --port 7001
```
Run the frontend: Run the frontend:
```bash ```bash
@ -58,17 +75,31 @@ MOCK=true poetry run uvicorn main:app --reload --port 7001
## Configuration ## Configuration
* You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog - You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
## Docker ## Docker
If you have Docker installed on your system, in the root directory, run: If you have Docker installed on your system, in the root directory, run:
For OpenAI Version:
```bash ```bash
echo "OPENAI_API_KEY=sk-your-key" > .env echo "OPENAI_API_KEY=sk-your-key" > .env
docker-compose up -d --build docker-compose up -d --build
``` ```
For Azure version:
```bash
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
docker-compose up -d --build
```
The app will be up and running at http://localhost:5173. Note that you can't develop the application with this setup as the file changes won't trigger a rebuild. The app will be up and running at http://localhost:5173. Note that you can't develop the application with this setup as the file changes won't trigger a rebuild.
## 🙋‍♂️ FAQs ## 🙋‍♂️ FAQs

View File

@ -8,7 +8,7 @@ from eval_utils import image_to_data_url
load_dotenv() load_dotenv()
import os import os
from llm import stream_openai_response from llm import stream_openai_response, stream_azure_openai_response
from prompts import assemble_prompt from prompts import assemble_prompt
import asyncio import asyncio
@ -19,21 +19,35 @@ async def generate_code_core(image_url: str, stack: str) -> str:
prompt_messages = assemble_prompt(image_url, stack) prompt_messages = assemble_prompt(image_url, stack)
openai_api_key = os.environ.get("OPENAI_API_KEY") openai_api_key = os.environ.get("OPENAI_API_KEY")
openai_base_url = None openai_base_url = None
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
azure_openai_deployment_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
pprint_prompt(prompt_messages) pprint_prompt(prompt_messages)
async def process_chunk(content: str): async def process_chunk(content: str):
pass pass
if not openai_api_key: if not openai_api_key and not azure_openai_api_key:
raise Exception("OpenAI API key not found") raise Exception("OpenAI API or Azure key not found")
completion = await stream_openai_response( if not openai_api_key:
prompt_messages, completion = await stream_openai_response(
api_key=openai_api_key, prompt_messages,
base_url=openai_base_url, api_key=openai_api_key,
callback=lambda x: process_chunk(x), base_url=openai_base_url,
) callback=lambda x: process_chunk(x),
)
if not azure_openai_api_key:
completion = await stream_azure_openai_response(
prompt_messages,
azure_openai_api_key=azure_openai_api_key,
azure_openai_api_version=azure_openai_api_version,
azure_openai_resource_name=azure_openai_resource_name,
azure_openai_deployment_name=azure_openai_deployment_name,
callback=lambda x: process_chunk(x),
)
return completion return completion

View File

@ -1,12 +1,32 @@
import asyncio import asyncio
import re import re
from typing import Dict, List, Union from typing import Dict, List, Union
from openai import AsyncOpenAI from openai import AsyncOpenAI, AsyncAzureOpenAI
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
async def process_tasks(prompts: List[str], api_key: str, base_url: str): async def process_tasks(
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts] prompts: List[str],
api_key: str | None,
base_url: str | None,
azure_openai_api_key: str | None,
azure_openai_dalle3_api_version: str | None,
azure_openai_resource_name: str | None,
azure_openai_dalle3_deployment_name: str | None,
):
if api_key is not None:
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
if azure_openai_api_key is not None:
tasks = [
generate_image_azure(
prompt,
azure_openai_api_key,
azure_openai_dalle3_api_version,
azure_openai_resource_name,
azure_openai_dalle3_deployment_name,
)
for prompt in prompts
]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
processed_results: List[Union[str, None]] = [] processed_results: List[Union[str, None]] = []
@ -35,6 +55,32 @@ async def generate_image(prompt: str, api_key: str, base_url: str):
return res.data[0].url return res.data[0].url
async def generate_image_azure(
prompt: str,
azure_openai_api_key: str,
azure_openai_api_version: str,
azure_openai_resource_name: str,
azure_openai_dalle3_deployment_name: str,
):
client = AsyncAzureOpenAI(
api_version=azure_openai_api_version,
api_key=azure_openai_api_key,
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
azure_deployment=azure_openai_dalle3_deployment_name,
)
image_params: Dict[str, Union[str, int]] = {
"model": "dall-e-3",
"quality": "standard",
"style": "natural",
"n": 1,
"size": "1024x1024",
"prompt": prompt,
}
res = await client.images.generate(**image_params)
await client.close()
return res.data[0].url
def extract_dimensions(url: str): def extract_dimensions(url: str):
# Regular expression to match numbers in the format '300x200' # Regular expression to match numbers in the format '300x200'
matches = re.findall(r"(\d+)x(\d+)", url) matches = re.findall(r"(\d+)x(\d+)", url)
@ -62,7 +108,14 @@ def create_alt_url_mapping(code: str) -> Dict[str, str]:
async def generate_images( async def generate_images(
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str] code: str,
api_key: str | None,
base_url: Union[str, None] | None,
image_cache: Dict[str, str],
azure_openai_api_key: str | None,
azure_openai_dalle3_api_version: str | None,
azure_openai_resource_name: str | None,
azure_openai_dalle3_deployment_name: str | None,
): ):
# Find all images # Find all images
soup = BeautifulSoup(code, "html.parser") soup = BeautifulSoup(code, "html.parser")
@ -90,7 +143,15 @@ async def generate_images(
return code return code
# Generate images # Generate images
results = await process_tasks(prompts, api_key, base_url) results = await process_tasks(
prompts,
api_key,
base_url,
azure_openai_api_key,
azure_openai_dalle3_api_version,
azure_openai_resource_name,
azure_openai_dalle3_deployment_name,
)
# Create a dict mapping alt text to image URL # Create a dict mapping alt text to image URL
mapped_image_urls = dict(zip(prompts, results)) mapped_image_urls = dict(zip(prompts, results))

View File

@ -1,5 +1,5 @@
from typing import Awaitable, Callable, List from typing import Awaitable, Callable, List
from openai import AsyncOpenAI from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
MODEL_GPT_4_VISION = "gpt-4-vision-preview" MODEL_GPT_4_VISION = "gpt-4-vision-preview"
@ -34,3 +34,41 @@ async def stream_openai_response(
await client.close() await client.close()
return full_response return full_response
async def stream_azure_openai_response(
messages: List[ChatCompletionMessageParam],
azure_openai_api_key: str | None,
azure_openai_api_version: str | None,
azure_openai_resource_name: str | None,
azure_openai_deployment_name: str | None,
callback: Callable[[str], Awaitable[None]],
) -> str:
client = AsyncAzureOpenAI(
api_version=azure_openai_api_version,
api_key=azure_openai_api_key,
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
azure_deployment=azure_openai_deployment_name,
)
model = MODEL_GPT_4_VISION
# Base parameters
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
# Add 'max_tokens' only if the model is a GPT4 vision model
if model == MODEL_GPT_4_VISION:
params["max_tokens"] = 4096
params["temperature"] = 0
stream = await client.chat.completions.create(**params) # type: ignore
full_response = ""
async for chunk in stream: # type: ignore
assert isinstance(chunk, ChatCompletionChunk)
content = chunk.choices[0].delta.content or ""
full_response += content
await callback(content)
await client.close()
return full_response

View File

@ -3,7 +3,7 @@ import traceback
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import openai import openai
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
from llm import stream_openai_response from llm import stream_openai_response, stream_azure_openai_response
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion from mock_llm import mock_completion
from typing import Dict, List from typing import Dict, List
@ -64,6 +64,12 @@ async def stream_code(websocket: WebSocket):
# Get the OpenAI API key from the request. Fall back to environment variable if not provided. # Get the OpenAI API key from the request. Fall back to environment variable if not provided.
# If neither is provided, we throw an error. # If neither is provided, we throw an error.
openai_api_key = None openai_api_key = None
azure_openai_api_key = None
azure_openai_resource_name = None
azure_openai_deployment_name = None
azure_openai_api_version = None
azure_openai_dalle3_deployment_name = None
azure_openai_dalle3_api_version = None
if "accessCode" in params and params["accessCode"]: if "accessCode" in params and params["accessCode"]:
print("Access code - using platform API key") print("Access code - using platform API key")
res = await validate_access_token(params["accessCode"]) res = await validate_access_token(params["accessCode"])
@ -83,15 +89,29 @@ async def stream_code(websocket: WebSocket):
print("Using OpenAI API key from client-side settings dialog") print("Using OpenAI API key from client-side settings dialog")
else: else:
openai_api_key = os.environ.get("OPENAI_API_KEY") openai_api_key = os.environ.get("OPENAI_API_KEY")
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
azure_openai_deployment_name = os.environ.get(
"AZURE_OPENAI_DEPLOYMENT_NAME"
)
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
azure_openai_dalle3_deployment_name = os.environ.get(
"AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME"
)
azure_openai_dalle3_api_version = os.environ.get(
"AZURE_OPENAI_DALLE3_API_VERSION"
)
if openai_api_key: if openai_api_key:
print("Using OpenAI API key from environment variable") print("Using OpenAI API key from environment variable")
if azure_openai_api_key:
print("Using Azure OpenAI API key from environment variable")
if not openai_api_key: if not openai_api_key and not azure_openai_api_key:
print("OpenAI API key not found") print("OpenAI API or Azure key not found")
await websocket.send_json( await websocket.send_json(
{ {
"type": "error", "type": "error",
"value": "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.", "value": "No OpenAI API or Azure key found. Please add your API key in the settings dialog or add it to backend/.env file.",
} }
) )
return return
@ -190,12 +210,22 @@ async def stream_code(websocket: WebSocket):
completion = await mock_completion(process_chunk) completion = await mock_completion(process_chunk)
else: else:
try: try:
completion = await stream_openai_response( if openai_api_key is not None:
prompt_messages, completion = await stream_openai_response(
api_key=openai_api_key, prompt_messages,
base_url=openai_base_url, api_key=openai_api_key,
callback=lambda x: process_chunk(x), base_url=openai_base_url,
) callback=lambda x: process_chunk(x),
)
if azure_openai_api_key is not None:
completion = await stream_azure_openai_response(
prompt_messages,
azure_openai_api_key=azure_openai_api_key,
azure_openai_api_version=azure_openai_api_version,
azure_openai_resource_name=azure_openai_resource_name,
azure_openai_deployment_name=azure_openai_deployment_name,
callback=lambda x: process_chunk(x),
)
except openai.AuthenticationError as e: except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e) print("[GENERATE_CODE] Authentication failed", e)
error_message = ( error_message = (
@ -244,6 +274,10 @@ async def stream_code(websocket: WebSocket):
api_key=openai_api_key, api_key=openai_api_key,
base_url=openai_base_url, base_url=openai_base_url,
image_cache=image_cache, image_cache=image_cache,
azure_openai_api_key=azure_openai_api_key,
azure_openai_dalle3_api_version=azure_openai_dalle3_api_version,
azure_openai_resource_name=azure_openai_resource_name,
azure_openai_dalle3_deployment_name=azure_openai_dalle3_deployment_name,
) )
else: else:
updated_html = completion updated_html = completion