screenshot-to-code/backend/main.py
2023-11-30 18:47:54 +08:00

181 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Load environment variables first
from dotenv import load_dotenv
#用于从 .env 文件中读取环境变量
load_dotenv()
import json
import os
import traceback
from datetime import datetime
from fastapi import FastAPI, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from llm import stream_openai_response
from mock import mock_completion
from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_prompt
from routes import screenshot
from access_token import validate_access_token
#这个是FastAPI的默认使用方式吗
app = FastAPI(openapi_url=None, docs_url=None, redoc_url=None)
# Configure CORS
# Configure CORS settings
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Useful for debugging purposes when you don't want to waste GPT4-Vision credits
# Setting to True will stream a mock response instead of calling the OpenAI API
# TODO: Should only be set to true when value is 'True', not any abitrary truthy value
SHOULD_MOCK_AI_RESPONSE = bool(os.environ.get("MOCK", False))
app.include_router(screenshot.router)
def write_logs(prompt_messages, completion):
# Get the logs path from environment, default to the current working directory
logs_path = os.environ.get("LOGS_PATH", os.getcwd())
# Create run_logs directory if it doesn't exist within the specified logs path
logs_directory = os.path.join(logs_path, "run_logs")
if not os.path.exists(logs_directory):
os.makedirs(logs_directory)
print("Writing to logs directory:", logs_directory)
# Generate a unique filename using the current timestamp within the logs directory
filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json")
# Write the messages dict into a new file for each run
with open(filename, "w") as f:
f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
@app.websocket("/generate-code")
async def stream_code(websocket: WebSocket):
await websocket.accept()
print("Incoming websocket connection...")
params = await websocket.receive_json()
print("Received params")
# Read the output settings from the request. Fall back to default if not provided.
output_settings = {"css": "tailwind", "js": "vanilla"}
if params["outputSettings"] and params["outputSettings"]["css"]:
output_settings["css"] = params["outputSettings"]["css"]
if params["outputSettings"] and params["outputSettings"]["js"]:
output_settings["js"] = params["outputSettings"]["js"]
print("Using output settings:", output_settings)
# Get the OpenAI API key from the request. Fall back to environment variable if not provided.
# If neither is provided, we throw an error.
openai_api_key = None
if "accessCode" in params and params["accessCode"]:
print("Access code - using platform API key")
if await validate_access_token(params["accessCode"]):
openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY")
else:
await websocket.send_json(
{
"type": "error",
"value": "Invalid access code or you're out of credits. Please try again.",
}
)
return
else:
if params["openAiApiKey"]:
openai_api_key = params["openAiApiKey"]
print("Using OpenAI API key from client-side settings dialog")
else:
openai_api_key = os.environ.get("OPENAI_API_KEY")
if openai_api_key:
print("Using OpenAI API key from environment variable")
if not openai_api_key:
print("OpenAI API key not found")
await websocket.send_json(
{
"type": "error",
"value": "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.",
}
)
return
should_generate_images = (
params["isImageGenerationEnabled"]
if "isImageGenerationEnabled" in params
else True
)
print("generating code...")
await websocket.send_json({"type": "status", "value": "Generating code..."})
async def process_chunk(content):
await websocket.send_json({"type": "chunk", "value": content})
if params.get("resultImage") and params["resultImage"]:
prompt_messages = assemble_prompt(
params["image"], output_settings, params["resultImage"]
)
else:
prompt_messages = assemble_prompt(params["image"], output_settings)
# Image cache for updates so that we don't have to regenerate images
image_cache = {}
if params["generationType"] == "update":
# Transform into message format
# TODO: Move this to frontend
for index, text in enumerate(params["history"]):
prompt_messages += [
{"role": "assistant" if index % 2 == 0 else "user", "content": text}
]
image_cache = create_alt_url_mapping(params["history"][-2])
if SHOULD_MOCK_AI_RESPONSE:
completion = await mock_completion(process_chunk)
else:
completion = await stream_openai_response(
prompt_messages,
api_key=openai_api_key,
callback=lambda x: process_chunk(x),
)
# Write the messages dict into a log so that we can debug later
write_logs(prompt_messages, completion)
try:
if should_generate_images:
await websocket.send_json(
{"type": "status", "value": "Generating images..."}
)
updated_html = await generate_images(
completion, api_key=openai_api_key, image_cache=image_cache
)
else:
updated_html = completion
await websocket.send_json({"type": "setCode", "value": updated_html})
await websocket.send_json(
{"type": "status", "value": "Code generation complete."}
)
except Exception as e:
traceback.print_exc()
print("Image generation failed", e)
await websocket.send_json(
{"type": "status", "value": "Image generation failed but code is complete."}
)
finally:
await websocket.close()