Merge branch 'main' into hosted

This commit is contained in:
Abi Raja 2023-12-07 11:52:37 -05:00
commit 1fb390e48c
10 changed files with 1028 additions and 36 deletions

View File

@ -1,6 +1,7 @@
# Load environment variables first
from dotenv import load_dotenv
load_dotenv()
@ -14,6 +15,7 @@ from fastapi.responses import HTMLResponse
import openai
from llm import stream_openai_response
from mock import mock_completion
from utils import pprint_prompt
from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_prompt
from routes import screenshot
@ -197,7 +199,6 @@ async def stream_code(websocket: WebSocket):
prompt_messages += [
{"role": "assistant" if index % 2 == 0 else "user", "content": text}
]
image_cache = create_alt_url_mapping(params["history"][-2])
if SHOULD_MOCK_AI_RESPONSE:

View File

@ -1,4 +1,9 @@
import copy
import json
def pprint_prompt(prompt_messages):
print(json.dumps(truncate_data_strings(prompt_messages), indent=4))
def truncate_data_strings(data):
@ -10,9 +15,12 @@ def truncate_data_strings(data):
# Recursively call the function if the value is a dictionary or a list
if isinstance(value, (dict, list)):
cloned_data[key] = truncate_data_strings(value)
# Truncate the string if it starts with 'data:'
elif isinstance(value, str) and value.startswith("data:"):
cloned_data[key] = value[:20]
# Truncate the string if it it's long and add ellipsis and length
elif isinstance(value, str):
cloned_data[key] = value[:40]
if len(value) > 40:
cloned_data[key] += "..." + f" ({len(value)} chars)"
elif isinstance(cloned_data, list):
# Process each item in the list
cloned_data = [truncate_data_strings(item) for item in cloned_data]

View File

@ -9,7 +9,8 @@
"build": "tsc && vite build",
"build-hosted": "tsc && vite build --mode prod",
"lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0",
"preview": "vite preview"
"preview": "vite preview",
"test": "vitest"
},
"dependencies": {
"@codemirror/lang-html": "^6.4.6",
@ -21,6 +22,7 @@
"@radix-ui/react-label": "^2.0.2",
"@radix-ui/react-popover": "^1.0.7",
"@radix-ui/react-progress": "^1.0.3",
"@radix-ui/react-scroll-area": "^1.0.5",
"@radix-ui/react-select": "^2.0.0",
"@radix-ui/react-separator": "^1.0.3",
"@radix-ui/react-slot": "^1.0.2",
@ -57,7 +59,8 @@
"tailwindcss": "^3.3.5",
"typescript": "^5.0.2",
"vite": "^4.4.5",
"vite-plugin-html": "^3.2.0"
"vite-plugin-html": "^3.2.0",
"vitest": "^1.0.1"
},
"engines": {
"node": ">=14.18.0"

View File

@ -30,16 +30,22 @@ import { USER_CLOSE_WEB_SOCKET_CODE } from "./constants";
import CodeTab from "./components/CodeTab";
import OutputSettingsSection from "./components/OutputSettingsSection";
import { addEvent } from "./lib/analytics";
import { History } from "./components/history/history_types";
import HistoryDisplay from "./components/history/HistoryDisplay";
import { extractHistoryTree } from "./components/history/utils";
import toast from "react-hot-toast";
const IS_OPENAI_DOWN = false;
function App() {
const [appState, setAppState] = useState<AppState>(AppState.INITIAL);
const [generatedCode, setGeneratedCode] = useState<string>("");
const [referenceImages, setReferenceImages] = useState<string[]>([]);
const [executionConsole, setExecutionConsole] = useState<string[]>([]);
const [updateInstruction, setUpdateInstruction] = useState("");
const [history, setHistory] = useState<string[]>([]);
// Settings
const [settings, setSettings] = usePersistedState<Settings>(
{
openAiApiKey: null,
@ -55,6 +61,11 @@ function App() {
"setting"
);
// App history
const [appHistory, setAppHistory] = useState<History>([]);
// Tracks the currently viewed version from app history
const [currentVersion, setCurrentVersion] = useState<number | null>(null);
const [shouldIncludeResultImage, setShouldIncludeResultImage] =
useState<boolean>(false);
@ -109,7 +120,7 @@ function App() {
setGeneratedCode("");
setReferenceImages([]);
setExecutionConsole([]);
setHistory([]);
setAppHistory([]);
};
const stop = () => {
@ -118,7 +129,10 @@ function App() {
setAppState(AppState.CODE_READY);
};
function doGenerateCode(params: CodeGenerationParams) {
function doGenerateCode(
params: CodeGenerationParams,
parentVersion: number | null
) {
setExecutionConsole([]);
setAppState(AppState.CODING);
@ -129,9 +143,48 @@ function App() {
wsRef,
updatedParams,
(token) => setGeneratedCode((prev) => prev + token),
(code) => setGeneratedCode(code),
(code) => {
setGeneratedCode(code);
if (params.generationType === "create") {
setAppHistory([
{
type: "ai_create",
parentIndex: null,
code,
inputs: { image_url: referenceImages[0] },
},
]);
setCurrentVersion(0);
} else {
setAppHistory((prev) => {
// Validate parent version
if (parentVersion === null) {
toast.error(
"No parent version set. Contact support or open a Github issue."
);
return prev;
}
const newHistory: History = [
...prev,
{
type: "ai_edit",
parentIndex: parentVersion,
code,
inputs: {
prompt: updateInstruction,
},
},
];
setCurrentVersion(newHistory.length - 1);
return newHistory;
});
}
},
(line) => setExecutionConsole((prev) => [...prev, line]),
() => setAppState(AppState.CODE_READY)
() => {
setAppState(AppState.CODE_READY);
}
);
}
@ -142,33 +195,52 @@ function App() {
setReferenceImages(referenceImages);
if (referenceImages.length > 0) {
doGenerateCode({
generationType: "create",
image: referenceImages[0],
});
doGenerateCode(
{
generationType: "create",
image: referenceImages[0],
},
currentVersion
);
}
}
// Subsequent updates
async function doUpdate() {
const updatedHistory = [...history, generatedCode, updateInstruction];
if (shouldIncludeResultImage) {
const resultImage = await takeScreenshot();
doGenerateCode({
generationType: "update",
image: referenceImages[0],
resultImage: resultImage,
history: updatedHistory,
});
} else {
doGenerateCode({
generationType: "update",
image: referenceImages[0],
history: updatedHistory,
});
if (currentVersion === null) {
toast.error(
"No current version set. Contact support or open a Github issue."
);
return;
}
const updatedHistory = [
...extractHistoryTree(appHistory, currentVersion),
updateInstruction,
];
if (shouldIncludeResultImage) {
const resultImage = await takeScreenshot();
doGenerateCode(
{
generationType: "update",
image: referenceImages[0],
resultImage: resultImage,
history: updatedHistory,
},
currentVersion
);
} else {
doGenerateCode(
{
generationType: "update",
image: referenceImages[0],
history: updatedHistory,
},
currentVersion
);
}
setHistory(updatedHistory);
setGeneratedCode("");
setUpdateInstruction("");
}
@ -320,6 +392,23 @@ function App() {
</div>
</>
)}
{
<HistoryDisplay
history={appHistory}
currentVersion={currentVersion}
revertToVersion={(index) => {
if (
index < 0 ||
index >= appHistory.length ||
!appHistory[index]
)
return;
setCurrentVersion(index);
setGeneratedCode(appHistory[index].code);
}}
shouldDisableReverts={appState === AppState.CODING}
/>
}
</div>
</div>

View File

@ -0,0 +1,75 @@
import { ScrollArea } from "@/components/ui/scroll-area";
import { History, HistoryItemType } from "./history_types";
import toast from "react-hot-toast";
import classNames from "classnames";
interface Props {
history: History;
currentVersion: number | null;
revertToVersion: (version: number) => void;
shouldDisableReverts: boolean;
}
function displayHistoryItemType(itemType: HistoryItemType) {
switch (itemType) {
case "ai_create":
return "Create";
case "ai_edit":
return "Edit";
default:
// TODO: Error out since this is exhaustive
return "Unknown";
}
}
export default function HistoryDisplay({
history,
currentVersion,
revertToVersion,
shouldDisableReverts,
}: Props) {
return history.length === 0 ? null : (
<div className="flex flex-col h-screen">
<h1 className="font-bold mb-2">Versions</h1>
<ScrollArea className="flex-1 overflow-y-auto">
<ul className="space-y-0 flex flex-col-reverse">
{history.map((item, index) => (
<li
key={index}
className={classNames(
"flex items-center space-x-2 justify-between p-2",
"border-b cursor-pointer",
{
" hover:bg-black hover:text-white": index !== currentVersion,
"bg-slate-500 text-white": index === currentVersion,
}
)}
onClick={() =>
shouldDisableReverts
? toast.error(
"Please wait for code generation to complete before viewing an older version."
)
: revertToVersion(index)
}
>
<div className="flex gap-x-1">
<h2 className="text-sm">{displayHistoryItemType(item.type)}</h2>
{item.parentIndex !== null && item.parentIndex !== index - 1 ? (
<h2 className="text-sm">
(parent: v{(item.parentIndex || 0) + 1})
</h2>
) : null}
</div>
<h2 className="text-sm">
{item.type === "ai_edit"
? item.inputs.prompt
: item.inputs.image_url}
</h2>
<h2 className="text-sm">v{index + 1}</h2>
</li>
))}
</ul>
</ScrollArea>
</div>
);
}

View File

@ -0,0 +1,26 @@
export type HistoryItemType = "ai_create" | "ai_edit";
type CommonHistoryItem = {
parentIndex: null | number;
code: string;
};
export type HistoryItem =
| ({
type: "ai_create";
inputs: AiCreateInputs;
} & CommonHistoryItem)
| ({
type: "ai_edit";
inputs: AiEditInputs;
} & CommonHistoryItem);
export type AiCreateInputs = {
image_url: string;
};
export type AiEditInputs = {
prompt: string;
};
export type History = HistoryItem[];

View File

@ -0,0 +1,103 @@
import { expect, test } from "vitest";
import { extractHistoryTree } from "./utils";
import type { History } from "./history_types";
const basicLinearHistory: History = [
{
type: "ai_create",
parentIndex: null,
code: "<html>1. create</html>",
inputs: {
image_url: "",
},
},
{
type: "ai_edit",
parentIndex: 0,
code: "<html>2. edit with better icons</html>",
inputs: {
prompt: "use better icons",
},
},
{
type: "ai_edit",
parentIndex: 1,
code: "<html>3. edit with better icons and red text</html>",
inputs: {
prompt: "make text red",
},
},
];
const basicBranchingHistory: History = [
...basicLinearHistory,
{
type: "ai_edit",
parentIndex: 1,
code: "<html>4. edit with better icons and green text</html>",
inputs: {
prompt: "make text green",
},
},
];
const longerBranchingHistory: History = [
...basicBranchingHistory,
{
type: "ai_edit",
parentIndex: 3,
code: "<html>5. edit with better icons and green, bold text</html>",
inputs: {
prompt: "make text bold",
},
},
];
test("should only include history from this point onward", () => {
expect(extractHistoryTree(basicLinearHistory, 2)).toEqual([
"<html>1. create</html>",
"use better icons",
"<html>2. edit with better icons</html>",
"make text red",
"<html>3. edit with better icons and red text</html>",
]);
expect(extractHistoryTree(basicLinearHistory, 0)).toEqual([
"<html>1. create</html>",
]);
// Test branching
expect(extractHistoryTree(basicBranchingHistory, 3)).toEqual([
"<html>1. create</html>",
"use better icons",
"<html>2. edit with better icons</html>",
"make text green",
"<html>4. edit with better icons and green text</html>",
]);
expect(extractHistoryTree(longerBranchingHistory, 4)).toEqual([
"<html>1. create</html>",
"use better icons",
"<html>2. edit with better icons</html>",
"make text green",
"<html>4. edit with better icons and green text</html>",
"make text bold",
"<html>5. edit with better icons and green, bold text</html>",
]);
expect(extractHistoryTree(longerBranchingHistory, 2)).toEqual([
"<html>1. create</html>",
"use better icons",
"<html>2. edit with better icons</html>",
"make text red",
"<html>3. edit with better icons and red text</html>",
]);
// Errors - TODO: Handle these
// Bad index
// TODO: Throw an exception instead?
expect(extractHistoryTree(basicLinearHistory, 100)).toEqual([]);
expect(extractHistoryTree(basicLinearHistory, -2)).toEqual([]);
// Bad tree
});

View File

@ -0,0 +1,32 @@
import { History, HistoryItem } from "./history_types";
export function extractHistoryTree(
history: History,
version: number
): string[] {
const flatHistory: string[] = [];
let currentIndex: number | null = version;
while (currentIndex !== null) {
const item: HistoryItem = history[currentIndex];
if (item) {
if (item.type === "ai_create") {
// Don't include the image for ai_create
flatHistory.unshift(item.code);
} else {
flatHistory.unshift(item.code);
flatHistory.unshift(item.inputs.prompt);
}
// Move to the parent of the current item
currentIndex = item.parentIndex;
} else {
// TODO: Throw an exception here?
// Break the loop if the item is not found (should not happen in a well-formed history)
break;
}
}
return flatHistory;
}

View File

@ -0,0 +1,46 @@
import * as React from "react"
import * as ScrollAreaPrimitive from "@radix-ui/react-scroll-area"
import { cn } from "@/lib/utils"
const ScrollArea = React.forwardRef<
React.ElementRef<typeof ScrollAreaPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof ScrollAreaPrimitive.Root>
>(({ className, children, ...props }, ref) => (
<ScrollAreaPrimitive.Root
ref={ref}
className={cn("relative overflow-hidden", className)}
{...props}
>
<ScrollAreaPrimitive.Viewport className="h-full w-full rounded-[inherit]">
{children}
</ScrollAreaPrimitive.Viewport>
<ScrollBar />
<ScrollAreaPrimitive.Corner />
</ScrollAreaPrimitive.Root>
))
ScrollArea.displayName = ScrollAreaPrimitive.Root.displayName
const ScrollBar = React.forwardRef<
React.ElementRef<typeof ScrollAreaPrimitive.ScrollAreaScrollbar>,
React.ComponentPropsWithoutRef<typeof ScrollAreaPrimitive.ScrollAreaScrollbar>
>(({ className, orientation = "vertical", ...props }, ref) => (
<ScrollAreaPrimitive.ScrollAreaScrollbar
ref={ref}
orientation={orientation}
className={cn(
"flex touch-none select-none transition-colors",
orientation === "vertical" &&
"h-full w-2.5 border-l border-l-transparent p-[1px]",
orientation === "horizontal" &&
"h-2.5 flex-col border-t border-t-transparent p-[1px]",
className
)}
{...props}
>
<ScrollAreaPrimitive.ScrollAreaThumb className="relative flex-1 rounded-full bg-border" />
</ScrollAreaPrimitive.ScrollAreaScrollbar>
))
ScrollBar.displayName = ScrollAreaPrimitive.ScrollAreaScrollbar.displayName
export { ScrollArea, ScrollBar }

File diff suppressed because it is too large Load Diff