screenshot-to-code/screenshottocode/llm.py
2023-12-26 19:03:45 +00:00

104 lines
3.1 KiB
Python

import io
from typing import Awaitable, Callable, List, Any
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
import google.generativeai as genai
from PIL import Image
import base64
import os
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
MODEL_GEMINI_PRO_VISION = "models/gemini-pro-vision"
async def format_response(response: str)-> str:
response = response.strip()
if response.startswith('```'):
response = '\n'.join(response.split('\n')[1:-2])
return response
async def format_messages_for_gemnini(messages: List[Any]):
system_promt = messages[0]['content']
image_base64 = messages[1]['content'][0]['image_url']['url']
additional_prompt = messages[1]['content'][1]['text']
base64_data = image_base64.split(",")[1]
decoded_bytes = base64.b64decode(base64_data)
image_bytes = io.BytesIO(decoded_bytes)
image = Image.open(image_bytes)
return [system_promt, image, additional_prompt]
async def stream_openai_response(
messages: List[ChatCompletionMessageParam],
api_key: str,
base_url: str | None,
callback: Callable[[str], Awaitable[None]],
) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
model = MODEL_GPT_4_VISION
# Base parameters
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}
# Add 'max_tokens' only if the model is a GPT4 vision model
if model == MODEL_GPT_4_VISION:
params["max_tokens"] = 4096
params["temperature"] = 0
stream = await client.chat.completions.create(**params) # type: ignore
full_response = ""
async for chunk in stream: # type: ignore
assert isinstance(chunk, ChatCompletionChunk)
content = chunk.choices[0].delta.content or ""
full_response += content
await callback(content)
await client.close()
return full_response
async def stream_gemini_response(
messages: List[Any] = [],
api_key: str = os.getenv('GOOGLE_API_KEY', ''),
callback: Callable[[str], Awaitable[None]] | None = None,
) -> str:
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-pro-vision')
general_config = {
"max_output_tokens": 2048,
"temperature": 0.4,
"top_p": 1,
"top_k": 32
}
formatted_messages = await format_messages_for_gemnini(messages)
response = model.generate_content(
formatted_messages,
stream=True,
generation_config=general_config # type: ignore
)
response.resolve()
result = await format_response(response.text)
while "</html>" not in result:
continue_prompt = "Generate the rest of the code below"
formatted_messages.append(continue_prompt)
formatted_messages.append(result)
response = model.generate_content(
formatted_messages,
stream=True,
generation_config=general_config # type: ignore
)
response.resolve()
result += await format_response(response.text)
result = result.strip()
return result