Merge main into sweep/add-reset-button-settings-dialog

This commit is contained in:
sweep-ai[bot] 2023-12-04 21:32:24 +00:00 committed by GitHub
commit 2520b72bf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,5 @@
# Load environment variables first # Load environment variables first
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi.responses import HTMLResponse
load_dotenv() load_dotenv()
@ -11,6 +10,8 @@ import traceback
from datetime import datetime from datetime import datetime
from fastapi import FastAPI, WebSocket from fastapi import FastAPI, WebSocket
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
import openai
from llm import stream_openai_response from llm import stream_openai_response
from mock import mock_completion from mock import mock_completion
from image_generation import create_alt_url_mapping, generate_images from image_generation import create_alt_url_mapping, generate_images
@ -35,6 +36,10 @@ app.add_middleware(
# TODO: Should only be set to true when value is 'True', not any abitrary truthy value # TODO: Should only be set to true when value is 'True', not any abitrary truthy value
SHOULD_MOCK_AI_RESPONSE = bool(os.environ.get("MOCK", False)) SHOULD_MOCK_AI_RESPONSE = bool(os.environ.get("MOCK", False))
# Set to True when running in production (on the hosted version)
# Used as a feature flag to enable or disable certain features
IS_PROD = os.environ.get("IS_PROD", False)
app.include_router(screenshot.router) app.include_router(screenshot.router)
@ -71,6 +76,12 @@ async def stream_code(websocket: WebSocket):
print("Incoming websocket connection...") print("Incoming websocket connection...")
async def throw_error(
message: str,
):
await websocket.send_json({"type": "error", "value": message})
await websocket.close()
params = await websocket.receive_json() params = await websocket.receive_json()
print("Received params") print("Received params")
@ -177,12 +188,24 @@ async def stream_code(websocket: WebSocket):
if SHOULD_MOCK_AI_RESPONSE: if SHOULD_MOCK_AI_RESPONSE:
completion = await mock_completion(process_chunk) completion = await mock_completion(process_chunk)
else: else:
try:
completion = await stream_openai_response( completion = await stream_openai_response(
prompt_messages, prompt_messages,
api_key=openai_api_key, api_key=openai_api_key,
base_url=openai_base_url, base_url=openai_base_url,
callback=lambda x: process_chunk(x), callback=lambda x: process_chunk(x),
) )
except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e)
error_message = (
"Incorrect OpenAI key. Please make sure your OpenAI API key is correct, or create a new OpenAI API key on your OpenAI dashboard."
+ (
" Alternatively, you can purchase code generation credits directly on this website."
if IS_PROD
else ""
)
)
return await throw_error(error_message)
# Write the messages dict into a log so that we can debug later # Write the messages dict into a log so that we can debug later
write_logs(prompt_messages, completion) write_logs(prompt_messages, completion)