1. load configure from environment 2. support anthropic base url

This commit is contained in:
chen 2024-05-25 10:54:01 +08:00
parent 23e631765e
commit c013bdc4ad
5 changed files with 32 additions and 25 deletions

View File

@ -3,10 +3,13 @@
# TODO: Should only be set to true when value is 'True', not any abitrary truthy value # TODO: Should only be set to true when value is 'True', not any abitrary truthy value
import os import os
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None) CFG_ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
CFG_ANTHROPIC_BASE_URL = os.environ.get("ANTHROPIC_BASE_URL", None)
CFG_OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY",None)
CFG_OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL",None)
# Debugging-related # Debugging-related
SHOULD_MOCK_AI_RESPONSE = bool(os.environ.get("MOCK", False)) SHOULD_MOCK_AI_RESPONSE = bool(os.environ.get("MOCK", False))
IS_DEBUG_ENABLED = bool(os.environ.get("IS_DEBUG_ENABLED", False)) IS_DEBUG_ENABLED = bool(os.environ.get("IS_DEBUG_ENABLED", False))
DEBUG_DIR = os.environ.get("DEBUG_DIR", "") DEBUG_DIR = os.environ.get("DEBUG_DIR", "")

View File

@ -1,5 +1,5 @@
import os import os
from config import ANTHROPIC_API_KEY from config import CFG_ANTHROPIC_API_KEY,CFG_ANTHROPIC_BASE_URL,CFG_OPENAI_API_KEY,CFG_OPENAI_BASE_URL
from llm import Llm, stream_claude_response, stream_openai_response from llm import Llm, stream_claude_response, stream_openai_response
from prompts import assemble_prompt from prompts import assemble_prompt
@ -8,30 +8,28 @@ from prompts.types import Stack
async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str: async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str:
prompt_messages = assemble_prompt(image_url, stack) prompt_messages = assemble_prompt(image_url, stack)
openai_api_key = os.environ.get("OPENAI_API_KEY")
anthropic_api_key = ANTHROPIC_API_KEY
openai_base_url = None
async def process_chunk(content: str): async def process_chunk(content: str):
pass pass
if model == Llm.CLAUDE_3_SONNET: if model == Llm.CLAUDE_3_SONNET:
if not anthropic_api_key: if not CFG_ANTHROPIC_API_KEY:
raise Exception("Anthropic API key not found") raise Exception("Anthropic API key not found")
completion = await stream_claude_response( completion = await stream_claude_response(
prompt_messages, prompt_messages,
api_key=anthropic_api_key, api_key=CFG_ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
base_url=CFG_ANTHROPIC_BASE_URL
) )
else: else:
if not openai_api_key: if not CFG_OPENAI_API_KEY:
raise Exception("OpenAI API key not found") raise Exception("OpenAI API key not found")
completion = await stream_openai_response( completion = await stream_openai_response(
prompt_messages, prompt_messages,
api_key=openai_api_key, api_key=CFG_OPENAI_API_KEY,
base_url=openai_base_url, base_url=CFG_OPENAI_BASE_URL,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
model=model, model=model,
) )

View File

@ -73,9 +73,10 @@ async def stream_claude_response(
messages: List[ChatCompletionMessageParam], messages: List[ChatCompletionMessageParam],
api_key: str, api_key: str,
callback: Callable[[str], Awaitable[None]], callback: Callable[[str], Awaitable[None]],
base_url:str|None = None
) -> str: ) -> str:
client = AsyncAnthropic(api_key=api_key) client = AsyncAnthropic(api_key=api_key,base_url=base_url)
# Base parameters # Base parameters
model = Llm.CLAUDE_3_SONNET model = Llm.CLAUDE_3_SONNET
@ -135,9 +136,10 @@ async def stream_claude_response_native(
callback: Callable[[str], Awaitable[None]], callback: Callable[[str], Awaitable[None]],
include_thinking: bool = False, include_thinking: bool = False,
model: Llm = Llm.CLAUDE_3_OPUS, model: Llm = Llm.CLAUDE_3_OPUS,
base_url: str | None = None,
) -> str: ) -> str:
client = AsyncAnthropic(api_key=api_key) client = AsyncAnthropic(api_key=api_key,base_url=base_url)
# Base model parameters # Base model parameters
max_tokens = 4096 max_tokens = 4096

View File

@ -2,7 +2,7 @@ import os
import traceback import traceback
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import openai import openai
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE from config import CFG_ANTHROPIC_API_KEY,CFG_ANTHROPIC_BASE_URL,CFG_OPENAI_API_KEY,CFG_OPENAI_BASE_URL, IS_PROD, SHOULD_MOCK_AI_RESPONSE
from custom_types import InputMode from custom_types import InputMode
from llm import ( from llm import (
Llm, Llm,
@ -105,7 +105,7 @@ async def stream_code(websocket: WebSocket):
openai_api_key = params["openAiApiKey"] openai_api_key = params["openAiApiKey"]
print("Using OpenAI API key from client-side settings dialog") print("Using OpenAI API key from client-side settings dialog")
else: else:
openai_api_key = os.environ.get("OPENAI_API_KEY") openai_api_key = CFG_OPENAI_API_KEY
if openai_api_key: if openai_api_key:
print("Using OpenAI API key from environment variable") print("Using OpenAI API key from environment variable")
@ -119,7 +119,7 @@ async def stream_code(websocket: WebSocket):
"No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server." "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file. If you add it to .env, make sure to restart the backend server."
) )
return return
openai_api_key:str = openai_api_key
# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
openai_base_url = None openai_base_url = None
# Disable user-specified OpenAI Base URL in prod # Disable user-specified OpenAI Base URL in prod
@ -219,7 +219,7 @@ async def stream_code(websocket: WebSocket):
else: else:
try: try:
if validated_input_mode == "video": if validated_input_mode == "video":
if not ANTHROPIC_API_KEY: if not CFG_ANTHROPIC_API_KEY:
await throw_error( await throw_error(
"Video only works with Anthropic models. No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env" "Video only works with Anthropic models. No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env"
) )
@ -228,14 +228,15 @@ async def stream_code(websocket: WebSocket):
completion = await stream_claude_response_native( completion = await stream_claude_response_native(
system_prompt=VIDEO_PROMPT, system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore messages=prompt_messages, # type: ignore
api_key=ANTHROPIC_API_KEY, api_key=CFG_ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
model=Llm.CLAUDE_3_OPUS, model=Llm.CLAUDE_3_OPUS,
include_thinking=True, include_thinking=True,
base_url=CFG_ANTHROPIC_BASE_URL
) )
exact_llm_version = Llm.CLAUDE_3_OPUS exact_llm_version = Llm.CLAUDE_3_OPUS
elif code_generation_model == Llm.CLAUDE_3_SONNET: elif code_generation_model == Llm.CLAUDE_3_SONNET:
if not ANTHROPIC_API_KEY: if not CFG_ANTHROPIC_API_KEY:
await throw_error( await throw_error(
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env" "No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env"
) )
@ -243,17 +244,19 @@ async def stream_code(websocket: WebSocket):
completion = await stream_claude_response( completion = await stream_claude_response(
prompt_messages, # type: ignore prompt_messages, # type: ignore
api_key=ANTHROPIC_API_KEY, api_key=CFG_ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
base_url=CFG_ANTHROPIC_BASE_URL
) )
exact_llm_version = code_generation_model exact_llm_version = code_generation_model
else: else:
completion = await stream_openai_response( completion = await stream_openai_response(
prompt_messages, # type: ignore prompt_messages, # type: ignore
api_key=openai_api_key, api_key=openai_api_key,
base_url=openai_base_url, base_url=CFG_OPENAI_BASE_URL,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
model=code_generation_model, model=code_generation_model,
) )
exact_llm_version = code_generation_model exact_llm_version = code_generation_model
except openai.AuthenticationError as e: except openai.AuthenticationError as e:
@ -307,7 +310,7 @@ async def stream_code(websocket: WebSocket):
updated_html = await generate_images( updated_html = await generate_images(
completion, completion,
api_key=openai_api_key, api_key=openai_api_key,
base_url=openai_base_url, base_url=CFG_OPENAI_BASE_URL,
image_cache=image_cache, image_cache=image_cache,
) )
else: else:

View File

@ -14,7 +14,7 @@ import asyncio
from datetime import datetime from datetime import datetime
from prompts.claude_prompts import VIDEO_PROMPT from prompts.claude_prompts import VIDEO_PROMPT
from utils import pprint_prompt from utils import pprint_prompt
from config import ANTHROPIC_API_KEY from config import CFG_ANTHROPIC_API_KEY,CFG_ANTHROPIC_BASE_URL,CFG_OPENAI_API_KEY,CFG_OPENAI_BASE_URL
from video.utils import extract_tag_content, assemble_claude_prompt_video from video.utils import extract_tag_content, assemble_claude_prompt_video
from llm import ( from llm import (
Llm, Llm,
@ -32,7 +32,7 @@ async def main():
video_filename = "shortest.mov" video_filename = "shortest.mov"
is_followup = False is_followup = False
if not ANTHROPIC_API_KEY: if not CFG_ANTHROPIC_API_KEY:
raise ValueError("ANTHROPIC_API_KEY is not set") raise ValueError("ANTHROPIC_API_KEY is not set")
# Get previous HTML # Get previous HTML
@ -84,10 +84,11 @@ async def main():
completion = await stream_claude_response_native( completion = await stream_claude_response_native(
system_prompt=VIDEO_PROMPT, system_prompt=VIDEO_PROMPT,
messages=prompt_messages, messages=prompt_messages,
api_key=ANTHROPIC_API_KEY, api_key=CFG_ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
model=Llm.CLAUDE_3_OPUS, model=Llm.CLAUDE_3_OPUS,
include_thinking=True, include_thinking=True,
base_url=CFG_ANTHROPIC_BASE_URL
) )
end_time = time.time() end_time = time.time()