improve evals code
This commit is contained in:
parent
9f732c4f5d
commit
9d11866143
@ -3,6 +3,7 @@
|
||||
# TODO: Should only be set to true when value is 'True', not any abitrary truthy value
|
||||
import os
|
||||
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
|
||||
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
|
||||
|
||||
# Debugging-related
|
||||
|
||||
@ -1,38 +1,40 @@
|
||||
import os
|
||||
from config import ANTHROPIC_API_KEY
|
||||
|
||||
from config import ANTHROPIC_API_KEY, OPENAI_API_KEY
|
||||
from llm import Llm, stream_claude_response, stream_openai_response
|
||||
from prompts import assemble_prompt
|
||||
from prompts.types import Stack
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
|
||||
async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str:
|
||||
async def generate_code_for_image(image_url: str, stack: Stack, model: Llm) -> str:
|
||||
prompt_messages = assemble_prompt(image_url, stack)
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
anthropic_api_key = ANTHROPIC_API_KEY
|
||||
openai_base_url = None
|
||||
return await generate_code_core(prompt_messages, model)
|
||||
|
||||
async def process_chunk(content: str):
|
||||
|
||||
async def generate_code_core(
|
||||
prompt_messages: list[ChatCompletionMessageParam], model: Llm
|
||||
) -> str:
|
||||
|
||||
async def process_chunk(_: str):
|
||||
pass
|
||||
|
||||
if model == Llm.CLAUDE_3_SONNET or model == Llm.CLAUDE_3_5_SONNET_2024_06_20:
|
||||
if not anthropic_api_key:
|
||||
if not ANTHROPIC_API_KEY:
|
||||
raise Exception("Anthropic API key not found")
|
||||
|
||||
completion = await stream_claude_response(
|
||||
prompt_messages,
|
||||
api_key=anthropic_api_key,
|
||||
api_key=ANTHROPIC_API_KEY,
|
||||
callback=lambda x: process_chunk(x),
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
if not openai_api_key:
|
||||
if not OPENAI_API_KEY:
|
||||
raise Exception("OpenAI API key not found")
|
||||
|
||||
completion = await stream_openai_response(
|
||||
prompt_messages,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_base_url,
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=None,
|
||||
callback=lambda x: process_chunk(x),
|
||||
model=model,
|
||||
)
|
||||
|
||||
@ -10,12 +10,12 @@ from typing import Any, Coroutine
|
||||
import asyncio
|
||||
|
||||
from evals.config import EVALS_DIR
|
||||
from evals.core import generate_code_core
|
||||
from evals.core import generate_code_for_image
|
||||
from evals.utils import image_to_data_url
|
||||
|
||||
STACK = "html_tailwind"
|
||||
MODEL = Llm.CLAUDE_3_5_SONNET_2024_06_20
|
||||
N = 1 # Number of outputs to generate
|
||||
# MODEL = Llm.CLAUDE_3_5_SONNET_2024_06_20
|
||||
N = 2 # Number of outputs to generate
|
||||
|
||||
|
||||
async def main():
|
||||
@ -29,10 +29,21 @@ async def main():
|
||||
for filename in evals:
|
||||
filepath = os.path.join(INPUT_DIR, filename)
|
||||
data_url = await image_to_data_url(filepath)
|
||||
for _ in range(N): # Generate N tasks for each input
|
||||
task = generate_code_core(image_url=data_url, stack=STACK, model=MODEL)
|
||||
for n in range(N): # Generate N tasks for each input
|
||||
if n == 0:
|
||||
task = generate_code_for_image(
|
||||
image_url=data_url,
|
||||
stack=STACK,
|
||||
model=Llm.CLAUDE_3_5_SONNET_2024_06_20,
|
||||
)
|
||||
else:
|
||||
task = generate_code_for_image(
|
||||
image_url=data_url, stack=STACK, model=Llm.GPT_4O_2024_05_13
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
print(f"Generating {len(tasks)} codes")
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user