Add azure support
This commit is contained in:
parent
d23cec9bc0
commit
88e383cbd6
37
README.md
37
README.md
@ -26,9 +26,11 @@ See the [Examples](#-examples) section below for more demos.
|
||||
|
||||
## 🛠 Getting Started
|
||||
|
||||
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API key with access to the GPT-4 Vision API.
|
||||
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API/Azure key with access to the GPT-4 Vision API.
|
||||
|
||||
Run the backend (I use Poetry for package management - `pip install poetry` if you don't have it):
|
||||
Run the backend based on the AI provider you want to use (I use Poetry for package management - `pip install poetry` if you don't have it):
|
||||
|
||||
For OpenAI Version:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
@ -38,6 +40,21 @@ poetry shell
|
||||
poetry run uvicorn main:app --reload --port 7001
|
||||
```
|
||||
|
||||
For Azure version, you need to add some additional environment keys (vision and dalle3 deployment must be int the same resource on Azure):
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
|
||||
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
|
||||
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
|
||||
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
|
||||
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
|
||||
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
|
||||
poetry install
|
||||
poetry shell
|
||||
poetry run uvicorn main:app --reload --port 7001
|
||||
```
|
||||
|
||||
Run the frontend:
|
||||
|
||||
```bash
|
||||
@ -58,17 +75,31 @@ MOCK=true poetry run uvicorn main:app --reload --port 7001
|
||||
|
||||
## Configuration
|
||||
|
||||
* You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
|
||||
- You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
|
||||
|
||||
## Docker
|
||||
|
||||
If you have Docker installed on your system, in the root directory, run:
|
||||
|
||||
For OpenAI Version:
|
||||
|
||||
```bash
|
||||
echo "OPENAI_API_KEY=sk-your-key" > .env
|
||||
docker-compose up -d --build
|
||||
```
|
||||
|
||||
For Azure version:
|
||||
|
||||
```bash
|
||||
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
|
||||
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
|
||||
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
|
||||
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
|
||||
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
|
||||
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
|
||||
docker-compose up -d --build
|
||||
```
|
||||
|
||||
The app will be up and running at http://localhost:5173. Note that you can't develop the application with this setup as the file changes won't trigger a rebuild.
|
||||
|
||||
## 🙋♂️ FAQs
|
||||
|
||||
@ -8,7 +8,7 @@ from eval_utils import image_to_data_url
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
from llm import stream_openai_response
|
||||
from llm import stream_openai_response, stream_azure_openai_response
|
||||
from prompts import assemble_prompt
|
||||
import asyncio
|
||||
|
||||
@ -19,21 +19,35 @@ async def generate_code_core(image_url: str, stack: str) -> str:
|
||||
prompt_messages = assemble_prompt(image_url, stack)
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
openai_base_url = None
|
||||
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
|
||||
azure_openai_deployment_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
|
||||
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||
|
||||
pprint_prompt(prompt_messages)
|
||||
|
||||
async def process_chunk(content: str):
|
||||
pass
|
||||
|
||||
if not openai_api_key:
|
||||
raise Exception("OpenAI API key not found")
|
||||
if not openai_api_key and not azure_openai_api_key:
|
||||
raise Exception("OpenAI API or Azure key not found")
|
||||
|
||||
if not openai_api_key:
|
||||
completion = await stream_openai_response(
|
||||
prompt_messages,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
if not azure_openai_api_key:
|
||||
completion = await stream_azure_openai_response(
|
||||
prompt_messages,
|
||||
azure_openai_api_key=azure_openai_api_key,
|
||||
azure_openai_api_version=azure_openai_api_version,
|
||||
azure_openai_resource_name=azure_openai_resource_name,
|
||||
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
|
||||
return completion
|
||||
|
||||
|
||||
@ -1,12 +1,32 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
from openai import AsyncOpenAI
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
async def process_tasks(prompts: List[str], api_key: str, base_url: str):
|
||||
async def process_tasks(
|
||||
prompts: List[str],
|
||||
api_key: str | None,
|
||||
base_url: str | None,
|
||||
azure_openai_api_key: str | None,
|
||||
azure_openai_dalle3_api_version: str | None,
|
||||
azure_openai_resource_name: str | None,
|
||||
azure_openai_dalle3_deployment_name: str | None,
|
||||
):
|
||||
if api_key is not None:
|
||||
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
|
||||
if azure_openai_api_key is not None:
|
||||
tasks = [
|
||||
generate_image_azure(
|
||||
prompt,
|
||||
azure_openai_api_key,
|
||||
azure_openai_dalle3_api_version,
|
||||
azure_openai_resource_name,
|
||||
azure_openai_dalle3_deployment_name,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
processed_results: List[Union[str, None]] = []
|
||||
@ -35,6 +55,32 @@ async def generate_image(prompt: str, api_key: str, base_url: str):
|
||||
return res.data[0].url
|
||||
|
||||
|
||||
async def generate_image_azure(
|
||||
prompt: str,
|
||||
azure_openai_api_key: str,
|
||||
azure_openai_api_version: str,
|
||||
azure_openai_resource_name: str,
|
||||
azure_openai_dalle3_deployment_name: str,
|
||||
):
|
||||
client = AsyncAzureOpenAI(
|
||||
api_version=azure_openai_api_version,
|
||||
api_key=azure_openai_api_key,
|
||||
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
|
||||
azure_deployment=azure_openai_dalle3_deployment_name,
|
||||
)
|
||||
image_params: Dict[str, Union[str, int]] = {
|
||||
"model": "dall-e-3",
|
||||
"quality": "standard",
|
||||
"style": "natural",
|
||||
"n": 1,
|
||||
"size": "1024x1024",
|
||||
"prompt": prompt,
|
||||
}
|
||||
res = await client.images.generate(**image_params)
|
||||
await client.close()
|
||||
return res.data[0].url
|
||||
|
||||
|
||||
def extract_dimensions(url: str):
|
||||
# Regular expression to match numbers in the format '300x200'
|
||||
matches = re.findall(r"(\d+)x(\d+)", url)
|
||||
@ -62,7 +108,14 @@ def create_alt_url_mapping(code: str) -> Dict[str, str]:
|
||||
|
||||
|
||||
async def generate_images(
|
||||
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
|
||||
code: str,
|
||||
api_key: str | None,
|
||||
base_url: Union[str, None] | None,
|
||||
image_cache: Dict[str, str],
|
||||
azure_openai_api_key: str | None,
|
||||
azure_openai_dalle3_api_version: str | None,
|
||||
azure_openai_resource_name: str | None,
|
||||
azure_openai_dalle3_deployment_name: str | None,
|
||||
):
|
||||
# Find all images
|
||||
soup = BeautifulSoup(code, "html.parser")
|
||||
@ -90,7 +143,15 @@ async def generate_images(
|
||||
return code
|
||||
|
||||
# Generate images
|
||||
results = await process_tasks(prompts, api_key, base_url)
|
||||
results = await process_tasks(
|
||||
prompts,
|
||||
api_key,
|
||||
base_url,
|
||||
azure_openai_api_key,
|
||||
azure_openai_dalle3_api_version,
|
||||
azure_openai_resource_name,
|
||||
azure_openai_dalle3_deployment_name,
|
||||
)
|
||||
|
||||
# Create a dict mapping alt text to image URL
|
||||
mapped_image_urls = dict(zip(prompts, results))
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from typing import Awaitable, Callable, List
|
||||
from openai import AsyncOpenAI
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
||||
|
||||
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
||||
@ -34,3 +34,41 @@ async def stream_openai_response(
|
||||
await client.close()
|
||||
|
||||
return full_response
|
||||
|
||||
|
||||
async def stream_azure_openai_response(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
azure_openai_api_key: str | None,
|
||||
azure_openai_api_version: str | None,
|
||||
azure_openai_resource_name: str | None,
|
||||
azure_openai_deployment_name: str | None,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
) -> str:
|
||||
client = AsyncAzureOpenAI(
|
||||
api_version=azure_openai_api_version,
|
||||
api_key=azure_openai_api_key,
|
||||
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
|
||||
azure_deployment=azure_openai_deployment_name,
|
||||
)
|
||||
|
||||
model = MODEL_GPT_4_VISION
|
||||
|
||||
# Base parameters
|
||||
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
|
||||
|
||||
# Add 'max_tokens' only if the model is a GPT4 vision model
|
||||
if model == MODEL_GPT_4_VISION:
|
||||
params["max_tokens"] = 4096
|
||||
params["temperature"] = 0
|
||||
|
||||
stream = await client.chat.completions.create(**params) # type: ignore
|
||||
full_response = ""
|
||||
async for chunk in stream: # type: ignore
|
||||
assert isinstance(chunk, ChatCompletionChunk)
|
||||
content = chunk.choices[0].delta.content or ""
|
||||
full_response += content
|
||||
await callback(content)
|
||||
|
||||
await client.close()
|
||||
|
||||
return full_response
|
||||
|
||||
@ -3,7 +3,7 @@ import traceback
|
||||
from fastapi import APIRouter, WebSocket
|
||||
import openai
|
||||
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
||||
from llm import stream_openai_response
|
||||
from llm import stream_openai_response, stream_azure_openai_response
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from mock_llm import mock_completion
|
||||
from typing import Dict, List
|
||||
@ -64,6 +64,12 @@ async def stream_code(websocket: WebSocket):
|
||||
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
||||
# If neither is provided, we throw an error.
|
||||
openai_api_key = None
|
||||
azure_openai_api_key = None
|
||||
azure_openai_resource_name = None
|
||||
azure_openai_deployment_name = None
|
||||
azure_openai_api_version = None
|
||||
azure_openai_dalle3_deployment_name = None
|
||||
azure_openai_dalle3_api_version = None
|
||||
if "accessCode" in params and params["accessCode"]:
|
||||
print("Access code - using platform API key")
|
||||
res = await validate_access_token(params["accessCode"])
|
||||
@ -83,15 +89,29 @@ async def stream_code(websocket: WebSocket):
|
||||
print("Using OpenAI API key from client-side settings dialog")
|
||||
else:
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
|
||||
azure_openai_deployment_name = os.environ.get(
|
||||
"AZURE_OPENAI_DEPLOYMENT_NAME"
|
||||
)
|
||||
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||
azure_openai_dalle3_deployment_name = os.environ.get(
|
||||
"AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME"
|
||||
)
|
||||
azure_openai_dalle3_api_version = os.environ.get(
|
||||
"AZURE_OPENAI_DALLE3_API_VERSION"
|
||||
)
|
||||
if openai_api_key:
|
||||
print("Using OpenAI API key from environment variable")
|
||||
if azure_openai_api_key:
|
||||
print("Using Azure OpenAI API key from environment variable")
|
||||
|
||||
if not openai_api_key:
|
||||
print("OpenAI API key not found")
|
||||
if not openai_api_key and not azure_openai_api_key:
|
||||
print("OpenAI API or Azure key not found")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"value": "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.",
|
||||
"value": "No OpenAI API or Azure key found. Please add your API key in the settings dialog or add it to backend/.env file.",
|
||||
}
|
||||
)
|
||||
return
|
||||
@ -190,12 +210,22 @@ async def stream_code(websocket: WebSocket):
|
||||
completion = await mock_completion(process_chunk)
|
||||
else:
|
||||
try:
|
||||
if openai_api_key is not None:
|
||||
completion = await stream_openai_response(
|
||||
prompt_messages,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
if azure_openai_api_key is not None:
|
||||
completion = await stream_azure_openai_response(
|
||||
prompt_messages,
|
||||
azure_openai_api_key=azure_openai_api_key,
|
||||
azure_openai_api_version=azure_openai_api_version,
|
||||
azure_openai_resource_name=azure_openai_resource_name,
|
||||
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
except openai.AuthenticationError as e:
|
||||
print("[GENERATE_CODE] Authentication failed", e)
|
||||
error_message = (
|
||||
@ -244,6 +274,10 @@ async def stream_code(websocket: WebSocket):
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
image_cache=image_cache,
|
||||
azure_openai_api_key=azure_openai_api_key,
|
||||
azure_openai_dalle3_api_version=azure_openai_dalle3_api_version,
|
||||
azure_openai_resource_name=azure_openai_resource_name,
|
||||
azure_openai_dalle3_deployment_name=azure_openai_dalle3_deployment_name,
|
||||
)
|
||||
else:
|
||||
updated_html = completion
|
||||
|
||||
Loading…
Reference in New Issue
Block a user