Merge 0e0b68abe0 into 5912957514
This commit is contained in:
commit
25d6ac7e01
37
README.md
37
README.md
@ -26,9 +26,11 @@ See the [Examples](#-examples) section below for more demos.
|
||||
|
||||
## 🛠 Getting Started
|
||||
|
||||
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API key with access to the GPT-4 Vision API.
|
||||
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API/Azure key with access to the GPT-4 Vision API.
|
||||
|
||||
Run the backend (I use Poetry for package management - `pip install poetry` if you don't have it):
|
||||
Run the backend based on the AI provider you want to use (I use Poetry for package management - `pip install poetry` if you don't have it):
|
||||
|
||||
For OpenAI Version:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
@ -38,6 +40,21 @@ poetry shell
|
||||
poetry run uvicorn main:app --reload --port 7001
|
||||
```
|
||||
|
||||
For Azure version, you need to add some additional environment keys (vision and dalle3 deployment must be int the same resource on Azure):
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
|
||||
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
|
||||
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
|
||||
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
|
||||
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
|
||||
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
|
||||
poetry install
|
||||
poetry shell
|
||||
poetry run uvicorn main:app --reload --port 7001
|
||||
```
|
||||
|
||||
Run the frontend:
|
||||
|
||||
```bash
|
||||
@ -58,17 +75,31 @@ MOCK=true poetry run uvicorn main:app --reload --port 7001
|
||||
|
||||
## Configuration
|
||||
|
||||
* You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
|
||||
- You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
|
||||
|
||||
## Docker
|
||||
|
||||
If you have Docker installed on your system, in the root directory, run:
|
||||
|
||||
For OpenAI Version:
|
||||
|
||||
```bash
|
||||
echo "OPENAI_API_KEY=sk-your-key" > .env
|
||||
docker-compose up -d --build
|
||||
```
|
||||
|
||||
For Azure version:
|
||||
|
||||
```bash
|
||||
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
|
||||
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
|
||||
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
|
||||
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
|
||||
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
|
||||
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
|
||||
docker-compose up -d --build
|
||||
```
|
||||
|
||||
The app will be up and running at http://localhost:5173. Note that you can't develop the application with this setup as the file changes won't trigger a rebuild.
|
||||
|
||||
## 🙋♂️ FAQs
|
||||
|
||||
23
backend/api_types.py
Normal file
23
backend/api_types.py
Normal file
@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Union, Literal, Optional
|
||||
|
||||
|
||||
class ApiProviderInfoBase(BaseModel):
|
||||
name: Literal["openai", "azure"]
|
||||
|
||||
|
||||
class OpenAiProviderInfo(ApiProviderInfoBase):
|
||||
name: Literal["openai"] = "openai" # type: ignore
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
class AzureProviderInfo(ApiProviderInfoBase):
|
||||
name: Literal["azure"] = "azure" # type: ignore
|
||||
api_version: str
|
||||
api_key: str
|
||||
deployment_name: str
|
||||
resource_name: str
|
||||
|
||||
|
||||
ApiProviderInfo = Union[OpenAiProviderInfo, AzureProviderInfo]
|
||||
@ -8,7 +8,7 @@ from eval_utils import image_to_data_url
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
from llm import stream_openai_response
|
||||
from llm import stream_openai_response, stream_azure_openai_response
|
||||
from prompts import assemble_prompt
|
||||
import asyncio
|
||||
|
||||
@ -19,21 +19,35 @@ async def generate_code_core(image_url: str, stack: str) -> str:
|
||||
prompt_messages = assemble_prompt(image_url, stack)
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
openai_base_url = None
|
||||
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
|
||||
azure_openai_deployment_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
|
||||
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||
|
||||
pprint_prompt(prompt_messages)
|
||||
|
||||
async def process_chunk(content: str):
|
||||
pass
|
||||
|
||||
if not openai_api_key:
|
||||
raise Exception("OpenAI API key not found")
|
||||
if not openai_api_key and not azure_openai_api_key:
|
||||
raise Exception("OpenAI API or Azure key not found")
|
||||
|
||||
completion = await stream_openai_response(
|
||||
prompt_messages,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
if not openai_api_key:
|
||||
completion = await stream_openai_response(
|
||||
prompt_messages,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
if not azure_openai_api_key:
|
||||
completion = await stream_azure_openai_response(
|
||||
prompt_messages,
|
||||
azure_openai_api_key=azure_openai_api_key,
|
||||
azure_openai_api_version=azure_openai_api_version,
|
||||
azure_openai_resource_name=azure_openai_resource_name,
|
||||
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
|
||||
return completion
|
||||
|
||||
|
||||
@ -1,12 +1,32 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
from openai import AsyncOpenAI
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
async def process_tasks(prompts: List[str], api_key: str, base_url: str):
|
||||
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
|
||||
async def process_tasks(
|
||||
prompts: List[str],
|
||||
api_key: str | None,
|
||||
base_url: str | None,
|
||||
azure_openai_api_key: str | None,
|
||||
azure_openai_dalle3_api_version: str | None,
|
||||
azure_openai_resource_name: str | None,
|
||||
azure_openai_dalle3_deployment_name: str | None,
|
||||
):
|
||||
if api_key is not None:
|
||||
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
|
||||
if azure_openai_api_key is not None:
|
||||
tasks = [
|
||||
generate_image_azure(
|
||||
prompt,
|
||||
azure_openai_api_key,
|
||||
azure_openai_dalle3_api_version,
|
||||
azure_openai_resource_name,
|
||||
azure_openai_dalle3_deployment_name,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
processed_results: List[Union[str, None]] = []
|
||||
@ -35,6 +55,32 @@ async def generate_image(prompt: str, api_key: str, base_url: str):
|
||||
return res.data[0].url
|
||||
|
||||
|
||||
async def generate_image_azure(
|
||||
prompt: str,
|
||||
azure_openai_api_key: str,
|
||||
azure_openai_api_version: str,
|
||||
azure_openai_resource_name: str,
|
||||
azure_openai_dalle3_deployment_name: str,
|
||||
):
|
||||
client = AsyncAzureOpenAI(
|
||||
api_version=azure_openai_api_version,
|
||||
api_key=azure_openai_api_key,
|
||||
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
|
||||
azure_deployment=azure_openai_dalle3_deployment_name,
|
||||
)
|
||||
image_params: Dict[str, Union[str, int]] = {
|
||||
"model": "dall-e-3",
|
||||
"quality": "standard",
|
||||
"style": "natural",
|
||||
"n": 1,
|
||||
"size": "1024x1024",
|
||||
"prompt": prompt,
|
||||
}
|
||||
res = await client.images.generate(**image_params)
|
||||
await client.close()
|
||||
return res.data[0].url
|
||||
|
||||
|
||||
def extract_dimensions(url: str):
|
||||
# Regular expression to match numbers in the format '300x200'
|
||||
matches = re.findall(r"(\d+)x(\d+)", url)
|
||||
@ -62,7 +108,14 @@ def create_alt_url_mapping(code: str) -> Dict[str, str]:
|
||||
|
||||
|
||||
async def generate_images(
|
||||
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
|
||||
code: str,
|
||||
api_key: str | None,
|
||||
base_url: Union[str, None] | None,
|
||||
image_cache: Dict[str, str],
|
||||
azure_openai_api_key: str | None,
|
||||
azure_openai_dalle3_api_version: str | None,
|
||||
azure_openai_resource_name: str | None,
|
||||
azure_openai_dalle3_deployment_name: str | None,
|
||||
):
|
||||
# Find all images
|
||||
soup = BeautifulSoup(code, "html.parser")
|
||||
@ -90,7 +143,15 @@ async def generate_images(
|
||||
return code
|
||||
|
||||
# Generate images
|
||||
results = await process_tasks(prompts, api_key, base_url)
|
||||
results = await process_tasks(
|
||||
prompts,
|
||||
api_key,
|
||||
base_url,
|
||||
azure_openai_api_key,
|
||||
azure_openai_dalle3_api_version,
|
||||
azure_openai_resource_name,
|
||||
azure_openai_dalle3_deployment_name,
|
||||
)
|
||||
|
||||
# Create a dict mapping alt text to image URL
|
||||
mapped_image_urls = dict(zip(prompts, results))
|
||||
|
||||
@ -1,17 +1,30 @@
|
||||
from typing import Awaitable, Callable, List
|
||||
from openai import AsyncOpenAI
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
||||
|
||||
from api_types import ApiProviderInfo
|
||||
|
||||
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
||||
|
||||
|
||||
async def stream_openai_response(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
api_key: str,
|
||||
base_url: str | None,
|
||||
api_provider_info: ApiProviderInfo,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
) -> str:
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
if api_provider_info.name == "openai":
|
||||
client = AsyncOpenAI(
|
||||
api_key=api_provider_info.api_key, base_url=api_provider_info.base_url
|
||||
)
|
||||
elif api_provider_info.name == "azure":
|
||||
client = AsyncAzureOpenAI(
|
||||
api_version=api_provider_info.api_version,
|
||||
api_key=api_provider_info.api_key,
|
||||
azure_endpoint=f"https://{api_provider_info.resource_name}.openai.azure.com/",
|
||||
azure_deployment=api_provider_info.deployment_name,
|
||||
)
|
||||
else:
|
||||
raise Exception("Invalid api_provider_info")
|
||||
|
||||
model = MODEL_GPT_4_VISION
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ import os
|
||||
import traceback
|
||||
from fastapi import APIRouter, WebSocket
|
||||
import openai
|
||||
from api_types import AzureProviderInfo, OpenAiProviderInfo
|
||||
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
||||
from llm import stream_openai_response
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
@ -64,6 +65,12 @@ async def stream_code(websocket: WebSocket):
|
||||
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
||||
# If neither is provided, we throw an error.
|
||||
openai_api_key = None
|
||||
azure_openai_api_key = None
|
||||
azure_openai_resource_name = None
|
||||
azure_openai_deployment_name = None
|
||||
azure_openai_api_version = None
|
||||
azure_openai_dalle3_deployment_name = None
|
||||
azure_openai_dalle3_api_version = None
|
||||
if "accessCode" in params and params["accessCode"]:
|
||||
print("Access code - using platform API key")
|
||||
res = await validate_access_token(params["accessCode"])
|
||||
@ -83,15 +90,29 @@ async def stream_code(websocket: WebSocket):
|
||||
print("Using OpenAI API key from client-side settings dialog")
|
||||
else:
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
|
||||
azure_openai_deployment_name = os.environ.get(
|
||||
"AZURE_OPENAI_DEPLOYMENT_NAME"
|
||||
)
|
||||
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||
azure_openai_dalle3_deployment_name = os.environ.get(
|
||||
"AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME"
|
||||
)
|
||||
azure_openai_dalle3_api_version = os.environ.get(
|
||||
"AZURE_OPENAI_DALLE3_API_VERSION"
|
||||
)
|
||||
if openai_api_key:
|
||||
print("Using OpenAI API key from environment variable")
|
||||
if azure_openai_api_key:
|
||||
print("Using Azure OpenAI API key from environment variable")
|
||||
|
||||
if not openai_api_key:
|
||||
print("OpenAI API key not found")
|
||||
if not openai_api_key and not azure_openai_api_key:
|
||||
print("OpenAI API or Azure key not found")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"value": "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.",
|
||||
"value": "No OpenAI API or Azure key found. Please add your API key in the settings dialog or add it to backend/.env file.",
|
||||
}
|
||||
)
|
||||
return
|
||||
@ -190,12 +211,44 @@ async def stream_code(websocket: WebSocket):
|
||||
completion = await mock_completion(process_chunk)
|
||||
else:
|
||||
try:
|
||||
api_provider_info = None
|
||||
if openai_api_key is not None:
|
||||
api_provider_info = {
|
||||
"name": "openai",
|
||||
"api_key": openai_api_key,
|
||||
"base_url": openai_base_url,
|
||||
}
|
||||
|
||||
api_provider_info = OpenAiProviderInfo(
|
||||
api_key=openai_api_key, base_url=openai_base_url
|
||||
)
|
||||
|
||||
if azure_openai_api_key is not None:
|
||||
if (
|
||||
not azure_openai_api_version
|
||||
or not azure_openai_resource_name
|
||||
or not azure_openai_deployment_name
|
||||
):
|
||||
raise Exception(
|
||||
"Missing Azure OpenAI API version, resource name, or deployment name"
|
||||
)
|
||||
|
||||
api_provider_info = AzureProviderInfo(
|
||||
api_key=azure_openai_api_key,
|
||||
api_version=azure_openai_api_version,
|
||||
deployment_name=azure_openai_deployment_name,
|
||||
resource_name=azure_openai_resource_name,
|
||||
)
|
||||
|
||||
if api_provider_info is None:
|
||||
raise Exception("Invalid api_provider_info")
|
||||
|
||||
completion = await stream_openai_response(
|
||||
prompt_messages,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
api_provider_info=api_provider_info,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
|
||||
except openai.AuthenticationError as e:
|
||||
print("[GENERATE_CODE] Authentication failed", e)
|
||||
error_message = (
|
||||
@ -244,6 +297,10 @@ async def stream_code(websocket: WebSocket):
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
image_cache=image_cache,
|
||||
azure_openai_api_key=azure_openai_api_key,
|
||||
azure_openai_dalle3_api_version=azure_openai_dalle3_api_version,
|
||||
azure_openai_resource_name=azure_openai_resource_name,
|
||||
azure_openai_dalle3_deployment_name=azure_openai_dalle3_deployment_name,
|
||||
)
|
||||
else:
|
||||
updated_html = completion
|
||||
|
||||
Loading…
Reference in New Issue
Block a user