Update main.py

## Improving Code Structure:
1-Group related imports together .
2-separate the code into functions to enhance readability and maintainability .
This commit is contained in:
DaEpic 2023-11-17 19:43:05 +00:00 committed by GitHub
parent dbf89928ec
commit d70dd41252
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,64 +1,95 @@
# Load environment variables first from datetime import datetime
from dotenv import load_dotenv
load_dotenv()
import json import json
import os import os
import traceback import traceback
from datetime import datetime
from fastapi import FastAPI, WebSocket from fastapi import FastAPI, WebSocket
from dotenv import load_dotenv
from llm import stream_openai_response from llm import stream_openai_response
from mock import mock_completion from mock import mock_completion
from image_generation import create_alt_url_mapping, generate_images from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_prompt from prompts import assemble_prompt
app = FastAPI() app = FastAPI()
load_dotenv()
# Useful for debugging purposes when you don't want to waste GPT4-Vision credits
# Setting to True will stream a mock response instead of calling the OpenAI API
SHOULD_MOCK_AI_RESPONSE = False
def write_logs(prompt_messages, completion): def get_openai_api_key(params):
# Get the logs path from environment, default to the current working directory return params.get("openAiApiKey") or os.environ.get("OPENAI_API_KEY")
logs_path = os.environ.get("LOGS_PATH", os.getcwd())
# Create run_logs directory if it doesn't exist within the specified logs path
def create_logs_directory(logs_path):
logs_directory = os.path.join(logs_path, "run_logs") logs_directory = os.path.join(logs_path, "run_logs")
if not os.path.exists(logs_directory): if not os.path.exists(logs_directory):
os.makedirs(logs_directory) os.makedirs(logs_directory)
return logs_directory
print("Writing to logs directory:", logs_directory)
# Generate a unique filename using the current timestamp within the logs directory def write_logs(logs_directory, prompt_messages, completion):
filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json") filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json")
# Write the messages dict into a new file for each run
with open(filename, "w") as f: with open(filename, "w") as f:
f.write(json.dumps({"prompt": prompt_messages, "completion": completion})) f.write(json.dumps({"prompt": prompt_messages, "completion": completion}))
async def send_status_message(websocket, message):
await websocket.send_json({"type": "status", "value": message})
async def process_chunk(websocket, content):
await websocket.send_json({"type": "chunk", "value": content})
async def generate_code(websocket, params, openai_api_key):
should_generate_images = params.get("isImageGenerationEnabled", True)
await send_status_message(websocket, "Generating code...")
prompt_messages = assemble_prompt(params["image"])
image_cache = {}
if params["generationType"] == "update":
for index, text in enumerate(params["history"]):
prompt_messages += [
{"role": "assistant" if index % 2 == 0 else "user", "content": text}
]
image_cache = create_alt_url_mapping(params["history"][-2])
if SHOULD_MOCK_AI_RESPONSE:
completion = await mock_completion(lambda x: process_chunk(websocket, x))
else:
completion = await stream_openai_response(
prompt_messages, api_key=openai_api_key, callback=lambda x: process_chunk(websocket, x)
)
logs_directory = create_logs_directory(os.environ.get("LOGS_PATH", os.getcwd()))
write_logs(logs_directory, prompt_messages, completion)
try:
if should_generate_images:
await send_status_message(websocket, "Generating images...")
updated_html = await generate_images(completion, api_key=openai_api_key, image_cache=image_cache)
else:
updated_html = completion
await websocket.send_json({"type": "setCode", "value": updated_html})
await send_status_message(websocket, "Code generation complete.")
except Exception as e:
traceback.print_exc()
print("Image generation failed", e)
await send_status_message(websocket, "Image generation failed but code is complete.")
finally:
await websocket.close()
@app.websocket("/generate-code") @app.websocket("/generate-code")
async def stream_code_test(websocket: WebSocket): async def stream_code_test(websocket: WebSocket):
await websocket.accept() await websocket.accept()
params = await websocket.receive_json() params = await websocket.receive_json()
# Get the OpenAI API key from the request. Fall back to environment variable if not provided. openai_api_key = get_openai_api_key(params)
# If neither is provided, we throw an error.
if params["openAiApiKey"]:
openai_api_key = params["openAiApiKey"]
print("Using OpenAI API key from client-side settings dialog")
else:
openai_api_key = os.environ.get("OPENAI_API_KEY")
if openai_api_key:
print("Using OpenAI API key from environment variable")
if not openai_api_key: if not openai_api_key:
print("OpenAI API key not found")
await websocket.send_json( await websocket.send_json(
{ {
"type": "error", "type": "error",
@ -67,64 +98,4 @@ async def stream_code_test(websocket: WebSocket):
) )
return return
should_generate_images = ( await generate_code(websocket, params, openai_api_key)
params["isImageGenerationEnabled"]
if "isImageGenerationEnabled" in params
else True
)
print("generating code...")
await websocket.send_json({"type": "status", "value": "Generating code..."})
async def process_chunk(content):
await websocket.send_json({"type": "chunk", "value": content})
prompt_messages = assemble_prompt(params["image"])
# Image cache for updates so that we don't have to regenerate images
image_cache = {}
if params["generationType"] == "update":
# Transform into message format
# TODO: Move this to frontend
for index, text in enumerate(params["history"]):
prompt_messages += [
{"role": "assistant" if index % 2 == 0 else "user", "content": text}
]
image_cache = create_alt_url_mapping(params["history"][-2])
if SHOULD_MOCK_AI_RESPONSE:
completion = await mock_completion(process_chunk)
else:
completion = await stream_openai_response(
prompt_messages,
api_key=openai_api_key,
callback=lambda x: process_chunk(x),
)
# Write the messages dict into a log so that we can debug later
write_logs(prompt_messages, completion)
try:
if should_generate_images:
await websocket.send_json(
{"type": "status", "value": "Generating images..."}
)
updated_html = await generate_images(
completion, api_key=openai_api_key, image_cache=image_cache
)
else:
updated_html = completion
await websocket.send_json({"type": "setCode", "value": updated_html})
await websocket.send_json(
{"type": "status", "value": "Code generation complete."}
)
except Exception as e:
traceback.print_exc()
print("Image generation failed", e)
await websocket.send_json(
{"type": "status", "value": "Image generation failed but code is complete."}
)
finally:
await websocket.close()