Merge 0e0b68abe0 into 5912957514
This commit is contained in:
commit
25d6ac7e01
37
README.md
37
README.md
@ -26,9 +26,11 @@ See the [Examples](#-examples) section below for more demos.
|
|||||||
|
|
||||||
## 🛠 Getting Started
|
## 🛠 Getting Started
|
||||||
|
|
||||||
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API key with access to the GPT-4 Vision API.
|
The app has a React/Vite frontend and a FastAPI backend. You will need an OpenAI API/Azure key with access to the GPT-4 Vision API.
|
||||||
|
|
||||||
Run the backend (I use Poetry for package management - `pip install poetry` if you don't have it):
|
Run the backend based on the AI provider you want to use (I use Poetry for package management - `pip install poetry` if you don't have it):
|
||||||
|
|
||||||
|
For OpenAI Version:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd backend
|
cd backend
|
||||||
@ -38,6 +40,21 @@ poetry shell
|
|||||||
poetry run uvicorn main:app --reload --port 7001
|
poetry run uvicorn main:app --reload --port 7001
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For Azure version, you need to add some additional environment keys (vision and dalle3 deployment must be int the same resource on Azure):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
|
||||||
|
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
|
||||||
|
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
|
||||||
|
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
|
||||||
|
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
|
||||||
|
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
|
||||||
|
poetry install
|
||||||
|
poetry shell
|
||||||
|
poetry run uvicorn main:app --reload --port 7001
|
||||||
|
```
|
||||||
|
|
||||||
Run the frontend:
|
Run the frontend:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -58,17 +75,31 @@ MOCK=true poetry run uvicorn main:app --reload --port 7001
|
|||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
* You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
|
- You can configure the OpenAI base URL if you need to use a proxy: Set OPENAI_BASE_URL in the `backend/.env` or directly in the UI in the settings dialog
|
||||||
|
|
||||||
## Docker
|
## Docker
|
||||||
|
|
||||||
If you have Docker installed on your system, in the root directory, run:
|
If you have Docker installed on your system, in the root directory, run:
|
||||||
|
|
||||||
|
For OpenAI Version:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
echo "OPENAI_API_KEY=sk-your-key" > .env
|
echo "OPENAI_API_KEY=sk-your-key" > .env
|
||||||
docker-compose up -d --build
|
docker-compose up -d --build
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For Azure version:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo "AZURE_OPENAI_API_KEY=sk-your-key" > .env
|
||||||
|
echo "AZURE_OPENAI_RESOURCE_NAME=azure_resource_name" > .env
|
||||||
|
echo "AZURE_OPENAI_DEPLOYMENT_NAME=azure_deployment_name" > .env
|
||||||
|
echo "AZURE_OPENAI_API_VERSION=azure_api_version" > .env
|
||||||
|
echo "AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME=azure_dalle3_deployment_name"> .env
|
||||||
|
echo "AZURE_OPENAI_DALLE3_API_VERSION=azure_dalle3_api_version" > .env
|
||||||
|
docker-compose up -d --build
|
||||||
|
```
|
||||||
|
|
||||||
The app will be up and running at http://localhost:5173. Note that you can't develop the application with this setup as the file changes won't trigger a rebuild.
|
The app will be up and running at http://localhost:5173. Note that you can't develop the application with this setup as the file changes won't trigger a rebuild.
|
||||||
|
|
||||||
## 🙋♂️ FAQs
|
## 🙋♂️ FAQs
|
||||||
|
|||||||
23
backend/api_types.py
Normal file
23
backend/api_types.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Union, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ApiProviderInfoBase(BaseModel):
|
||||||
|
name: Literal["openai", "azure"]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAiProviderInfo(ApiProviderInfoBase):
|
||||||
|
name: Literal["openai"] = "openai" # type: ignore
|
||||||
|
api_key: str
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AzureProviderInfo(ApiProviderInfoBase):
|
||||||
|
name: Literal["azure"] = "azure" # type: ignore
|
||||||
|
api_version: str
|
||||||
|
api_key: str
|
||||||
|
deployment_name: str
|
||||||
|
resource_name: str
|
||||||
|
|
||||||
|
|
||||||
|
ApiProviderInfo = Union[OpenAiProviderInfo, AzureProviderInfo]
|
||||||
@ -8,7 +8,7 @@ from eval_utils import image_to_data_url
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from llm import stream_openai_response
|
from llm import stream_openai_response, stream_azure_openai_response
|
||||||
from prompts import assemble_prompt
|
from prompts import assemble_prompt
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
@ -19,21 +19,35 @@ async def generate_code_core(image_url: str, stack: str) -> str:
|
|||||||
prompt_messages = assemble_prompt(image_url, stack)
|
prompt_messages = assemble_prompt(image_url, stack)
|
||||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
openai_base_url = None
|
openai_base_url = None
|
||||||
|
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||||
|
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
|
||||||
|
azure_openai_deployment_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
|
||||||
|
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||||
|
|
||||||
pprint_prompt(prompt_messages)
|
pprint_prompt(prompt_messages)
|
||||||
|
|
||||||
async def process_chunk(content: str):
|
async def process_chunk(content: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if not openai_api_key:
|
if not openai_api_key and not azure_openai_api_key:
|
||||||
raise Exception("OpenAI API key not found")
|
raise Exception("OpenAI API or Azure key not found")
|
||||||
|
|
||||||
|
if not openai_api_key:
|
||||||
completion = await stream_openai_response(
|
completion = await stream_openai_response(
|
||||||
prompt_messages,
|
prompt_messages,
|
||||||
api_key=openai_api_key,
|
api_key=openai_api_key,
|
||||||
base_url=openai_base_url,
|
base_url=openai_base_url,
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x: process_chunk(x),
|
||||||
)
|
)
|
||||||
|
if not azure_openai_api_key:
|
||||||
|
completion = await stream_azure_openai_response(
|
||||||
|
prompt_messages,
|
||||||
|
azure_openai_api_key=azure_openai_api_key,
|
||||||
|
azure_openai_api_version=azure_openai_api_version,
|
||||||
|
azure_openai_resource_name=azure_openai_resource_name,
|
||||||
|
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||||
|
callback=lambda x: process_chunk(x),
|
||||||
|
)
|
||||||
|
|
||||||
return completion
|
return completion
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,32 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
|
||||||
async def process_tasks(prompts: List[str], api_key: str, base_url: str):
|
async def process_tasks(
|
||||||
|
prompts: List[str],
|
||||||
|
api_key: str | None,
|
||||||
|
base_url: str | None,
|
||||||
|
azure_openai_api_key: str | None,
|
||||||
|
azure_openai_dalle3_api_version: str | None,
|
||||||
|
azure_openai_resource_name: str | None,
|
||||||
|
azure_openai_dalle3_deployment_name: str | None,
|
||||||
|
):
|
||||||
|
if api_key is not None:
|
||||||
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
|
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
|
||||||
|
if azure_openai_api_key is not None:
|
||||||
|
tasks = [
|
||||||
|
generate_image_azure(
|
||||||
|
prompt,
|
||||||
|
azure_openai_api_key,
|
||||||
|
azure_openai_dalle3_api_version,
|
||||||
|
azure_openai_resource_name,
|
||||||
|
azure_openai_dalle3_deployment_name,
|
||||||
|
)
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
processed_results: List[Union[str, None]] = []
|
processed_results: List[Union[str, None]] = []
|
||||||
@ -35,6 +55,32 @@ async def generate_image(prompt: str, api_key: str, base_url: str):
|
|||||||
return res.data[0].url
|
return res.data[0].url
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_image_azure(
|
||||||
|
prompt: str,
|
||||||
|
azure_openai_api_key: str,
|
||||||
|
azure_openai_api_version: str,
|
||||||
|
azure_openai_resource_name: str,
|
||||||
|
azure_openai_dalle3_deployment_name: str,
|
||||||
|
):
|
||||||
|
client = AsyncAzureOpenAI(
|
||||||
|
api_version=azure_openai_api_version,
|
||||||
|
api_key=azure_openai_api_key,
|
||||||
|
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
|
||||||
|
azure_deployment=azure_openai_dalle3_deployment_name,
|
||||||
|
)
|
||||||
|
image_params: Dict[str, Union[str, int]] = {
|
||||||
|
"model": "dall-e-3",
|
||||||
|
"quality": "standard",
|
||||||
|
"style": "natural",
|
||||||
|
"n": 1,
|
||||||
|
"size": "1024x1024",
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
res = await client.images.generate(**image_params)
|
||||||
|
await client.close()
|
||||||
|
return res.data[0].url
|
||||||
|
|
||||||
|
|
||||||
def extract_dimensions(url: str):
|
def extract_dimensions(url: str):
|
||||||
# Regular expression to match numbers in the format '300x200'
|
# Regular expression to match numbers in the format '300x200'
|
||||||
matches = re.findall(r"(\d+)x(\d+)", url)
|
matches = re.findall(r"(\d+)x(\d+)", url)
|
||||||
@ -62,7 +108,14 @@ def create_alt_url_mapping(code: str) -> Dict[str, str]:
|
|||||||
|
|
||||||
|
|
||||||
async def generate_images(
|
async def generate_images(
|
||||||
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
|
code: str,
|
||||||
|
api_key: str | None,
|
||||||
|
base_url: Union[str, None] | None,
|
||||||
|
image_cache: Dict[str, str],
|
||||||
|
azure_openai_api_key: str | None,
|
||||||
|
azure_openai_dalle3_api_version: str | None,
|
||||||
|
azure_openai_resource_name: str | None,
|
||||||
|
azure_openai_dalle3_deployment_name: str | None,
|
||||||
):
|
):
|
||||||
# Find all images
|
# Find all images
|
||||||
soup = BeautifulSoup(code, "html.parser")
|
soup = BeautifulSoup(code, "html.parser")
|
||||||
@ -90,7 +143,15 @@ async def generate_images(
|
|||||||
return code
|
return code
|
||||||
|
|
||||||
# Generate images
|
# Generate images
|
||||||
results = await process_tasks(prompts, api_key, base_url)
|
results = await process_tasks(
|
||||||
|
prompts,
|
||||||
|
api_key,
|
||||||
|
base_url,
|
||||||
|
azure_openai_api_key,
|
||||||
|
azure_openai_dalle3_api_version,
|
||||||
|
azure_openai_resource_name,
|
||||||
|
azure_openai_dalle3_deployment_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Create a dict mapping alt text to image URL
|
# Create a dict mapping alt text to image URL
|
||||||
mapped_image_urls = dict(zip(prompts, results))
|
mapped_image_urls = dict(zip(prompts, results))
|
||||||
|
|||||||
@ -1,17 +1,30 @@
|
|||||||
from typing import Awaitable, Callable, List
|
from typing import Awaitable, Callable, List
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
||||||
|
|
||||||
|
from api_types import ApiProviderInfo
|
||||||
|
|
||||||
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
||||||
|
|
||||||
|
|
||||||
async def stream_openai_response(
|
async def stream_openai_response(
|
||||||
messages: List[ChatCompletionMessageParam],
|
messages: List[ChatCompletionMessageParam],
|
||||||
api_key: str,
|
api_provider_info: ApiProviderInfo,
|
||||||
base_url: str | None,
|
|
||||||
callback: Callable[[str], Awaitable[None]],
|
callback: Callable[[str], Awaitable[None]],
|
||||||
) -> str:
|
) -> str:
|
||||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
if api_provider_info.name == "openai":
|
||||||
|
client = AsyncOpenAI(
|
||||||
|
api_key=api_provider_info.api_key, base_url=api_provider_info.base_url
|
||||||
|
)
|
||||||
|
elif api_provider_info.name == "azure":
|
||||||
|
client = AsyncAzureOpenAI(
|
||||||
|
api_version=api_provider_info.api_version,
|
||||||
|
api_key=api_provider_info.api_key,
|
||||||
|
azure_endpoint=f"https://{api_provider_info.resource_name}.openai.azure.com/",
|
||||||
|
azure_deployment=api_provider_info.deployment_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid api_provider_info")
|
||||||
|
|
||||||
model = MODEL_GPT_4_VISION
|
model = MODEL_GPT_4_VISION
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, WebSocket
|
||||||
import openai
|
import openai
|
||||||
|
from api_types import AzureProviderInfo, OpenAiProviderInfo
|
||||||
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
||||||
from llm import stream_openai_response
|
from llm import stream_openai_response
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
@ -64,6 +65,12 @@ async def stream_code(websocket: WebSocket):
|
|||||||
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
|
||||||
# If neither is provided, we throw an error.
|
# If neither is provided, we throw an error.
|
||||||
openai_api_key = None
|
openai_api_key = None
|
||||||
|
azure_openai_api_key = None
|
||||||
|
azure_openai_resource_name = None
|
||||||
|
azure_openai_deployment_name = None
|
||||||
|
azure_openai_api_version = None
|
||||||
|
azure_openai_dalle3_deployment_name = None
|
||||||
|
azure_openai_dalle3_api_version = None
|
||||||
if "accessCode" in params and params["accessCode"]:
|
if "accessCode" in params and params["accessCode"]:
|
||||||
print("Access code - using platform API key")
|
print("Access code - using platform API key")
|
||||||
res = await validate_access_token(params["accessCode"])
|
res = await validate_access_token(params["accessCode"])
|
||||||
@ -83,15 +90,29 @@ async def stream_code(websocket: WebSocket):
|
|||||||
print("Using OpenAI API key from client-side settings dialog")
|
print("Using OpenAI API key from client-side settings dialog")
|
||||||
else:
|
else:
|
||||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
azure_openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||||
|
azure_openai_resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
|
||||||
|
azure_openai_deployment_name = os.environ.get(
|
||||||
|
"AZURE_OPENAI_DEPLOYMENT_NAME"
|
||||||
|
)
|
||||||
|
azure_openai_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||||
|
azure_openai_dalle3_deployment_name = os.environ.get(
|
||||||
|
"AZURE_OPENAI_DALLE3_DEPLOYMENT_NAME"
|
||||||
|
)
|
||||||
|
azure_openai_dalle3_api_version = os.environ.get(
|
||||||
|
"AZURE_OPENAI_DALLE3_API_VERSION"
|
||||||
|
)
|
||||||
if openai_api_key:
|
if openai_api_key:
|
||||||
print("Using OpenAI API key from environment variable")
|
print("Using OpenAI API key from environment variable")
|
||||||
|
if azure_openai_api_key:
|
||||||
|
print("Using Azure OpenAI API key from environment variable")
|
||||||
|
|
||||||
if not openai_api_key:
|
if not openai_api_key and not azure_openai_api_key:
|
||||||
print("OpenAI API key not found")
|
print("OpenAI API or Azure key not found")
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"value": "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.",
|
"value": "No OpenAI API or Azure key found. Please add your API key in the settings dialog or add it to backend/.env file.",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -190,12 +211,44 @@ async def stream_code(websocket: WebSocket):
|
|||||||
completion = await mock_completion(process_chunk)
|
completion = await mock_completion(process_chunk)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
api_provider_info = None
|
||||||
|
if openai_api_key is not None:
|
||||||
|
api_provider_info = {
|
||||||
|
"name": "openai",
|
||||||
|
"api_key": openai_api_key,
|
||||||
|
"base_url": openai_base_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
api_provider_info = OpenAiProviderInfo(
|
||||||
|
api_key=openai_api_key, base_url=openai_base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
if azure_openai_api_key is not None:
|
||||||
|
if (
|
||||||
|
not azure_openai_api_version
|
||||||
|
or not azure_openai_resource_name
|
||||||
|
or not azure_openai_deployment_name
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
"Missing Azure OpenAI API version, resource name, or deployment name"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_provider_info = AzureProviderInfo(
|
||||||
|
api_key=azure_openai_api_key,
|
||||||
|
api_version=azure_openai_api_version,
|
||||||
|
deployment_name=azure_openai_deployment_name,
|
||||||
|
resource_name=azure_openai_resource_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if api_provider_info is None:
|
||||||
|
raise Exception("Invalid api_provider_info")
|
||||||
|
|
||||||
completion = await stream_openai_response(
|
completion = await stream_openai_response(
|
||||||
prompt_messages,
|
prompt_messages,
|
||||||
api_key=openai_api_key,
|
api_provider_info=api_provider_info,
|
||||||
base_url=openai_base_url,
|
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x: process_chunk(x),
|
||||||
)
|
)
|
||||||
|
|
||||||
except openai.AuthenticationError as e:
|
except openai.AuthenticationError as e:
|
||||||
print("[GENERATE_CODE] Authentication failed", e)
|
print("[GENERATE_CODE] Authentication failed", e)
|
||||||
error_message = (
|
error_message = (
|
||||||
@ -244,6 +297,10 @@ async def stream_code(websocket: WebSocket):
|
|||||||
api_key=openai_api_key,
|
api_key=openai_api_key,
|
||||||
base_url=openai_base_url,
|
base_url=openai_base_url,
|
||||||
image_cache=image_cache,
|
image_cache=image_cache,
|
||||||
|
azure_openai_api_key=azure_openai_api_key,
|
||||||
|
azure_openai_dalle3_api_version=azure_openai_dalle3_api_version,
|
||||||
|
azure_openai_resource_name=azure_openai_resource_name,
|
||||||
|
azure_openai_dalle3_deployment_name=azure_openai_dalle3_deployment_name,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
updated_html = completion
|
updated_html = completion
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user