clean up some code

This commit is contained in:
Abi Raja 2023-12-16 22:02:33 -05:00
parent ce637dbc20
commit 0e0b68abe0
3 changed files with 75 additions and 54 deletions

23
backend/api_types.py Normal file
View File

@ -0,0 +1,23 @@
from pydantic import BaseModel
from typing import Union, Literal, Optional
class ApiProviderInfoBase(BaseModel):
name: Literal["openai", "azure"]
class OpenAiProviderInfo(ApiProviderInfoBase):
name: Literal["openai"] = "openai" # type: ignore
api_key: str
base_url: Optional[str] = None
class AzureProviderInfo(ApiProviderInfoBase):
name: Literal["azure"] = "azure" # type: ignore
api_version: str
api_key: str
deployment_name: str
resource_name: str
ApiProviderInfo = Union[OpenAiProviderInfo, AzureProviderInfo]

View File

@ -2,54 +2,29 @@ from typing import Awaitable, Callable, List
from openai import AsyncOpenAI, AsyncAzureOpenAI from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
from api_types import ApiProviderInfo
MODEL_GPT_4_VISION = "gpt-4-vision-preview" MODEL_GPT_4_VISION = "gpt-4-vision-preview"
async def stream_openai_response( async def stream_openai_response(
messages: List[ChatCompletionMessageParam], messages: List[ChatCompletionMessageParam],
api_key: str, api_provider_info: ApiProviderInfo,
base_url: str | None,
callback: Callable[[str], Awaitable[None]], callback: Callable[[str], Awaitable[None]],
) -> str: ) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url) if api_provider_info.name == "openai":
client = AsyncOpenAI(
model = MODEL_GPT_4_VISION api_key=api_provider_info.api_key, base_url=api_provider_info.base_url
)
# Base parameters elif api_provider_info.name == "azure":
params = {"model": model, "messages": messages, "stream": True, "timeout": 600} client = AsyncAzureOpenAI(
api_version=api_provider_info.api_version,
# Add 'max_tokens' only if the model is a GPT4 vision model api_key=api_provider_info.api_key,
if model == MODEL_GPT_4_VISION: azure_endpoint=f"https://{api_provider_info.resource_name}.openai.azure.com/",
params["max_tokens"] = 4096 azure_deployment=api_provider_info.deployment_name,
params["temperature"] = 0 )
else:
stream = await client.chat.completions.create(**params) # type: ignore raise Exception("Invalid api_provider_info")
full_response = ""
async for chunk in stream: # type: ignore
assert isinstance(chunk, ChatCompletionChunk)
content = chunk.choices[0].delta.content or ""
full_response += content
await callback(content)
await client.close()
return full_response
async def stream_azure_openai_response(
messages: List[ChatCompletionMessageParam],
azure_openai_api_key: str | None,
azure_openai_api_version: str | None,
azure_openai_resource_name: str | None,
azure_openai_deployment_name: str | None,
callback: Callable[[str], Awaitable[None]],
) -> str:
client = AsyncAzureOpenAI(
api_version=azure_openai_api_version,
api_key=azure_openai_api_key,
azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/",
azure_deployment=azure_openai_deployment_name,
)
model = MODEL_GPT_4_VISION model = MODEL_GPT_4_VISION

View File

@ -2,8 +2,9 @@ import os
import traceback import traceback
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import openai import openai
from api_types import AzureProviderInfo, OpenAiProviderInfo
from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE
from llm import stream_openai_response, stream_azure_openai_response from llm import stream_openai_response
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion from mock_llm import mock_completion
from typing import Dict, List from typing import Dict, List
@ -210,22 +211,44 @@ async def stream_code(websocket: WebSocket):
completion = await mock_completion(process_chunk) completion = await mock_completion(process_chunk)
else: else:
try: try:
api_provider_info = None
if openai_api_key is not None: if openai_api_key is not None:
completion = await stream_openai_response( api_provider_info = {
prompt_messages, "name": "openai",
api_key=openai_api_key, "api_key": openai_api_key,
base_url=openai_base_url, "base_url": openai_base_url,
callback=lambda x: process_chunk(x), }
api_provider_info = OpenAiProviderInfo(
api_key=openai_api_key, base_url=openai_base_url
) )
if azure_openai_api_key is not None: if azure_openai_api_key is not None:
completion = await stream_azure_openai_response( if (
prompt_messages, not azure_openai_api_version
azure_openai_api_key=azure_openai_api_key, or not azure_openai_resource_name
azure_openai_api_version=azure_openai_api_version, or not azure_openai_deployment_name
azure_openai_resource_name=azure_openai_resource_name, ):
azure_openai_deployment_name=azure_openai_deployment_name, raise Exception(
callback=lambda x: process_chunk(x), "Missing Azure OpenAI API version, resource name, or deployment name"
)
api_provider_info = AzureProviderInfo(
api_key=azure_openai_api_key,
api_version=azure_openai_api_version,
deployment_name=azure_openai_deployment_name,
resource_name=azure_openai_resource_name,
) )
if api_provider_info is None:
raise Exception("Invalid api_provider_info")
completion = await stream_openai_response(
prompt_messages,
api_provider_info=api_provider_info,
callback=lambda x: process_chunk(x),
)
except openai.AuthenticationError as e: except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e) print("[GENERATE_CODE] Authentication failed", e)
error_message = ( error_message = (