diff --git a/backend/config.py b/backend/config.py index 05592b0..5199cc6 100644 --- a/backend/config.py +++ b/backend/config.py @@ -3,6 +3,7 @@ # TODO: Should only be set to true when value is 'True', not any abitrary truthy value import os +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None) ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None) # Debugging-related diff --git a/backend/evals/core.py b/backend/evals/core.py index 2fc0352..ec35d21 100644 --- a/backend/evals/core.py +++ b/backend/evals/core.py @@ -1,38 +1,40 @@ -import os -from config import ANTHROPIC_API_KEY - +from config import ANTHROPIC_API_KEY, OPENAI_API_KEY from llm import Llm, stream_claude_response, stream_openai_response from prompts import assemble_prompt from prompts.types import Stack +from openai.types.chat import ChatCompletionMessageParam -async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str: +async def generate_code_for_image(image_url: str, stack: Stack, model: Llm) -> str: prompt_messages = assemble_prompt(image_url, stack) - openai_api_key = os.environ.get("OPENAI_API_KEY") - anthropic_api_key = ANTHROPIC_API_KEY - openai_base_url = None + return await generate_code_core(prompt_messages, model) - async def process_chunk(content: str): + +async def generate_code_core( + prompt_messages: list[ChatCompletionMessageParam], model: Llm +) -> str: + + async def process_chunk(_: str): pass if model == Llm.CLAUDE_3_SONNET or model == Llm.CLAUDE_3_5_SONNET_2024_06_20: - if not anthropic_api_key: + if not ANTHROPIC_API_KEY: raise Exception("Anthropic API key not found") completion = await stream_claude_response( prompt_messages, - api_key=anthropic_api_key, + api_key=ANTHROPIC_API_KEY, callback=lambda x: process_chunk(x), model=model, ) else: - if not openai_api_key: + if not OPENAI_API_KEY: raise Exception("OpenAI API key not found") completion = await stream_openai_response( prompt_messages, - api_key=openai_api_key, - base_url=openai_base_url, + api_key=OPENAI_API_KEY, + base_url=None, callback=lambda x: process_chunk(x), model=model, ) diff --git a/backend/run_evals.py b/backend/run_evals.py index ff5dc9f..2324c51 100644 --- a/backend/run_evals.py +++ b/backend/run_evals.py @@ -10,12 +10,12 @@ from typing import Any, Coroutine import asyncio from evals.config import EVALS_DIR -from evals.core import generate_code_core +from evals.core import generate_code_for_image from evals.utils import image_to_data_url STACK = "html_tailwind" -MODEL = Llm.CLAUDE_3_5_SONNET_2024_06_20 -N = 1 # Number of outputs to generate +# MODEL = Llm.CLAUDE_3_5_SONNET_2024_06_20 +N = 2 # Number of outputs to generate async def main(): @@ -29,10 +29,21 @@ async def main(): for filename in evals: filepath = os.path.join(INPUT_DIR, filename) data_url = await image_to_data_url(filepath) - for _ in range(N): # Generate N tasks for each input - task = generate_code_core(image_url=data_url, stack=STACK, model=MODEL) + for n in range(N): # Generate N tasks for each input + if n == 0: + task = generate_code_for_image( + image_url=data_url, + stack=STACK, + model=Llm.CLAUDE_3_5_SONNET_2024_06_20, + ) + else: + task = generate_code_for_image( + image_url=data_url, stack=STACK, model=Llm.GPT_4O_2024_05_13 + ) tasks.append(task) + print(f"Generating {len(tasks)} codes") + results = await asyncio.gather(*tasks) os.makedirs(OUTPUT_DIR, exist_ok=True)