Add azure support
This commit is contained in:
parent
d23cec9bc0
commit
88e383cbd6
37
README.md
37
README.md
@ -26,9 +26,11 @@ See the [Examples](#-examples) section below for more demos.
|
|||||||
|
|
||||||
## 🛠 Getting Started
|
## 🛠 Getting Started
|
||||||
|
|
||||||
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API key with access to the GPT-4 Vision API.
|
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API/Azure key with access to the GPT-4 Vision API.
|
||||||
|
|
||||||
Run the backend (I use Poetry for package management - `pip install poetry` if you don't have it):
|
Run the backend based on the AI provider you want to use (I use Poetry for package management - `pip install poetry` if you don't have it):
|
||||||
|
|
||||||
|
For OpenAI Version:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd backend
|
cd backend
|
||||||
@ -38,6 +40,21 @@ poetry shell
|
|||||||
poetry run uvicorn main:app --reload --port 7001
|
poetry run uvicorn main:app --reload --port 7001
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For Azure version, you need to add some additional environment keys (vision and dalle3 deployment must be int the same resource on Azure):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
|
||||||
|
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
|
||||||
|
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
|
||||||
|
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
|
||||||
|
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
|
||||||
|
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
|
||||||
|
poetry install
|
||||||
|
poetry shell
|
||||||
|
poetry run uvicorn main:app --reload --port 7001
|
||||||
|
```
|
||||||
|
|
||||||
Run the frontend:
|
Run the frontend:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -58,17 +75,31 @@ MOCK=true poetry run uvicorn main:app --reload --port 7001
|
|||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
* You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
|
- You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
|
||||||
|
|
||||||
## Docker
|
## Docker
|
||||||
|
|
||||||
If you have Docker installed on your system, in the root directory, run:
|
If you have Docker installed on your system, in the root directory, run:
|
||||||
|
|
||||||
|
For OpenAI Version:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
echo "OPENAI_API_KEY=sk-your-key" > .env
|
echo "OPENAI_API_KEY=sk-your-key" > .env
|
||||||
docker-compose up -d --build
|
docker-compose up -d --build
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For Azure version:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
|
||||||
|
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
|
||||||
|
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
|
||||||
|
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
|
||||||
|
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
|
||||||
|
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
|
||||||
|
docker-compose up -d --build
|
||||||
|
```
|
||||||
|
|
||||||
The app will be up and running at http://localhost:5173. Note that you can't develop the application with this setup as the file changes won't trigger a rebuild.
|
The app will be up and running at http://localhost:5173. Note that you can't develop the application with this setup as the file changes won't trigger a rebuild.
|
||||||
|
|
||||||
## 🙋♂️ FAQs
|
## 🙋♂️ FAQs
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from eval_utils import image_to_data_url
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from llm import stream_openai_response
|
from llm import stream_openai_response, stream_azure_openai_response
|
||||||
from prompts import assemble_prompt
|
from prompts import assemble_prompt
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
@ -19,21 +19,35 @@ async def generate_code_core(image_url: str, stack: str) -> str:
|
|||||||
prompt_messages = assemble_prompt(image_url, stack)
|
prompt_messages = assemble_prompt(image_url, stack)
|
||||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
openai_base_url = None
|
openai_base_url = None
|
||||||
|
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||||
|
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
|
||||||
|
azure_openai_deployment_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
|
||||||
|
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||||
|
|
||||||
pprint_prompt(prompt_messages)
|
pprint_prompt(prompt_messages)
|
||||||
|
|
||||||
async def process_chunk(content: str):
|
async def process_chunk(content: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if not openai_api_key:
|
if not openai_api_key and not azure_openai_api_key:
|
||||||
raise Exception("OpenAI API key not found")
|
raise Exception("OpenAI API or Azure key not found")
|
||||||
|
|
||||||
completion = await stream_openai_response(
|
if not openai_api_key:
|
||||||
prompt_messages,
|
completion = await stream_openai_response(
|
||||||
api_key=openai_api_key,
|
prompt_messages,
|
||||||
base_url=openai_base_url,
|
api_key=openai_api_key,
|
||||||
callback=lambda x: process_chunk(x),
|
base_url=openai_base_url,
|
||||||
)
|
callback=lambda x: process_chunk(x),
|
||||||
|
)
|
||||||
|
if not azure_openai_api_key:
|
||||||
|
completion = await stream_azure_openai_response(
|
||||||
|
prompt_messages,
|
||||||
|
azure_openai_api_key=azure_openai_api_key,
|
||||||
|
azure_openai_api_version=azure_openai_api_version,
|
||||||
|
azure_openai_resource_name=azure_openai_resource_name,
|
||||||
|
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||||
|
callback=lambda x: process_chunk(x),
|
||||||
|
)
|
||||||
|
|
||||||
return completion
|
return completion
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,32 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
|
||||||
async def process_tasks(prompts: List[str], api_key: str, base_url: str):
|
async def process_tasks(
|
||||||
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
|
prompts: List[str],
|
||||||
|
api_key: str | None,
|
||||||
|
base_url: str | None,
|
||||||
|
azure_openai_api_key: str | None,
|
||||||
|
azure_openai_dalle3_api_version: str | None,
|
||||||
|
azure_openai_resource_name: str | None,
|
||||||
|
azure_openai_dalle3_deployment_name: str | None,
|
||||||
|
):
|
||||||
|
if api_key is not None:
|
||||||
|
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
|
||||||
|
if azure_openai_api_key is not None:
|
||||||
|
tasks = [
|
||||||
|
generate_image_azure(
|
||||||
|
prompt,
|
||||||
|
azure_openai_api_key,
|
||||||
|
azure_openai_dalle3_api_version,
|
||||||
|
azure_openai_resource_name,
|
||||||
|
azure_openai_dalle3_deployment_name,
|
||||||
|
)
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
processed_results: List[Union[str, None]] = []
|
processed_results: List[Union[str, None]] = []
|
||||||
@ -35,6 +55,32 @@ async def generate_image(prompt: str, api_key: str, base_url: str):
|
|||||||
return res.data[0].url
|
return res.data[0].url
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_image_azure(
|
||||||
|
prompt: str,
|
||||||
|
azure_openai_api_key: str,
|
||||||
|
azure_openai_api_version: str,
|
||||||
|
azure_openai_resource_name: str,
|
||||||
|
azure_openai_dalle3_deployment_name: str,
|
||||||
|
):
|
||||||
|
client = AsyncAzureOpenAI(
|
||||||
|
api_version=azure_openai_api_version,
|
||||||
|
api_key=azure_openai_api_key,
|
||||||
|
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
|
||||||
|
azure_deployment=azure_openai_dalle3_deployment_name,
|
||||||
|
)
|
||||||
|
image_params: Dict[str, Union[str, int]] = {
|
||||||
|
"model": "dall-e-3",
|
||||||
|
"quality": "standard",
|
||||||
|
"style": "natural",
|
||||||
|
"n": 1,
|
||||||
|
"size": "1024x1024",
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
res = await client.images.generate(**image_params)
|
||||||
|
await client.close()
|
||||||
|
return res.data[0].url
|
||||||
|
|
||||||
|
|
||||||
def extract_dimensions(url: str):
|
def extract_dimensions(url: str):
|
||||||
# Regular expression to match numbers in the format '300x200'
|
# Regular expression to match numbers in the format '300x200'
|
||||||
matches = re.findall(r"(\d+)x(\d+)", url)
|
matches = re.findall(r"(\d+)x(\d+)", url)
|
||||||
@ -62,7 +108,14 @@ def create_alt_url_mapping(code: str) -> Dict[str, str]:
|
|||||||
|
|
||||||
|
|
||||||
async def generate_images(
|
async def generate_images(
|
||||||
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
|
code: str,
|
||||||
|
api_key: str | None,
|
||||||
|
base_url: Union[str, None] | None,
|
||||||
|
image_cache: Dict[str, str],
|
||||||
|
azure_openai_api_key: str | None,
|
||||||
|
azure_openai_dalle3_api_version: str | None,
|
||||||
|
azure_openai_resource_name: str | None,
|
||||||
|
azure_openai_dalle3_deployment_name: str | None,
|
||||||
):
|
):
|
||||||
# Find all images
|
# Find all images
|
||||||
soup = BeautifulSoup(code, "html.parser")
|
soup = BeautifulSoup(code, "html.parser")
|
||||||
@ -90,7 +143,15 @@ async def generate_images(
|
|||||||
return code
|
return code
|
||||||
|
|
||||||
# Generate images
|
# Generate images
|
||||||
results = await process_tasks(prompts, api_key, base_url)
|
results = await process_tasks(
|
||||||
|
prompts,
|
||||||
|
api_key,
|
||||||
|
base_url,
|
||||||
|
azure_openai_api_key,
|
||||||
|
azure_openai_dalle3_api_version,
|
||||||
|
azure_openai_resource_name,
|
||||||
|
azure_openai_dalle3_deployment_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Create a dict mapping alt text to image URL
|
# Create a dict mapping alt text to image URL
|
||||||
mapped_image_urls = dict(zip(prompts, results))
|
mapped_image_urls = dict(zip(prompts, results))
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from typing import Awaitable, Callable, List
|
from typing import Awaitable, Callable, List
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
||||||
|
|
||||||
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
||||||
@ -34,3 +34,41 @@ async def stream_openai_response(
|
|||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_azure_openai_response(
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
azure_openai_api_key: str | None,
|
||||||
|
azure_openai_api_version: str | None,
|
||||||
|
azure_openai_resource_name: str | None,
|
||||||
|
azure_openai_deployment_name: str | None,
|
||||||
|
callback: Callable[[str], Awaitable[None]],
|
||||||
|
) -> str:
|
||||||
|
client = AsyncAzureOpenAI(
|
||||||
|
api_version=azure_openai_api_version,
|
||||||
|
api_key=azure_openai_api_key,
|
||||||
|
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
|
||||||
|
azure_deployment=azure_openai_deployment_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = MODEL_GPT_4_VISION
|
||||||
|
|
||||||
|
# Base parameters
|
||||||
|
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
|
||||||
|
|
||||||
|
# Add 'max_tokens' only if the model is a GPT4 vision model
|
||||||
|
if model == MODEL_GPT_4_VISION:
|
||||||
|
params["max_tokens"] = 4096
|
||||||
|
params["temperature"] = 0
|
||||||
|
|
||||||
|
stream = await client.chat.completions.create(**params) # type: ignore
|
||||||
|
full_response = ""
|
||||||
|
async for chunk in stream: # type: ignore
|
||||||
|
assert isinstance(chunk, ChatCompletionChunk)
|
||||||
|
content = chunk.choices[0].delta.content or ""
|
||||||
|
full_response += content
|
||||||
|
await callback(content)
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
return full_response
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import traceback
|
|||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, WebSocket
|
||||||
import openai
|
import openai
|
||||||
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
||||||
from llm import stream_openai_response
|
from llm import stream_openai_response, stream_azure_openai_response
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
from mock_llm import mock_completion
|
from mock_llm import mock_completion
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
@ -64,6 +64,12 @@ async def stream_code(websocket: WebSocket):
|
|||||||
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
||||||
# If neither is provided, we throw an error.
|
# If neither is provided, we throw an error.
|
||||||
openai_api_key = None
|
openai_api_key = None
|
||||||
|
azure_openai_api_key = None
|
||||||
|
azure_openai_resource_name = None
|
||||||
|
azure_openai_deployment_name = None
|
||||||
|
azure_openai_api_version = None
|
||||||
|
azure_openai_dalle3_deployment_name = None
|
||||||
|
azure_openai_dalle3_api_version = None
|
||||||
if "accessCode" in params and params["accessCode"]:
|
if "accessCode" in params and params["accessCode"]:
|
||||||
print("Access code - using platform API key")
|
print("Access code - using platform API key")
|
||||||
res = await validate_access_token(params["accessCode"])
|
res = await validate_access_token(params["accessCode"])
|
||||||
@ -83,15 +89,29 @@ async def stream_code(websocket: WebSocket):
|
|||||||
print("Using OpenAI API key from client-side settings dialog")
|
print("Using OpenAI API key from client-side settings dialog")
|
||||||
else:
|
else:
|
||||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||||
|
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
|
||||||
|
azure_openai_deployment_name = os.environ.get(
|
||||||
|
"AZURE_OPENAI_DEPLOYMENT_NAME"
|
||||||
|
)
|
||||||
|
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||||
|
azure_openai_dalle3_deployment_name = os.environ.get(
|
||||||
|
"AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME"
|
||||||
|
)
|
||||||
|
azure_openai_dalle3_api_version = os.environ.get(
|
||||||
|
"AZURE_OPENAI_DALLE3_API_VERSION"
|
||||||
|
)
|
||||||
if openai_api_key:
|
if openai_api_key:
|
||||||
print("Using OpenAI API key from environment variable")
|
print("Using OpenAI API key from environment variable")
|
||||||
|
if azure_openai_api_key:
|
||||||
|
print("Using Azure OpenAI API key from environment variable")
|
||||||
|
|
||||||
if not openai_api_key:
|
if not openai_api_key and not azure_openai_api_key:
|
||||||
print("OpenAI API key not found")
|
print("OpenAI API or Azure key not found")
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"value": "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.",
|
"value": "No OpenAI API or Azure key found. Please add your API key in the settings dialog or add it to backend/.env file.",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -190,12 +210,22 @@ async def stream_code(websocket: WebSocket):
|
|||||||
completion = await mock_completion(process_chunk)
|
completion = await mock_completion(process_chunk)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
completion = await stream_openai_response(
|
if openai_api_key is not None:
|
||||||
prompt_messages,
|
completion = await stream_openai_response(
|
||||||
api_key=openai_api_key,
|
prompt_messages,
|
||||||
base_url=openai_base_url,
|
api_key=openai_api_key,
|
||||||
callback=lambda x: process_chunk(x),
|
base_url=openai_base_url,
|
||||||
)
|
callback=lambda x: process_chunk(x),
|
||||||
|
)
|
||||||
|
if azure_openai_api_key is not None:
|
||||||
|
completion = await stream_azure_openai_response(
|
||||||
|
prompt_messages,
|
||||||
|
azure_openai_api_key=azure_openai_api_key,
|
||||||
|
azure_openai_api_version=azure_openai_api_version,
|
||||||
|
azure_openai_resource_name=azure_openai_resource_name,
|
||||||
|
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||||
|
callback=lambda x: process_chunk(x),
|
||||||
|
)
|
||||||
except openai.AuthenticationError as e:
|
except openai.AuthenticationError as e:
|
||||||
print("[GENERATE_CODE] Authentication failed", e)
|
print("[GENERATE_CODE] Authentication failed", e)
|
||||||
error_message = (
|
error_message = (
|
||||||
@ -244,6 +274,10 @@ async def stream_code(websocket: WebSocket):
|
|||||||
api_key=openai_api_key,
|
api_key=openai_api_key,
|
||||||
base_url=openai_base_url,
|
base_url=openai_base_url,
|
||||||
image_cache=image_cache,
|
image_cache=image_cache,
|
||||||
|
azure_openai_api_key=azure_openai_api_key,
|
||||||
|
azure_openai_dalle3_api_version=azure_openai_dalle3_api_version,
|
||||||
|
azure_openai_resource_name=azure_openai_resource_name,
|
||||||
|
azure_openai_dalle3_deployment_name=azure_openai_dalle3_deployment_name,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
updated_html = completion
|
updated_html = completion
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user