screenshot-to-code/backend/llm.py
2024-03-05 12:24:30 -05:00

91 lines
2.8 KiB
Python

from typing import Awaitable, Callable, List, cast
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
MODEL_CLAUDE_SONNET = "claude-3-sonnet-20240229"
MODEL_CLAUDE_OPUS = "claude-3-opus-20240229"
# Keep in sync with frontend (lib/models.ts)
CODE_GENERATION_MODELS = [
"gpt_4_vision",
"claude_3_sonnet",
]
async def stream_openai_response(
messages: List[ChatCompletionMessageParam],
api_key: str,
base_url: str | None,
callback: Callable[[str], Awaitable[None]],
) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
model = MODEL_GPT_4_VISION
# Base parameters
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
# Add 'max_tokens' only if the model is a GPT4 vision model
if model == MODEL_GPT_4_VISION:
params["max_tokens"] = 4096
params["temperature"] = 0
stream = await client.chat.completions.create(**params) # type: ignore
full_response = ""
async for chunk in stream: # type: ignore
assert isinstance(chunk, ChatCompletionChunk)
content = chunk.choices[0].delta.content or ""
full_response += content
await callback(content)
await client.close()
return full_response
async def stream_claude_response(
messages: List[ChatCompletionMessageParam],
api_key: str,
callback: Callable[[str], Awaitable[None]],
) -> str:
client = AsyncAnthropic(api_key=api_key)
# Base parameters
model = MODEL_CLAUDE_SONNET
max_tokens = 4096
temperature = 0.0
# Translate OpenAI messages to Claude messages
system_prompt = cast(str, messages[0]["content"])
claude_messages = [dict(message) for message in messages[1:]]
for message in claude_messages:
if not isinstance(message["content"], list):
continue
for content in message["content"]: # type: ignore
if content["type"] == "image_url":
content["type"] = "image"
content["source"] = {
"type": "base64",
"media_type": "image/png", # TODO: Automatically detect media type
"data": content["image_url"]["url"].split(",")[1], # type: ignore
}
del content["image_url"]
async with client.messages.stream(
model=model,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt,
messages=claude_messages, # type: ignore
) as stream:
async for text in stream.text_stream:
await callback(text)
response = await stream.get_final_message()
return response.content[0].text