From 6a28ee2d3cac362d8367c4efbf1d07f98c88743e Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Sat, 9 Dec 2023 15:34:16 -0500 Subject: [PATCH 1/7] strictly type python backend --- .vscode/settings.json | 3 +++ backend/image_generation.py | 20 +++++++++++--------- backend/llm.py | 13 +++++++------ backend/main.py | 27 +++++++++++++++++++-------- backend/{mock.py => mock_llm.py} | 3 ++- backend/prompts.py | 13 ++++++++++--- backend/routes/screenshot.py | 4 +++- backend/utils.py | 20 +++++++++++--------- 8 files changed, 66 insertions(+), 37 deletions(-) create mode 100644 .vscode/settings.json rename backend/{mock.py => mock_llm.py} (98%) diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..d6e2638 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.analysis.typeCheckingMode": "strict" +} diff --git a/backend/image_generation.py b/backend/image_generation.py index bb272f8..d3e71b1 100644 --- a/backend/image_generation.py +++ b/backend/image_generation.py @@ -1,15 +1,15 @@ import asyncio -import os import re +from typing import Dict, List, Union from openai import AsyncOpenAI from bs4 import BeautifulSoup -async def process_tasks(prompts, api_key, base_url): +async def process_tasks(prompts: List[str], api_key: str, base_url: str): tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts] results = await asyncio.gather(*tasks, return_exceptions=True) - processed_results = [] + processed_results: List[Union[str, None]] = [] for result in results: if isinstance(result, Exception): print(f"An exception occurred: {result}") @@ -20,9 +20,9 @@ async def process_tasks(prompts, api_key, base_url): return processed_results -async def generate_image(prompt, api_key, base_url): +async def generate_image(prompt: str, api_key: str, base_url: str): client = AsyncOpenAI(api_key=api_key, base_url=base_url) - image_params = { + image_params: Dict[str, Union[str, int]] = { "model": "dall-e-3", "quality": "standard", "style": "natural", @@ -35,7 +35,7 @@ async def generate_image(prompt, api_key, base_url): return res.data[0].url -def extract_dimensions(url): +def extract_dimensions(url: str): # Regular expression to match numbers in the format '300x200' matches = re.findall(r"(\d+)x(\d+)", url) @@ -48,11 +48,11 @@ def extract_dimensions(url): return (100, 100) -def create_alt_url_mapping(code): +def create_alt_url_mapping(code: str) -> Dict[str, str]: soup = BeautifulSoup(code, "html.parser") images = soup.find_all("img") - mapping = {} + mapping: Dict[str, str] = {} for image in images: if not image["src"].startswith("https://placehold.co"): @@ -61,7 +61,9 @@ def create_alt_url_mapping(code): return mapping -async def generate_images(code, api_key, base_url, image_cache): +async def generate_images( + code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str] +): # Find all images soup = BeautifulSoup(code, "html.parser") images = soup.find_all("img") diff --git a/backend/llm.py b/backend/llm.py index e2b41c4..66e3a47 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -1,16 +1,16 @@ -import os -from typing import Awaitable, Callable +from typing import Awaitable, Callable, List from openai import AsyncOpenAI +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk MODEL_GPT_4_VISION = "gpt-4-vision-preview" async def stream_openai_response( - messages, + messages: List[ChatCompletionMessageParam], api_key: str, base_url: str | None, callback: Callable[[str], Awaitable[None]], -): +) -> str: client = AsyncOpenAI(api_key=api_key, base_url=base_url) model = MODEL_GPT_4_VISION @@ -23,9 +23,10 @@ async def stream_openai_response( params["max_tokens"] = 4096 params["temperature"] = 0 - completion = await client.chat.completions.create(**params) + stream = await client.chat.completions.create(**params) # type: ignore full_response = "" - async for chunk in completion: + async for chunk in stream: # type: ignore + assert isinstance(chunk, ChatCompletionChunk) content = chunk.choices[0].delta.content or "" full_response += content await callback(content) diff --git a/backend/main.py b/backend/main.py index 517eef7..593ec3a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,8 +14,10 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse import openai from llm import stream_openai_response -from mock import mock_completion +from openai.types.chat import ChatCompletionMessageParam +from mock_llm import mock_completion from utils import pprint_prompt +from typing import Dict, List from image_generation import create_alt_url_mapping, generate_images from prompts import assemble_prompt from routes import screenshot @@ -53,7 +55,7 @@ async def get_status(): ) -def write_logs(prompt_messages, completion): +def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str): # Get the logs path from environment, default to the current working directory logs_path = os.environ.get("LOGS_PATH", os.getcwd()) @@ -84,7 +86,8 @@ async def stream_code(websocket: WebSocket): await websocket.send_json({"type": "error", "value": message}) await websocket.close() - params = await websocket.receive_json() + # TODO: Are the values always strings? + params: Dict[str, str] = await websocket.receive_json() print("Received params") @@ -154,7 +157,7 @@ async def stream_code(websocket: WebSocket): print("generating code...") await websocket.send_json({"type": "status", "value": "Generating code..."}) - async def process_chunk(content): + async def process_chunk(content: str): await websocket.send_json({"type": "chunk", "value": content}) # Assemble the prompt @@ -176,15 +179,23 @@ async def stream_code(websocket: WebSocket): return # Image cache for updates so that we don't have to regenerate images - image_cache = {} + image_cache: Dict[str, str] = {} if params["generationType"] == "update": # Transform into message format # TODO: Move this to frontend for index, text in enumerate(params["history"]): - prompt_messages += [ - {"role": "assistant" if index % 2 == 0 else "user", "content": text} - ] + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + prompt_messages.append(message) image_cache = create_alt_url_mapping(params["history"][-2]) if SHOULD_MOCK_AI_RESPONSE: diff --git a/backend/mock.py b/backend/mock_llm.py similarity index 98% rename from backend/mock.py rename to backend/mock_llm.py index 90dc7d3..0102bad 100644 --- a/backend/mock.py +++ b/backend/mock_llm.py @@ -1,7 +1,8 @@ import asyncio +from typing import Awaitable, Callable -async def mock_completion(process_chunk): +async def mock_completion(process_chunk: Callable[[str], Awaitable[None]]) -> str: code_to_return = NO_IMAGES_NYTIMES_MOCK_CODE for i in range(0, len(code_to_return), 10): diff --git a/backend/prompts.py b/backend/prompts.py index c9e48cb..f52c195 100644 --- a/backend/prompts.py +++ b/backend/prompts.py @@ -1,3 +1,8 @@ +from typing import List, Union + +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam + + TAILWIND_SYSTEM_PROMPT = """ You are an expert Tailwind developer You take screenshots of a reference web page from the user, and then build single page apps @@ -117,8 +122,10 @@ Generate code for a web page that looks exactly like this. def assemble_prompt( - image_data_url, generated_code_config: str, result_image_data_url=None -): + image_data_url: str, + generated_code_config: str, + result_image_data_url: Union[str, None] = None, +) -> List[ChatCompletionMessageParam]: # Set the system prompt based on the output settings system_content = TAILWIND_SYSTEM_PROMPT if generated_code_config == "html_tailwind": @@ -132,7 +139,7 @@ def assemble_prompt( else: raise Exception("Code config is not one of available options") - user_content = [ + user_content: List[ChatCompletionContentPartParam] = [ { "type": "image_url", "image_url": {"url": image_data_url, "detail": "high"}, diff --git a/backend/routes/screenshot.py b/backend/routes/screenshot.py index 7efcfb8..258cd7e 100644 --- a/backend/routes/screenshot.py +++ b/backend/routes/screenshot.py @@ -11,7 +11,9 @@ def bytes_to_data_url(image_bytes: bytes, mime_type: str) -> str: return f"data:{mime_type};base64,{base64_image}" -async def capture_screenshot(target_url, api_key, device="desktop") -> bytes: +async def capture_screenshot( + target_url: str, api_key: str, device: str = "desktop" +) -> bytes: api_base_url = "https://api.screenshotone.com/take" params = { diff --git a/backend/utils.py b/backend/utils.py index 17d6423..6c28e14 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -1,28 +1,30 @@ import copy import json +from typing import List +from openai.types.chat import ChatCompletionMessageParam -def pprint_prompt(prompt_messages): +def pprint_prompt(prompt_messages: List[ChatCompletionMessageParam]): print(json.dumps(truncate_data_strings(prompt_messages), indent=4)) -def truncate_data_strings(data): +def truncate_data_strings(data: List[ChatCompletionMessageParam]): # type: ignore # Deep clone the data to avoid modifying the original object cloned_data = copy.deepcopy(data) if isinstance(cloned_data, dict): - for key, value in cloned_data.items(): + for key, value in cloned_data.items(): # type: ignore # Recursively call the function if the value is a dictionary or a list if isinstance(value, (dict, list)): - cloned_data[key] = truncate_data_strings(value) + cloned_data[key] = truncate_data_strings(value) # type: ignore # Truncate the string if it it's long and add ellipsis and length elif isinstance(value, str): - cloned_data[key] = value[:40] + cloned_data[key] = value[:40] # type: ignore if len(value) > 40: - cloned_data[key] += "..." + f" ({len(value)} chars)" + cloned_data[key] += "..." + f" ({len(value)} chars)" # type: ignore - elif isinstance(cloned_data, list): + elif isinstance(cloned_data, list): # type: ignore # Process each item in the list - cloned_data = [truncate_data_strings(item) for item in cloned_data] + cloned_data = [truncate_data_strings(item) for item in cloned_data] # type: ignore - return cloned_data + return cloned_data # type: ignore From 435402bc8530f44ba1db2cd6a4700e716cb3ff0b Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Sat, 9 Dec 2023 15:46:42 -0500 Subject: [PATCH 2/7] split main.py into appropriate routes files --- backend/config.py | 11 ++ backend/main.py | 258 +------------------------------- backend/routes/generate_code.py | 235 +++++++++++++++++++++++++++++ backend/routes/home.py | 12 ++ 4 files changed, 263 insertions(+), 253 deletions(-) create mode 100644 backend/config.py create mode 100644 backend/routes/generate_code.py create mode 100644 backend/routes/home.py diff --git a/backend/config.py b/backend/config.py new file mode 100644 index 0000000..7dd21f9 --- /dev/null +++ b/backend/config.py @@ -0,0 +1,11 @@ +# Useful for debugging purposes when you don't want to waste GPT4-Vision credits +# Setting to True will stream a mock response instead of calling the OpenAI API +# TODO: Should only be set to true when value is 'True', not any abitrary truthy value +import os + + +SHOULD_MOCK_AI_RESPONSE = bool(os.environ.get("MOCK", False)) + +# Set to True when running in production (on the hosted version) +# Used as a feature flag to enable or disable certain features +IS_PROD = os.environ.get("IS_PROD", False) diff --git a/backend/main.py b/backend/main.py index 593ec3a..3ea43d3 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,27 +1,12 @@ # Load environment variables first from dotenv import load_dotenv - load_dotenv() -import json -import os -import traceback -from datetime import datetime -from fastapi import FastAPI, WebSocket +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import HTMLResponse -import openai -from llm import stream_openai_response -from openai.types.chat import ChatCompletionMessageParam -from mock_llm import mock_completion -from utils import pprint_prompt -from typing import Dict, List -from image_generation import create_alt_url_mapping, generate_images -from prompts import assemble_prompt -from routes import screenshot -from access_token import validate_access_token +from routes import screenshot, generate_code, home app = FastAPI(openapi_url=None, docs_url=None, redoc_url=None) @@ -34,240 +19,7 @@ app.add_middleware( allow_headers=["*"], ) - -# Useful for debugging purposes when you don't want to waste GPT4-Vision credits -# Setting to True will stream a mock response instead of calling the OpenAI API -# TODO: Should only be set to true when value is 'True', not any abitrary truthy value -SHOULD_MOCK_AI_RESPONSE = bool(os.environ.get("MOCK", False)) - -# Set to True when running in production (on the hosted version) -# Used as a feature flag to enable or disable certain features -IS_PROD = os.environ.get("IS_PROD", False) - - +# Add routes +app.include_router(generate_code.router) app.include_router(screenshot.router) - - -@app.get("/") -async def get_status(): - return HTMLResponse( - content="

Your backend is running correctly. Please open the front-end URL (default is http://localhost:5173) to use screenshot-to-code.

" - ) - - -def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str): - # Get the logs path from environment, default to the current working directory - logs_path = os.environ.get("LOGS_PATH", os.getcwd()) - - # Create run_logs directory if it doesn't exist within the specified logs path - logs_directory = os.path.join(logs_path, "run_logs") - if not os.path.exists(logs_directory): - os.makedirs(logs_directory) - - print("Writing to logs directory:", logs_directory) - - # Generate a unique filename using the current timestamp within the logs directory - filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json") - - # Write the messages dict into a new file for each run - with open(filename, "w") as f: - f.write(json.dumps({"prompt": prompt_messages, "completion": completion})) - - -@app.websocket("/generate-code") -async def stream_code(websocket: WebSocket): - await websocket.accept() - - print("Incoming websocket connection...") - - async def throw_error( - message: str, - ): - await websocket.send_json({"type": "error", "value": message}) - await websocket.close() - - # TODO: Are the values always strings? - params: Dict[str, str] = await websocket.receive_json() - - print("Received params") - - # Read the code config settings from the request. Fall back to default if not provided. - generated_code_config = "" - if "generatedCodeConfig" in params and params["generatedCodeConfig"]: - generated_code_config = params["generatedCodeConfig"] - print(f"Generating {generated_code_config} code") - - # Get the OpenAI API key from the request. Fall back to environment variable if not provided. - # If neither is provided, we throw an error. - openai_api_key = None - if "accessCode" in params and params["accessCode"]: - print("Access code - using platform API key") - res = await validate_access_token(params["accessCode"]) - if res["success"]: - openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY") - else: - await websocket.send_json( - { - "type": "error", - "value": res["failure_reason"], - } - ) - return - else: - if params["openAiApiKey"]: - openai_api_key = params["openAiApiKey"] - print("Using OpenAI API key from client-side settings dialog") - else: - openai_api_key = os.environ.get("OPENAI_API_KEY") - if openai_api_key: - print("Using OpenAI API key from environment variable") - - if not openai_api_key: - print("OpenAI API key not found") - await websocket.send_json( - { - "type": "error", - "value": "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.", - } - ) - return - - # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. - openai_base_url = None - # Disable user-specified OpenAI Base URL in prod - if not os.environ.get("IS_PROD"): - if "openAiBaseURL" in params and params["openAiBaseURL"]: - openai_base_url = params["openAiBaseURL"] - print("Using OpenAI Base URL from client-side settings dialog") - else: - openai_base_url = os.environ.get("OPENAI_BASE_URL") - if openai_base_url: - print("Using OpenAI Base URL from environment variable") - - if not openai_base_url: - print("Using official OpenAI URL") - - # Get the image generation flag from the request. Fall back to True if not provided. - should_generate_images = ( - params["isImageGenerationEnabled"] - if "isImageGenerationEnabled" in params - else True - ) - - print("generating code...") - await websocket.send_json({"type": "status", "value": "Generating code..."}) - - async def process_chunk(content: str): - await websocket.send_json({"type": "chunk", "value": content}) - - # Assemble the prompt - try: - if params.get("resultImage") and params["resultImage"]: - prompt_messages = assemble_prompt( - params["image"], generated_code_config, params["resultImage"] - ) - else: - prompt_messages = assemble_prompt(params["image"], generated_code_config) - except: - await websocket.send_json( - { - "type": "error", - "value": "Error assembling prompt. Contact support at support@picoapps.xyz", - } - ) - await websocket.close() - return - - # Image cache for updates so that we don't have to regenerate images - image_cache: Dict[str, str] = {} - - if params["generationType"] == "update": - # Transform into message format - # TODO: Move this to frontend - for index, text in enumerate(params["history"]): - if index % 2 == 0: - message: ChatCompletionMessageParam = { - "role": "assistant", - "content": text, - } - else: - message: ChatCompletionMessageParam = { - "role": "user", - "content": text, - } - prompt_messages.append(message) - image_cache = create_alt_url_mapping(params["history"][-2]) - - if SHOULD_MOCK_AI_RESPONSE: - completion = await mock_completion(process_chunk) - else: - try: - completion = await stream_openai_response( - prompt_messages, - api_key=openai_api_key, - base_url=openai_base_url, - callback=lambda x: process_chunk(x), - ) - except openai.AuthenticationError as e: - print("[GENERATE_CODE] Authentication failed", e) - error_message = ( - "Incorrect OpenAI key. Please make sure your OpenAI API key is correct, or create a new OpenAI API key on your OpenAI dashboard." - + ( - " Alternatively, you can purchase code generation credits directly on this website." - if IS_PROD - else "" - ) - ) - return await throw_error(error_message) - except openai.NotFoundError as e: - print("[GENERATE_CODE] Model not found", e) - error_message = ( - e.message - + ". Please make sure you have followed the instructions correctly to obtain an OpenAI key with GPT vision access: https://github.com/abi/screenshot-to-code/blob/main/Troubleshooting.md" - + ( - " Alternatively, you can purchase code generation credits directly on this website." - if IS_PROD - else "" - ) - ) - return await throw_error(error_message) - except openai.RateLimitError as e: - print("[GENERATE_CODE] Rate limit exceeded", e) - error_message = ( - "OpenAI error - 'You exceeded your current quota, please check your plan and billing details.'" - + ( - " Alternatively, you can purchase code generation credits directly on this website." - if IS_PROD - else "" - ) - ) - return await throw_error(error_message) - - # Write the messages dict into a log so that we can debug later - write_logs(prompt_messages, completion) - - try: - if should_generate_images: - await websocket.send_json( - {"type": "status", "value": "Generating images..."} - ) - updated_html = await generate_images( - completion, - api_key=openai_api_key, - base_url=openai_base_url, - image_cache=image_cache, - ) - else: - updated_html = completion - await websocket.send_json({"type": "setCode", "value": updated_html}) - await websocket.send_json( - {"type": "status", "value": "Code generation complete."} - ) - except Exception as e: - traceback.print_exc() - print("Image generation failed", e) - await websocket.send_json( - {"type": "status", "value": "Image generation failed but code is complete."} - ) - - await websocket.close() +app.include_router(home.router) diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py new file mode 100644 index 0000000..a44aeac --- /dev/null +++ b/backend/routes/generate_code.py @@ -0,0 +1,235 @@ +import os +import traceback +from fastapi import APIRouter, WebSocket +import openai +from config import IS_PROD, SHOULD_MOCK_AI_RESPONSE +from llm import stream_openai_response +from openai.types.chat import ChatCompletionMessageParam +from mock_llm import mock_completion +from typing import Dict, List +from image_generation import create_alt_url_mapping, generate_images +from prompts import assemble_prompt +from access_token import validate_access_token +from datetime import datetime +import json + + +router = APIRouter() + + +def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str): + # Get the logs path from environment, default to the current working directory + logs_path = os.environ.get("LOGS_PATH", os.getcwd()) + + # Create run_logs directory if it doesn't exist within the specified logs path + logs_directory = os.path.join(logs_path, "run_logs") + if not os.path.exists(logs_directory): + os.makedirs(logs_directory) + + print("Writing to logs directory:", logs_directory) + + # Generate a unique filename using the current timestamp within the logs directory + filename = datetime.now().strftime(f"{logs_directory}/messages_%Y%m%d_%H%M%S.json") + + # Write the messages dict into a new file for each run + with open(filename, "w") as f: + f.write(json.dumps({"prompt": prompt_messages, "completion": completion})) + + +@router.websocket("/generate-code") +async def stream_code(websocket: WebSocket): + await websocket.accept() + + print("Incoming websocket connection...") + + async def throw_error( + message: str, + ): + await websocket.send_json({"type": "error", "value": message}) + await websocket.close() + + # TODO: Are the values always strings? + params: Dict[str, str] = await websocket.receive_json() + + print("Received params") + + # Read the code config settings from the request. Fall back to default if not provided. + generated_code_config = "" + if "generatedCodeConfig" in params and params["generatedCodeConfig"]: + generated_code_config = params["generatedCodeConfig"] + print(f"Generating {generated_code_config} code") + + # Get the OpenAI API key from the request. Fall back to environment variable if not provided. + # If neither is provided, we throw an error. + openai_api_key = None + if "accessCode" in params and params["accessCode"]: + print("Access code - using platform API key") + res = await validate_access_token(params["accessCode"]) + if res["success"]: + openai_api_key = os.environ.get("PLATFORM_OPENAI_API_KEY") + else: + await websocket.send_json( + { + "type": "error", + "value": res["failure_reason"], + } + ) + return + else: + if params["openAiApiKey"]: + openai_api_key = params["openAiApiKey"] + print("Using OpenAI API key from client-side settings dialog") + else: + openai_api_key = os.environ.get("OPENAI_API_KEY") + if openai_api_key: + print("Using OpenAI API key from environment variable") + + if not openai_api_key: + print("OpenAI API key not found") + await websocket.send_json( + { + "type": "error", + "value": "No OpenAI API key found. Please add your API key in the settings dialog or add it to backend/.env file.", + } + ) + return + + # Get the OpenAI Base URL from the request. Fall back to environment variable if not provided. + openai_base_url = None + # Disable user-specified OpenAI Base URL in prod + if not os.environ.get("IS_PROD"): + if "openAiBaseURL" in params and params["openAiBaseURL"]: + openai_base_url = params["openAiBaseURL"] + print("Using OpenAI Base URL from client-side settings dialog") + else: + openai_base_url = os.environ.get("OPENAI_BASE_URL") + if openai_base_url: + print("Using OpenAI Base URL from environment variable") + + if not openai_base_url: + print("Using official OpenAI URL") + + # Get the image generation flag from the request. Fall back to True if not provided. + should_generate_images = ( + params["isImageGenerationEnabled"] + if "isImageGenerationEnabled" in params + else True + ) + + print("generating code...") + await websocket.send_json({"type": "status", "value": "Generating code..."}) + + async def process_chunk(content: str): + await websocket.send_json({"type": "chunk", "value": content}) + + # Assemble the prompt + try: + if params.get("resultImage") and params["resultImage"]: + prompt_messages = assemble_prompt( + params["image"], generated_code_config, params["resultImage"] + ) + else: + prompt_messages = assemble_prompt(params["image"], generated_code_config) + except: + await websocket.send_json( + { + "type": "error", + "value": "Error assembling prompt. Contact support at support@picoapps.xyz", + } + ) + await websocket.close() + return + + # Image cache for updates so that we don't have to regenerate images + image_cache: Dict[str, str] = {} + + if params["generationType"] == "update": + # Transform into message format + # TODO: Move this to frontend + for index, text in enumerate(params["history"]): + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + prompt_messages.append(message) + image_cache = create_alt_url_mapping(params["history"][-2]) + + if SHOULD_MOCK_AI_RESPONSE: + completion = await mock_completion(process_chunk) + else: + try: + completion = await stream_openai_response( + prompt_messages, + api_key=openai_api_key, + base_url=openai_base_url, + callback=lambda x: process_chunk(x), + ) + except openai.AuthenticationError as e: + print("[GENERATE_CODE] Authentication failed", e) + error_message = ( + "Incorrect OpenAI key. Please make sure your OpenAI API key is correct, or create a new OpenAI API key on your OpenAI dashboard." + + ( + " Alternatively, you can purchase code generation credits directly on this website." + if IS_PROD + else "" + ) + ) + return await throw_error(error_message) + except openai.NotFoundError as e: + print("[GENERATE_CODE] Model not found", e) + error_message = ( + e.message + + ". Please make sure you have followed the instructions correctly to obtain an OpenAI key with GPT vision access: https://github.com/abi/screenshot-to-code/blob/main/Troubleshooting.md" + + ( + " Alternatively, you can purchase code generation credits directly on this website." + if IS_PROD + else "" + ) + ) + return await throw_error(error_message) + except openai.RateLimitError as e: + print("[GENERATE_CODE] Rate limit exceeded", e) + error_message = ( + "OpenAI error - 'You exceeded your current quota, please check your plan and billing details.'" + + ( + " Alternatively, you can purchase code generation credits directly on this website." + if IS_PROD + else "" + ) + ) + return await throw_error(error_message) + + # Write the messages dict into a log so that we can debug later + write_logs(prompt_messages, completion) + + try: + if should_generate_images: + await websocket.send_json( + {"type": "status", "value": "Generating images..."} + ) + updated_html = await generate_images( + completion, + api_key=openai_api_key, + base_url=openai_base_url, + image_cache=image_cache, + ) + else: + updated_html = completion + await websocket.send_json({"type": "setCode", "value": updated_html}) + await websocket.send_json( + {"type": "status", "value": "Code generation complete."} + ) + except Exception as e: + traceback.print_exc() + print("Image generation failed", e) + await websocket.send_json( + {"type": "status", "value": "Image generation failed but code is complete."} + ) + + await websocket.close() diff --git a/backend/routes/home.py b/backend/routes/home.py new file mode 100644 index 0000000..c9f66b4 --- /dev/null +++ b/backend/routes/home.py @@ -0,0 +1,12 @@ +from fastapi import APIRouter +from fastapi.responses import HTMLResponse + + +router = APIRouter() + + +@router.get("/") +async def get_status(): + return HTMLResponse( + content="

Your backend is running correctly. Please open the front-end URL (default is http://localhost:5173) to use screenshot-to-code.

" + ) From 52fee9e49b1237a798909d21260682ff3ce3876a Mon Sep 17 00:00:00 2001 From: Abi Raja Date: Sat, 9 Dec 2023 21:00:18 -0500 Subject: [PATCH 3/7] initial implementation of importing from code --- backend/imported_code_prompts.py | 16 ++++ backend/prompts.py | 19 +++++ backend/routes/generate_code.py | 80 ++++++++++++------- frontend/src/App.tsx | 54 +++++++++---- frontend/src/components/ImportCodeSection.tsx | 49 ++++++++++++ .../src/components/history/HistoryDisplay.tsx | 14 +++- .../src/components/history/history_types.ts | 10 ++- frontend/src/components/history/utils.ts | 4 +- frontend/src/generateCode.ts | 1 + 9 files changed, 199 insertions(+), 48 deletions(-) create mode 100644 backend/imported_code_prompts.py create mode 100644 frontend/src/components/ImportCodeSection.tsx diff --git a/backend/imported_code_prompts.py b/backend/imported_code_prompts.py new file mode 100644 index 0000000..28b5fa3 --- /dev/null +++ b/backend/imported_code_prompts.py @@ -0,0 +1,16 @@ +IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT = """ +You are an expert Tailwind developer. + +- Do not add comments in the code such as "" and "" in place of writing the full code. WRITE THE FULL CODE. +- Repeat elements as needed. For example, if there are 15 items, the code should have 15 items. DO NOT LEAVE comments like "" or bad things will happen. +- For images, use placeholder images from https://placehold.co and include a detailed description of the image in the alt text so that an image generation AI can generate the image later. + +In terms of libraries, + +- Use this script to include Tailwind: +- You can use Google Fonts +- Font Awesome for icons: + +Return only the full code in tags. +Do not include markdown "```" or "```html" at the start or end. +""" diff --git a/backend/prompts.py b/backend/prompts.py index f52c195..554f62f 100644 --- a/backend/prompts.py +++ b/backend/prompts.py @@ -2,6 +2,8 @@ from typing import List, Union from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam +from imported_code_prompts import IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT + TAILWIND_SYSTEM_PROMPT = """ You are an expert Tailwind developer @@ -121,6 +123,23 @@ Generate code for a web page that looks exactly like this. """ +def assemble_imported_code_prompt( + code: str, result_image_data_url: Union[str, None] = None +) -> List[ChatCompletionMessageParam]: + system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT + return [ + { + "role": "system", + "content": system_content, + }, + { + "role": "user", + "content": "Here is the code of the app: " + code, + }, + ] + # TODO: Use result_image_data_url + + def assemble_prompt( image_data_url: str, generated_code_config: str, diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index a44aeac..eae9b98 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -8,11 +8,13 @@ from openai.types.chat import ChatCompletionMessageParam from mock_llm import mock_completion from typing import Dict, List from image_generation import create_alt_url_mapping, generate_images -from prompts import assemble_prompt +from prompts import assemble_imported_code_prompt, assemble_prompt from access_token import validate_access_token from datetime import datetime import json +from utils import pprint_prompt # type: ignore + router = APIRouter() @@ -122,43 +124,63 @@ async def stream_code(websocket: WebSocket): async def process_chunk(content: str): await websocket.send_json({"type": "chunk", "value": content}) - # Assemble the prompt - try: - if params.get("resultImage") and params["resultImage"]: - prompt_messages = assemble_prompt( - params["image"], generated_code_config, params["resultImage"] - ) - else: - prompt_messages = assemble_prompt(params["image"], generated_code_config) - except: - await websocket.send_json( - { - "type": "error", - "value": "Error assembling prompt. Contact support at support@picoapps.xyz", - } - ) - await websocket.close() - return - # Image cache for updates so that we don't have to regenerate images image_cache: Dict[str, str] = {} - if params["generationType"] == "update": - # Transform into message format - # TODO: Move this to frontend - for index, text in enumerate(params["history"]): + # If this generation started off with imported code, we need to assemble the prompt differently + if params.get("isImportedFromCode") and params["isImportedFromCode"]: + original_imported_code = params["history"][0] + prompt_messages = assemble_imported_code_prompt(original_imported_code) + for index, text in enumerate(params["history"][1:]): if index % 2 == 0: - message: ChatCompletionMessageParam = { - "role": "assistant", - "content": text, - } - else: message: ChatCompletionMessageParam = { "role": "user", "content": text, } + else: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } prompt_messages.append(message) - image_cache = create_alt_url_mapping(params["history"][-2]) + else: + # Assemble the prompt + try: + if params.get("resultImage") and params["resultImage"]: + prompt_messages = assemble_prompt( + params["image"], generated_code_config, params["resultImage"] + ) + else: + prompt_messages = assemble_prompt( + params["image"], generated_code_config + ) + except: + await websocket.send_json( + { + "type": "error", + "value": "Error assembling prompt. Contact support at support@picoapps.xyz", + } + ) + await websocket.close() + return + + if params["generationType"] == "update": + # Transform the history tree into message format + # TODO: Move this to frontend + for index, text in enumerate(params["history"]): + if index % 2 == 0: + message: ChatCompletionMessageParam = { + "role": "assistant", + "content": text, + } + else: + message: ChatCompletionMessageParam = { + "role": "user", + "content": text, + } + prompt_messages.append(message) + + image_cache = create_alt_url_mapping(params["history"][-2]) if SHOULD_MOCK_AI_RESPONSE: completion = await mock_completion(process_chunk) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 95cd8d3..03ec3ba 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -33,6 +33,7 @@ import { History } from "./components/history/history_types"; import HistoryDisplay from "./components/history/HistoryDisplay"; import { extractHistoryTree } from "./components/history/utils"; import toast from "react-hot-toast"; +import ImportCodeSection from "./components/ImportCodeSection"; const IS_OPENAI_DOWN = false; @@ -43,6 +44,7 @@ function App() { const [referenceImages, setReferenceImages] = useState([]); const [executionConsole, setExecutionConsole] = useState([]); const [updateInstruction, setUpdateInstruction] = useState(""); + const [isImportedFromCode, setIsImportedFromCode] = useState(false); // Settings const [settings, setSettings] = usePersistedState( @@ -118,6 +120,8 @@ function App() { setReferenceImages([]); setExecutionConsole([]); setAppHistory([]); + setCurrentVersion(null); + setIsImportedFromCode(false); }; const stop = () => { @@ -231,6 +235,7 @@ function App() { image: referenceImages[0], resultImage: resultImage, history: updatedHistory, + isImportedFromCode, }, currentVersion ); @@ -240,6 +245,7 @@ function App() { generationType: "update", image: referenceImages[0], history: updatedHistory, + isImportedFromCode, }, currentVersion ); @@ -256,6 +262,21 @@ function App() { })); }; + function importFromCode(code: string) { + setAppState(AppState.CODE_READY); + setGeneratedCode(code); + setAppHistory([ + { + type: "code_create", + parentIndex: null, + code, + inputs: { code }, + }, + ]); + setCurrentVersion(0); + setIsImportedFromCode(true); + } + return (
{IS_RUNNING_ON_CLOUD && } @@ -364,22 +385,24 @@ function App() { {/* Reference image display */}
-
-
- Reference + {referenceImages.length > 0 && ( +
+
+ Reference +
+
+ Original Screenshot +
-
- Original Screenshot -
-
+ )}

Console @@ -424,6 +447,7 @@ function App() { doCreate={doCreate} screenshotOneApiKey={settings.screenshotOneApiKey} /> +

)} diff --git a/frontend/src/components/ImportCodeSection.tsx b/frontend/src/components/ImportCodeSection.tsx new file mode 100644 index 0000000..0199413 --- /dev/null +++ b/frontend/src/components/ImportCodeSection.tsx @@ -0,0 +1,49 @@ +import { useState } from "react"; +import { Button } from "./ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "./ui/dialog"; +import { Textarea } from "./ui/textarea"; + +interface Props { + importFromCode: (code: string) => void; +} + +function ImportCodeSection({ importFromCode }: Props) { + const [code, setCode] = useState(""); + return ( + + + + + + + Paste in your HTML code + + Make sure that the code you're importing is valid HTML. + + + +