clean up some code
This commit is contained in:
parent
ce637dbc20
commit
0e0b68abe0
23
backend/api_types.py
Normal file
23
backend/api_types.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Union, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ApiProviderInfoBase(BaseModel):
|
||||||
|
name: Literal["openai", "azure"]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAiProviderInfo(ApiProviderInfoBase):
|
||||||
|
name: Literal["openai"] = "openai" # type: ignore
|
||||||
|
api_key: str
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AzureProviderInfo(ApiProviderInfoBase):
|
||||||
|
name: Literal["azure"] = "azure" # type: ignore
|
||||||
|
api_version: str
|
||||||
|
api_key: str
|
||||||
|
deployment_name: str
|
||||||
|
resource_name: str
|
||||||
|
|
||||||
|
|
||||||
|
ApiProviderInfo = Union[OpenAiProviderInfo, AzureProviderInfo]
|
||||||
@ -2,54 +2,29 @@ from typing import Awaitable, Callable, List
|
|||||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
||||||
|
|
||||||
|
from api_types import ApiProviderInfo
|
||||||
|
|
||||||
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
||||||
|
|
||||||
|
|
||||||
async def stream_openai_response(
|
async def stream_openai_response(
|
||||||
messages: List[ChatCompletionMessageParam],
|
messages: List[ChatCompletionMessageParam],
|
||||||
api_key: str,
|
api_provider_info: ApiProviderInfo,
|
||||||
base_url: str | None,
|
|
||||||
callback: Callable[[str], Awaitable[None]],
|
callback: Callable[[str], Awaitable[None]],
|
||||||
) -> str:
|
) -> str:
|
||||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
if api_provider_info.name == "openai":
|
||||||
|
client = AsyncOpenAI(
|
||||||
model = MODEL_GPT_4_VISION
|
api_key=api_provider_info.api_key, base_url=api_provider_info.base_url
|
||||||
|
|
||||||
# Base parameters
|
|
||||||
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
|
|
||||||
|
|
||||||
# Add 'max_tokens' only if the model is a GPT4 vision model
|
|
||||||
if model == MODEL_GPT_4_VISION:
|
|
||||||
params["max_tokens"] = 4096
|
|
||||||
params["temperature"] = 0
|
|
||||||
|
|
||||||
stream = await client.chat.completions.create(**params) # type: ignore
|
|
||||||
full_response = ""
|
|
||||||
async for chunk in stream: # type: ignore
|
|
||||||
assert isinstance(chunk, ChatCompletionChunk)
|
|
||||||
content = chunk.choices[0].delta.content or ""
|
|
||||||
full_response += content
|
|
||||||
await callback(content)
|
|
||||||
|
|
||||||
await client.close()
|
|
||||||
|
|
||||||
return full_response
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_azure_openai_response(
|
|
||||||
messages: List[ChatCompletionMessageParam],
|
|
||||||
azure_openai_api_key: str | None,
|
|
||||||
azure_openai_api_version: str | None,
|
|
||||||
azure_openai_resource_name: str | None,
|
|
||||||
azure_openai_deployment_name: str | None,
|
|
||||||
callback: Callable[[str], Awaitable[None]],
|
|
||||||
) -> str:
|
|
||||||
client = AsyncAzureOpenAI(
|
|
||||||
api_version=azure_openai_api_version,
|
|
||||||
api_key=azure_openai_api_key,
|
|
||||||
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
|
|
||||||
azure_deployment=azure_openai_deployment_name,
|
|
||||||
)
|
)
|
||||||
|
elif api_provider_info.name == "azure":
|
||||||
|
client = AsyncAzureOpenAI(
|
||||||
|
api_version=api_provider_info.api_version,
|
||||||
|
api_key=api_provider_info.api_key,
|
||||||
|
azure_endpoint=f"https://{api_provider_info.resource_name}.openai.azure.com/",
|
||||||
|
azure_deployment=api_provider_info.deployment_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid api_provider_info")
|
||||||
|
|
||||||
model = MODEL_GPT_4_VISION
|
model = MODEL_GPT_4_VISION
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,9 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, WebSocket
|
||||||
import openai
|
import openai
|
||||||
|
from api_types import AzureProviderInfo, OpenAiProviderInfo
|
||||||
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
||||||
from llm import stream_openai_response, stream_azure_openai_response
|
from llm import stream_openai_response
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
from mock_llm import mock_completion
|
from mock_llm import mock_completion
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
@ -210,22 +211,44 @@ async def stream_code(websocket: WebSocket):
|
|||||||
completion = await mock_completion(process_chunk)
|
completion = await mock_completion(process_chunk)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
api_provider_info = None
|
||||||
if openai_api_key is not None:
|
if openai_api_key is not None:
|
||||||
|
api_provider_info = {
|
||||||
|
"name": "openai",
|
||||||
|
"api_key": openai_api_key,
|
||||||
|
"base_url": openai_base_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
api_provider_info = OpenAiProviderInfo(
|
||||||
|
api_key=openai_api_key, base_url=openai_base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
if azure_openai_api_key is not None:
|
||||||
|
if (
|
||||||
|
not azure_openai_api_version
|
||||||
|
or not azure_openai_resource_name
|
||||||
|
or not azure_openai_deployment_name
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
"Missing Azure OpenAI API version, resource name, or deployment name"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_provider_info = AzureProviderInfo(
|
||||||
|
api_key=azure_openai_api_key,
|
||||||
|
api_version=azure_openai_api_version,
|
||||||
|
deployment_name=azure_openai_deployment_name,
|
||||||
|
resource_name=azure_openai_resource_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if api_provider_info is None:
|
||||||
|
raise Exception("Invalid api_provider_info")
|
||||||
|
|
||||||
completion = await stream_openai_response(
|
completion = await stream_openai_response(
|
||||||
prompt_messages,
|
prompt_messages,
|
||||||
api_key=openai_api_key,
|
api_provider_info=api_provider_info,
|
||||||
base_url=openai_base_url,
|
|
||||||
callback=lambda x: process_chunk(x),
|
|
||||||
)
|
|
||||||
if azure_openai_api_key is not None:
|
|
||||||
completion = await stream_azure_openai_response(
|
|
||||||
prompt_messages,
|
|
||||||
azure_openai_api_key=azure_openai_api_key,
|
|
||||||
azure_openai_api_version=azure_openai_api_version,
|
|
||||||
azure_openai_resource_name=azure_openai_resource_name,
|
|
||||||
azure_openai_deployment_name=azure_openai_deployment_name,
|
|
||||||
callback=lambda x: process_chunk(x),
|
callback=lambda x: process_chunk(x),
|
||||||
)
|
)
|
||||||
|
|
||||||
except openai.AuthenticationError as e:
|
except openai.AuthenticationError as e:
|
||||||
print("[GENERATE_CODE] Authentication failed", e)
|
print("[GENERATE_CODE] Authentication failed", e)
|
||||||
error_message = (
|
error_message = (
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user