improve evals code

This commit is contained in:
Abi Raja 2024-07-19 07:55:44 -04:00
parent 9f732c4f5d
commit 9d11866143
3 changed files with 32 additions and 18 deletions

View File

@ -3,6 +3,7 @@
# TODO: Should only be set to true when value is 'True', not any abitrary truthy value
import os
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
# Debugging-related

View File

@ -1,38 +1,40 @@
import os
from config import ANTHROPIC_API_KEY
from config import ANTHROPIC_API_KEY, OPENAI_API_KEY
from llm import Llm, stream_claude_response, stream_openai_response
from prompts import assemble_prompt
from prompts.types import Stack
from openai.types.chat import ChatCompletionMessageParam
async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str:
async def generate_code_for_image(image_url: str, stack: Stack, model: Llm) -> str:
prompt_messages = assemble_prompt(image_url, stack)
openai_api_key = os.environ.get("OPENAI_API_KEY")
anthropic_api_key = ANTHROPIC_API_KEY
openai_base_url = None
return await generate_code_core(prompt_messages, model)
async def process_chunk(content: str):
async def generate_code_core(
prompt_messages: list[ChatCompletionMessageParam], model: Llm
) -> str:
async def process_chunk(_: str):
pass
if model == Llm.CLAUDE_3_SONNET or model == Llm.CLAUDE_3_5_SONNET_2024_06_20:
if not anthropic_api_key:
if not ANTHROPIC_API_KEY:
raise Exception("Anthropic API key not found")
completion = await stream_claude_response(
prompt_messages,
api_key=anthropic_api_key,
api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x),
model=model,
)
else:
if not openai_api_key:
if not OPENAI_API_KEY:
raise Exception("OpenAI API key not found")
completion = await stream_openai_response(
prompt_messages,
api_key=openai_api_key,
base_url=openai_base_url,
api_key=OPENAI_API_KEY,
base_url=None,
callback=lambda x: process_chunk(x),
model=model,
)

View File

@ -10,12 +10,12 @@ from typing import Any, Coroutine
import asyncio
from evals.config import EVALS_DIR
from evals.core import generate_code_core
from evals.core import generate_code_for_image
from evals.utils import image_to_data_url
STACK = "html_tailwind"
MODEL = Llm.CLAUDE_3_5_SONNET_2024_06_20
N = 1 # Number of outputs to generate
# MODEL = Llm.CLAUDE_3_5_SONNET_2024_06_20
N = 2 # Number of outputs to generate
async def main():
@ -29,10 +29,21 @@ async def main():
for filename in evals:
filepath = os.path.join(INPUT_DIR, filename)
data_url = await image_to_data_url(filepath)
for _ in range(N): # Generate N tasks for each input
task = generate_code_core(image_url=data_url, stack=STACK, model=MODEL)
for n in range(N): # Generate N tasks for each input
if n == 0:
task = generate_code_for_image(
image_url=data_url,
stack=STACK,
model=Llm.CLAUDE_3_5_SONNET_2024_06_20,
)
else:
task = generate_code_for_image(
image_url=data_url, stack=STACK, model=Llm.GPT_4O_2024_05_13
)
tasks.append(task)
print(f"Generating {len(tasks)} codes")
results = await asyncio.gather(*tasks)
os.makedirs(OUTPUT_DIR, exist_ok=True)