102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
from typing import Awaitable, Callable, List, cast
|
|
from anthropic import AsyncAnthropic
|
|
from openai import AsyncOpenAI
|
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
|
|
|
|
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
|
|
MODEL_CLAUDE_SONNET = "claude-3-sonnet-20240229"
|
|
MODEL_CLAUDE_OPUS = "claude-3-opus-20240229"
|
|
|
|
|
|
# Keep in sync with frontend (lib/models.ts)
|
|
CODE_GENERATION_MODELS = [
|
|
"gpt_4_vision",
|
|
"claude_3_sonnet",
|
|
]
|
|
|
|
|
|
async def stream_openai_response(
|
|
messages: List[ChatCompletionMessageParam],
|
|
api_key: str,
|
|
base_url: str | None,
|
|
callback: Callable[[str], Awaitable[None]],
|
|
) -> str:
|
|
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
|
|
model = MODEL_GPT_4_VISION
|
|
|
|
# Base parameters
|
|
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
|
|
|
|
# Add 'max_tokens' only if the model is a GPT4 vision model
|
|
if model == MODEL_GPT_4_VISION:
|
|
params["max_tokens"] = 4096
|
|
params["temperature"] = 0
|
|
|
|
stream = await client.chat.completions.create(**params) # type: ignore
|
|
full_response = ""
|
|
async for chunk in stream: # type: ignore
|
|
assert isinstance(chunk, ChatCompletionChunk)
|
|
content = chunk.choices[0].delta.content or ""
|
|
full_response += content
|
|
await callback(content)
|
|
|
|
await client.close()
|
|
|
|
return full_response
|
|
|
|
|
|
async def stream_claude_response(
|
|
messages: List[ChatCompletionMessageParam],
|
|
api_key: str,
|
|
callback: Callable[[str], Awaitable[None]],
|
|
) -> str:
|
|
|
|
client = AsyncAnthropic(api_key=api_key)
|
|
|
|
# Base parameters
|
|
model = MODEL_CLAUDE_SONNET
|
|
max_tokens = 4096
|
|
temperature = 0.0
|
|
|
|
# Translate OpenAI messages to Claude messages
|
|
system_prompt = cast(str, messages[0]["content"])
|
|
claude_messages = [dict(message) for message in messages[1:]]
|
|
for message in claude_messages:
|
|
if not isinstance(message["content"], list):
|
|
continue
|
|
|
|
for content in message["content"]: # type: ignore
|
|
if content["type"] == "image_url":
|
|
content["type"] = "image"
|
|
|
|
# Extract base64 data and media type from data URL
|
|
# Example base64 data URL: data:image/png;base64,iVBOR...
|
|
image_data_url = cast(str, content["image_url"]["url"])
|
|
media_type = image_data_url.split(";")[0].split(":")[1]
|
|
base64_data = image_data_url.split(",")[1]
|
|
|
|
# Remove OpenAI parameter
|
|
del content["image_url"]
|
|
|
|
content["source"] = {
|
|
"type": "base64",
|
|
"media_type": media_type,
|
|
"data": base64_data,
|
|
}
|
|
|
|
# Stream Claude response
|
|
async with client.messages.stream(
|
|
model=model,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
system=system_prompt,
|
|
messages=claude_messages, # type: ignore
|
|
) as stream:
|
|
async for text in stream.text_stream:
|
|
await callback(text)
|
|
|
|
# Return final message
|
|
response = await stream.get_final_message()
|
|
return response.content[0].text
|