clean up some code
This commit is contained in:
parent
ce637dbc20
commit
0e0b68abe0
23
backend/api_types.py
Normal file
23
backend/api_types.py
Normal file
@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Union, Literal, Optional
|
||||
|
||||
|
||||
class ApiProviderInfoBase(BaseModel):
|
||||
name: Literal["openai", "azure"]
|
||||
|
||||
|
||||
class OpenAiProviderInfo(ApiProviderInfoBase):
|
||||
name: Literal["openai"] = "openai" # type: ignore
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
class AzureProviderInfo(ApiProviderInfoBase):
|
||||
name: Literal["azure"] = "azure" # type: ignore
|
||||
api_version: str
|
||||
api_key: str
|
||||
deployment_name: str
|
||||
resource_name: str
|
||||
|
||||
|
||||
ApiProviderInfo = Union[OpenAiProviderInfo, AzureProviderInfo]
|
||||
@ -2,54 +2,29 @@ from typing import Awaitable, Callable, List
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
||||
|
||||
from api_types import ApiProviderInfo
|
||||
|
||||
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
||||
|
||||
|
||||
async def stream_openai_response(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
api_key: str,
|
||||
base_url: str | None,
|
||||
api_provider_info: ApiProviderInfo,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
) -> str:
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
model = MODEL_GPT_4_VISION
|
||||
|
||||
# Base parameters
|
||||
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
|
||||
|
||||
# Add 'max_tokens' only if the model is a GPT4 vision model
|
||||
if model == MODEL_GPT_4_VISION:
|
||||
params["max_tokens"] = 4096
|
||||
params["temperature"] = 0
|
||||
|
||||
stream = await client.chat.completions.create(**params) # type: ignore
|
||||
full_response = ""
|
||||
async for chunk in stream: # type: ignore
|
||||
assert isinstance(chunk, ChatCompletionChunk)
|
||||
content = chunk.choices[0].delta.content or ""
|
||||
full_response += content
|
||||
await callback(content)
|
||||
|
||||
await client.close()
|
||||
|
||||
return full_response
|
||||
|
||||
|
||||
async def stream_azure_openai_response(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
azure_openai_api_key: str | None,
|
||||
azure_openai_api_version: str | None,
|
||||
azure_openai_resource_name: str | None,
|
||||
azure_openai_deployment_name: str | None,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
) -> str:
|
||||
client = AsyncAzureOpenAI(
|
||||
api_version=azure_openai_api_version,
|
||||
api_key=azure_openai_api_key,
|
||||
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
|
||||
azure_deployment=azure_openai_deployment_name,
|
||||
if api_provider_info.name == "openai":
|
||||
client = AsyncOpenAI(
|
||||
api_key=api_provider_info.api_key, base_url=api_provider_info.base_url
|
||||
)
|
||||
elif api_provider_info.name == "azure":
|
||||
client = AsyncAzureOpenAI(
|
||||
api_version=api_provider_info.api_version,
|
||||
api_key=api_provider_info.api_key,
|
||||
azure_endpoint=f"https://{api_provider_info.resource_name}.openai.azure.com/",
|
||||
azure_deployment=api_provider_info.deployment_name,
|
||||
)
|
||||
else:
|
||||
raise Exception("Invalid api_provider_info")
|
||||
|
||||
model = MODEL_GPT_4_VISION
|
||||
|
||||
|
||||
@ -2,8 +2,9 @@ import os
|
||||
import traceback
|
||||
from fastapi import APIRouter, WebSocket
|
||||
import openai
|
||||
from api_types import AzureProviderInfo, OpenAiProviderInfo
|
||||
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
|
||||
from llm import stream_openai_response, stream_azure_openai_response
|
||||
from llm import stream_openai_response
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from mock_llm import mock_completion
|
||||
from typing import Dict, List
|
||||
@ -210,22 +211,44 @@ async def stream_code(websocket: WebSocket):
|
||||
completion = await mock_completion(process_chunk)
|
||||
else:
|
||||
try:
|
||||
api_provider_info = None
|
||||
if openai_api_key is not None:
|
||||
api_provider_info = {
|
||||
"name": "openai",
|
||||
"api_key": openai_api_key,
|
||||
"base_url": openai_base_url,
|
||||
}
|
||||
|
||||
api_provider_info = OpenAiProviderInfo(
|
||||
api_key=openai_api_key, base_url=openai_base_url
|
||||
)
|
||||
|
||||
if azure_openai_api_key is not None:
|
||||
if (
|
||||
not azure_openai_api_version
|
||||
or not azure_openai_resource_name
|
||||
or not azure_openai_deployment_name
|
||||
):
|
||||
raise Exception(
|
||||
"Missing Azure OpenAI API version, resource name, or deployment name"
|
||||
)
|
||||
|
||||
api_provider_info = AzureProviderInfo(
|
||||
api_key=azure_openai_api_key,
|
||||
api_version=azure_openai_api_version,
|
||||
deployment_name=azure_openai_deployment_name,
|
||||
resource_name=azure_openai_resource_name,
|
||||
)
|
||||
|
||||
if api_provider_info is None:
|
||||
raise Exception("Invalid api_provider_info")
|
||||
|
||||
completion = await stream_openai_response(
|
||||
prompt_messages,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
if azure_openai_api_key is not None:
|
||||
completion = await stream_azure_openai_response(
|
||||
prompt_messages,
|
||||
azure_openai_api_key=azure_openai_api_key,
|
||||
azure_openai_api_version=azure_openai_api_version,
|
||||
azure_openai_resource_name=azure_openai_resource_name,
|
||||
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||
api_provider_info=api_provider_info,
|
||||
callback=lambda x: process_chunk(x),
|
||||
)
|
||||
|
||||
except openai.AuthenticationError as e:
|
||||
print("[GENERATE_CODE] Authentication failed", e)
|
||||
error_message = (
|
||||
|
||||
Loading…
Reference in New Issue
Block a user