from typing import Awaitable, Callable, List, cast from anthropic import AsyncAnthropic from openai import AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk MODEL_GPT_4_VISION = "gpt-4-vision-preview" MODEL_CLAUDE_SONNET = "claude-3-sonnet-20240229" MODEL_CLAUDE_OPUS = "claude-3-opus-20240229" # Keep in sync with frontend (lib/models.ts) CODE_GENERATION_MODELS = [ "gpt_4_vision", "claude_3_sonnet", ] async def stream_openai_response( messages: List[ChatCompletionMessageParam], api_key: str, base_url: str | None, callback: Callable[[str], Awaitable[None]], ) -> str: client = AsyncOpenAI(api_key=api_key, base_url=base_url) model = MODEL_GPT_4_VISION # Base parameters params = {"model": model, "messages": messages, "stream": True, "timeout": 600} # Add 'max_tokens' only if the model is a GPT4 vision model if model == MODEL_GPT_4_VISION: params["max_tokens"] = 4096 params["temperature"] = 0 stream = await client.chat.completions.create(**params) # type: ignore full_response = "" async for chunk in stream: # type: ignore assert isinstance(chunk, ChatCompletionChunk) content = chunk.choices[0].delta.content or "" full_response += content await callback(content) await client.close() return full_response async def stream_claude_response( messages: List[ChatCompletionMessageParam], api_key: str, callback: Callable[[str], Awaitable[None]], ) -> str: client = AsyncAnthropic(api_key=api_key) # Base parameters model = MODEL_CLAUDE_SONNET max_tokens = 4096 temperature = 0.0 # Translate OpenAI messages to Claude messages system_prompt = cast(str, messages[0]["content"]) claude_messages = [dict(message) for message in messages[1:]] for message in claude_messages: if not isinstance(message["content"], list): continue for content in message["content"]: # type: ignore if content["type"] == "image_url": content["type"] = "image" content["source"] = { "type": "base64", "media_type": "image/png", # TODO: Automatically detect media type "data": content["image_url"]["url"].split(",")[1], # type: ignore } del content["image_url"] async with client.messages.stream( model=model, max_tokens=max_tokens, temperature=temperature, system=system_prompt, messages=claude_messages, # type: ignore ) as stream: async for text in stream.text_stream: await callback(text) response = await stream.get_final_message() return response.content[0].text