screenshot-to-code/backend/llm.py
2023-12-16 22:02:33 -05:00

50 lines
1.7 KiB
Python

from typing import Awaitable, Callable, List
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
from api_types import ApiProviderInfo
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
async def stream_openai_response(
messages: List[ChatCompletionMessageParam],
api_provider_info: ApiProviderInfo,
callback: Callable[[str], Awaitable[None]],
) -> str:
if api_provider_info.name == "openai":
client = AsyncOpenAI(
api_key=api_provider_info.api_key, base_url=api_provider_info.base_url
)
elif api_provider_info.name == "azure":
client = AsyncAzureOpenAI(
api_version=api_provider_info.api_version,
api_key=api_provider_info.api_key,
azure_endpoint=f"https://{api_provider_info.resource_name}.openai.azure.com/",
azure_deployment=api_provider_info.deployment_name,
)
else:
raise Exception("Invalid api_provider_info")
model = MODEL_GPT_4_VISION
# Base parameters
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
# Add 'max_tokens' only if the model is a GPT4 vision model
if model == MODEL_GPT_4_VISION:
params["max_tokens"] = 4096
params["temperature"] = 0
stream = await client.chat.completions.create(**params) # type: ignore
full_response = ""
async for chunk in stream: # type: ignore
assert isinstance(chunk, ChatCompletionChunk)
content = chunk.choices[0].delta.content or ""
full_response += content
await callback(content)
await client.close()
return full_response