Compare commits

..

9 Commits

Author SHA1 Message Date
Abi Raja
a53afd350d update sdxl-lightning references to flux 2024-09-20 16:45:16 +02:00
Abi Raja
6899c7792e when all generations fail, print the all the underlying exceptions for debugging 2024-09-20 13:56:20 +02:00
Abi Raja
9199fee21d keep generating even if one of the models fails 2024-09-17 15:56:35 +02:00
Abi Raja
8717298def switch to flux schnell for replicate image gen 2024-09-17 12:20:19 +02:00
Abi Raja
a4087e613f make variant 1 claude 2024-09-16 15:21:57 +02:00
Abi Raja
cf6b94d675 update favicon/branding 2024-09-14 17:44:57 +02:00
Abi Raja
589507846b remove model selection dropdown since that happens on the background now 2024-09-13 14:41:00 +02:00
Abi Raja
24995c302e
Merge pull request #394 from abi/multiple-generations
Support multiple variants for each generation
2024-09-06 11:43:44 -04:00
Abi Raja
67ce707c3c
Update README.md 2024-08-30 08:47:27 -04:00
11 changed files with 64 additions and 132 deletions

View File

@ -33,10 +33,6 @@ We also just added experimental support for taking a video/screen recording of a
[Follow me on Twitter for updates](https://twitter.com/_abi_). [Follow me on Twitter for updates](https://twitter.com/_abi_).
## Sponsors
<a href="https://konghq.com/products/kong-konnect?utm_medium=referral&utm_source=github&utm_campaign=platform&utm_content=screenshot-to-code" target="_blank" title="Kong - powering the API world"><img src="https://picoapps.xyz/s2c-sponsors/Kong-GitHub-240x100.png"></a>
## 🚀 Hosted Version ## 🚀 Hosted Version
[Try it live on the hosted version (paid)](https://screenshottocode.com). [Try it live on the hosted version (paid)](https://screenshottocode.com).

View File

@ -7,19 +7,19 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: check-yaml - id: check-yaml
- id: check-added-large-files - id: check-added-large-files
- repo: local # - repo: local
hooks: # hooks:
- id: poetry-pytest # - id: poetry-pytest
name: Run pytest with Poetry # name: Run pytest with Poetry
entry: poetry run --directory backend pytest # entry: poetry run --directory backend pytest
language: system # language: system
pass_filenames: false # pass_filenames: false
always_run: true # always_run: true
files: ^backend/ # files: ^backend/
# - id: poetry-pyright # # - id: poetry-pyright
# name: Run pyright with Poetry # # name: Run pyright with Poetry
# entry: poetry run --directory backend pyright # # entry: poetry run --directory backend pyright
# language: system # # language: system
# pass_filenames: false # # pass_filenames: false
# always_run: true # # always_run: true
# files: ^backend/ # # files: ^backend/

View File

@ -11,7 +11,7 @@ async def process_tasks(
prompts: List[str], prompts: List[str],
api_key: str, api_key: str,
base_url: str | None, base_url: str | None,
model: Literal["dalle3", "sdxl-lightning"], model: Literal["dalle3", "flux"],
): ):
import time import time
@ -54,18 +54,14 @@ async def generate_image_dalle(
async def generate_image_replicate(prompt: str, api_key: str) -> str: async def generate_image_replicate(prompt: str, api_key: str) -> str:
# We use SDXL Lightning # We use Flux Schnell
return await call_replicate( return await call_replicate(
"5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
{ {
"width": 1024,
"height": 1024,
"prompt": prompt, "prompt": prompt,
"scheduler": "K_EULER",
"num_outputs": 1, "num_outputs": 1,
"guidance_scale": 0, "aspect_ratio": "1:1",
"negative_prompt": "worst quality, low quality", "output_format": "png",
"num_inference_steps": 4, "output_quality": 100,
}, },
api_key, api_key,
) )
@ -102,7 +98,7 @@ async def generate_images(
api_key: str, api_key: str,
base_url: Union[str, None], base_url: Union[str, None],
image_cache: Dict[str, str], image_cache: Dict[str, str],
model: Literal["dalle3", "sdxl-lightning"] = "dalle3", model: Literal["dalle3", "flux"] = "dalle3",
) -> str: ) -> str:
# Find all images # Find all images
soup = BeautifulSoup(code, "html.parser") soup = BeautifulSoup(code, "html.parser")

View File

@ -2,20 +2,21 @@ import asyncio
import httpx import httpx
async def call_replicate( async def call_replicate(input: dict[str, str | int], api_token: str) -> str:
replicate_model_version: str, input: dict[str, str | int], api_token: str
) -> str:
url = "https://api.replicate.com/v1/predictions"
headers = { headers = {
"Authorization": f"Bearer {api_token}", "Authorization": f"Bearer {api_token}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
data = {"version": replicate_model_version, "input": input} data = {"input": input}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: try:
response = await client.post(url, headers=headers, json=data) response = await client.post(
"https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions",
headers=headers,
json=data,
)
response.raise_for_status() response.raise_for_status()
response_json = response.json() response_json = response.json()
@ -24,16 +25,18 @@ async def call_replicate(
if not prediction_id: if not prediction_id:
raise ValueError("Prediction ID not found in initial response.") raise ValueError("Prediction ID not found in initial response.")
# Polling every 1 second until the status is succeeded or error # Polling every 0.1 seconds until the status is succeeded or error (upto 10s)
num_polls = 0 num_polls = 0
max_polls = 100 max_polls = 100
while num_polls < max_polls: while num_polls < max_polls:
num_polls += 1 num_polls += 1
await asyncio.sleep(0.2) await asyncio.sleep(0.1)
# Check the status # Check the status
status_check_url = f"{url}/{prediction_id}" status_check_url = (
f"https://api.replicate.com/v1/predictions/{prediction_id}"
)
status_response = await client.get(status_check_url, headers=headers) status_response = await client.get(status_check_url, headers=headers)
status_response.raise_for_status() status_response.raise_for_status()
status_response_json = status_response.json() status_response_json = status_response.json()

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
import traceback
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import openai import openai
from codegen.utils import extract_html_content from codegen.utils import extract_html_content
@ -63,7 +64,7 @@ async def perform_image_generation(
return completion return completion
if replicate_api_key: if replicate_api_key:
image_generation_model = "sdxl-lightning" image_generation_model = "flux"
api_key = replicate_api_key api_key = replicate_api_key
else: else:
if not openai_api_key: if not openai_api_key:
@ -271,7 +272,7 @@ async def stream_code(websocket: WebSocket):
# we decide which models to run # we decide which models to run
variant_models = [] variant_models = []
if openai_api_key and anthropic_api_key: if openai_api_key and anthropic_api_key:
variant_models = ["openai", "anthropic"] variant_models = ["anthropic", "openai"]
elif openai_api_key: elif openai_api_key:
variant_models = ["openai", "openai"] variant_models = ["openai", "openai"]
elif anthropic_api_key: elif anthropic_api_key:
@ -312,7 +313,29 @@ async def stream_code(websocket: WebSocket):
) )
) )
completions = await asyncio.gather(*tasks) # Run the models in parallel and capture exceptions if any
completions = await asyncio.gather(*tasks, return_exceptions=True)
# If all generations failed, throw an error
all_generations_failed = all(
isinstance(completion, Exception) for completion in completions
)
if all_generations_failed:
await throw_error("Error generating code. Please contact support.")
# Print the all the underlying exceptions for debugging
for completion in completions:
traceback.print_exception(
type(completion), completion, completion.__traceback__
)
raise Exception("All generations failed")
# If some completions failed, replace them with empty strings
for index, completion in enumerate(completions):
if isinstance(completion, Exception):
completions[index] = ""
print("Generation failed for variant", index)
print("Models used for generation: ", variant_models) print("Models used for generation: ", variant_models)
except openai.AuthenticationError as e: except openai.AuthenticationError as e:

View File

@ -41,7 +41,7 @@ OUTPUT_DIR: str = "generated_images"
async def generate_and_save_images( async def generate_and_save_images(
prompts: List[str], prompts: List[str],
model: Literal["dalle3", "sdxl-lightning"], model: Literal["dalle3", "flux"],
api_key: Optional[str], api_key: Optional[str],
) -> None: ) -> None:
# Ensure the output directory exists # Ensure the output directory exists
@ -64,7 +64,7 @@ async def generate_and_save_images(
image_data: bytes = await response.read() image_data: bytes = await response.read()
# Save the image with a filename based on the input eval # Save the image with a filename based on the input eval
prefix = "replicate_" if model == "sdxl-lightning" else "dalle3_" prefix = "replicate_" if model == "flux" else "dalle3_"
filename: str = ( filename: str = (
f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png" f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png"
) )
@ -78,7 +78,7 @@ async def generate_and_save_images(
async def main() -> None: async def main() -> None:
# await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY) # await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY)
await generate_and_save_images(EVALS, "sdxl-lightning", REPLICATE_API_TOKEN) await generate_and_save_images(EVALS, "flux", REPLICATE_API_TOKEN)
if __name__ == "__main__": if __name__ == "__main__":

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 KiB

After

Width:  |  Height:  |  Size: 16 KiB

View File

@ -358,11 +358,7 @@ function App() {
</div> </div>
{/* Generation settings like stack and model */} {/* Generation settings like stack and model */}
<GenerationSettings <GenerationSettings settings={settings} setSettings={setSettings} />
settings={settings}
setSettings={setSettings}
selectedCodeGenerationModel={model}
/>
{/* Show auto updated message when older models are choosen */} {/* Show auto updated message when older models are choosen */}
{showBetterModelMessage && <DeprecationMessage />} {showBetterModelMessage && <DeprecationMessage />}

View File

@ -2,20 +2,16 @@ import React from "react";
import { useAppStore } from "../../store/app-store"; import { useAppStore } from "../../store/app-store";
import { AppState, Settings } from "../../types"; import { AppState, Settings } from "../../types";
import OutputSettingsSection from "./OutputSettingsSection"; import OutputSettingsSection from "./OutputSettingsSection";
import ModelSettingsSection from "./ModelSettingsSection";
import { Stack } from "../../lib/stacks"; import { Stack } from "../../lib/stacks";
import { CodeGenerationModel } from "../../lib/models";
interface GenerationSettingsProps { interface GenerationSettingsProps {
settings: Settings; settings: Settings;
setSettings: React.Dispatch<React.SetStateAction<Settings>>; setSettings: React.Dispatch<React.SetStateAction<Settings>>;
selectedCodeGenerationModel: CodeGenerationModel;
} }
export const GenerationSettings: React.FC<GenerationSettingsProps> = ({ export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
settings, settings,
setSettings, setSettings,
selectedCodeGenerationModel,
}) => { }) => {
const { appState } = useAppStore(); const { appState } = useAppStore();
@ -26,13 +22,6 @@ export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
})); }));
} }
function setCodeGenerationModel(codeGenerationModel: CodeGenerationModel) {
setSettings((prev: Settings) => ({
...prev,
codeGenerationModel,
}));
}
const shouldDisableUpdates = const shouldDisableUpdates =
appState === AppState.CODING || appState === AppState.CODE_READY; appState === AppState.CODING || appState === AppState.CODE_READY;
@ -43,12 +32,6 @@ export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
setStack={setStack} setStack={setStack}
shouldDisableUpdates={shouldDisableUpdates} shouldDisableUpdates={shouldDisableUpdates}
/> />
<ModelSettingsSection
codeGenerationModel={selectedCodeGenerationModel}
setCodeGenerationModel={setCodeGenerationModel}
shouldDisableUpdates={shouldDisableUpdates}
/>
</div> </div>
); );
}; };

View File

@ -1,65 +0,0 @@
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectTrigger,
} from "../ui/select";
import {
CODE_GENERATION_MODEL_DESCRIPTIONS,
CodeGenerationModel,
} from "../../lib/models";
import { Badge } from "../ui/badge";
interface Props {
codeGenerationModel: CodeGenerationModel;
setCodeGenerationModel: (codeGenerationModel: CodeGenerationModel) => void;
shouldDisableUpdates?: boolean;
}
function ModelSettingsSection({
codeGenerationModel,
setCodeGenerationModel,
shouldDisableUpdates = false,
}: Props) {
return (
<div className="flex flex-col gap-y-2 justify-between text-sm">
<div className="grid grid-cols-3 items-center gap-4">
<span>AI Model:</span>
<Select
value={codeGenerationModel}
onValueChange={(value: string) =>
setCodeGenerationModel(value as CodeGenerationModel)
}
disabled={shouldDisableUpdates}
>
<SelectTrigger className="col-span-2" id="output-settings-js">
<span className="font-semibold">
{CODE_GENERATION_MODEL_DESCRIPTIONS[codeGenerationModel].name}
</span>
</SelectTrigger>
<SelectContent>
<SelectGroup>
{Object.values(CodeGenerationModel).map((model) => (
<SelectItem key={model} value={model}>
<div className="flex items-center">
<span className="font-semibold">
{CODE_GENERATION_MODEL_DESCRIPTIONS[model].name}
</span>
{CODE_GENERATION_MODEL_DESCRIPTIONS[model].inBeta && (
<Badge className="ml-2" variant="secondary">
Beta
</Badge>
)}
</div>
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</div>
</div>
);
}
export default ModelSettingsSection;