diff --git a/backend/api_types.py b/backend/api_types.py new file mode 100644 index 0000000..542ec6a --- /dev/null +++ b/backend/api_types.py @@ -0,0 +1,23 @@ +from pydantic import BaseModel +from typing import Union, Literal, Optional + + +class ApiProviderInfoBase(BaseModel): + name: Literal["openai", "azure"] + + +class OpenAiProviderInfo(ApiProviderInfoBase): + name: Literal["openai"] = "openai" # type: ignore + api_key: str + base_url: Optional[str] = None + + +class AzureProviderInfo(ApiProviderInfoBase): + name: Literal["azure"] = "azure" # type: ignore + api_version: str + api_key: str + deployment_name: str + resource_name: str + + +ApiProviderInfo = Union[OpenAiProviderInfo, AzureProviderInfo] diff --git a/backend/llm.py b/backend/llm.py index 1b04477..eca0965 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -2,54 +2,29 @@ from typing import Awaitable, Callable, List from openai import AsyncOpenAI, AsyncAzureOpenAI from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk +from api_types import ApiProviderInfo + MODEL_GPT_4_VISION = "gpt-4-vision-preview" async def stream_openai_response( messages: List[ChatCompletionMessageParam], - api_key: str, - base_url: str | None, + api_provider_info: ApiProviderInfo, callback: Callable[[str], Awaitable[None]], ) -> str: - client = AsyncOpenAI(api_key=api_key, base_url=base_url) - - model = MODEL_GPT_4_VISION - - # Base parameters - params = {"model": model, "messages": messages, "stream": True, "timeout": 600} - - # Add 'max_tokens' only if the model is a GPT4 vision model - if model == MODEL_GPT_4_VISION: - params["max_tokens"] = 4096 - params["temperature"] = 0 - - stream = await client.chat.completions.create(**params) # type: ignore - full_response = "" - async for chunk in stream: # type: ignore - assert isinstance(chunk, ChatCompletionChunk) - content = chunk.choices[0].delta.content or "" - full_response += content - await callback(content) - - await client.close() - - return full_response - - -async def stream_azure_openai_response( - messages: List[ChatCompletionMessageParam], - azure_openai_api_key: str | None, - azure_openai_api_version: str | None, - azure_openai_resource_name: str | None, - azure_openai_deployment_name: str | None, - callback: Callable[[str], Awaitable[None]], -) -> str: - client = AsyncAzureOpenAI( - api_version=azure_openai_api_version, - api_key=azure_openai_api_key, - azure_endpoint=f"https://{azure_openai_resource_name}.openai.azure.com/", - azure_deployment=azure_openai_deployment_name, - ) + if api_provider_info.name == "openai": + client = AsyncOpenAI( + api_key=api_provider_info.api_key, base_url=api_provider_info.base_url + ) + elif api_provider_info.name == "azure": + client = AsyncAzureOpenAI( + api_version=api_provider_info.api_version, + api_key=api_provider_info.api_key, + azure_endpoint=f"https://{api_provider_info.resource_name}.openai.azure.com/", + azure_deployment=api_provider_info.deployment_name, + ) + else: + raise Exception("Invalid api_provider_info") model = MODEL_GPT_4_VISION diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 401302f..37be602 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -2,8 +2,9 @@ import os import traceback from fastapi import APIRouter, WebSocket import openai +from api_types import AzureProviderInfo, OpenAiProviderInfo from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE -from llm import stream_openai_response, stream_azure_openai_response +from llm import stream_openai_response from openai.types.chat import ChatCompletionMessageParam from mock_llm import mock_completion from typing import Dict, List @@ -210,22 +211,44 @@ async def stream_code(websocket: WebSocket): completion = await mock_completion(process_chunk) else: try: + api_provider_info = None if openai_api_key is not None: - completion = await stream_openai_response( - prompt_messages, - api_key=openai_api_key, - base_url=openai_base_url, - callback=lambda x: process_chunk(x), + api_provider_info = { + "name": "openai", + "api_key": openai_api_key, + "base_url": openai_base_url, + } + + api_provider_info = OpenAiProviderInfo( + api_key=openai_api_key, base_url=openai_base_url ) + if azure_openai_api_key is not None: - completion = await stream_azure_openai_response( - prompt_messages, - azure_openai_api_key=azure_openai_api_key, - azure_openai_api_version=azure_openai_api_version, - azure_openai_resource_name=azure_openai_resource_name, - azure_openai_deployment_name=azure_openai_deployment_name, - callback=lambda x: process_chunk(x), + if ( + not azure_openai_api_version + or not azure_openai_resource_name + or not azure_openai_deployment_name + ): + raise Exception( + "Missing Azure OpenAI API version, resource name, or deployment name" + ) + + api_provider_info = AzureProviderInfo( + api_key=azure_openai_api_key, + api_version=azure_openai_api_version, + deployment_name=azure_openai_deployment_name, + resource_name=azure_openai_resource_name, ) + + if api_provider_info is None: + raise Exception("Invalid api_provider_info") + + completion = await stream_openai_response( + prompt_messages, + api_provider_info=api_provider_info, + callback=lambda x: process_chunk(x), + ) + except openai.AuthenticationError as e: print("[GENERATE_CODE] Authentication failed", e) error_message = (