Compare commits

...

7 Commits

Author SHA1 Message Date
Abi Raja
a53afd350d update sdxl-lightning references to flux 2024-09-20 16:45:16 +02:00
Abi Raja
6899c7792e when all generations fail, print the all the underlying exceptions for debugging 2024-09-20 13:56:20 +02:00
Abi Raja
9199fee21d keep generating even if one of the models fails 2024-09-17 15:56:35 +02:00
Abi Raja
8717298def switch to flux schnell for replicate image gen 2024-09-17 12:20:19 +02:00
Abi Raja
a4087e613f make variant 1 claude 2024-09-16 15:21:57 +02:00
Abi Raja
cf6b94d675 update favicon/branding 2024-09-14 17:44:57 +02:00
Abi Raja
589507846b remove model selection dropdown since that happens on the background now 2024-09-13 14:41:00 +02:00
10 changed files with 64 additions and 128 deletions

View File

@ -7,19 +7,19 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: local
hooks:
- id: poetry-pytest
name: Run pytest with Poetry
entry: poetry run --directory backend pytest
language: system
pass_filenames: false
always_run: true
files: ^backend/
# - id: poetry-pyright
# name: Run pyright with Poetry
# entry: poetry run --directory backend pyright
# language: system
# pass_filenames: false
# always_run: true
# files: ^backend/
# - repo: local
# hooks:
# - id: poetry-pytest
# name: Run pytest with Poetry
# entry: poetry run --directory backend pytest
# language: system
# pass_filenames: false
# always_run: true
# files: ^backend/
# # - id: poetry-pyright
# # name: Run pyright with Poetry
# # entry: poetry run --directory backend pyright
# # language: system
# # pass_filenames: false
# # always_run: true
# # files: ^backend/

View File

@ -11,7 +11,7 @@ async def process_tasks(
prompts: List[str],
api_key: str,
base_url: str | None,
model: Literal["dalle3", "sdxl-lightning"],
model: Literal["dalle3", "flux"],
):
import time
@ -54,18 +54,14 @@ async def generate_image_dalle(
async def generate_image_replicate(prompt: str, api_key: str) -> str:
# We use SDXL Lightning
# We use Flux Schnell
return await call_replicate(
"5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
{
"width": 1024,
"height": 1024,
"prompt": prompt,
"scheduler": "K_EULER",
"num_outputs": 1,
"guidance_scale": 0,
"negative_prompt": "worst quality, low quality",
"num_inference_steps": 4,
"aspect_ratio": "1:1",
"output_format": "png",
"output_quality": 100,
},
api_key,
)
@ -102,7 +98,7 @@ async def generate_images(
api_key: str,
base_url: Union[str, None],
image_cache: Dict[str, str],
model: Literal["dalle3", "sdxl-lightning"] = "dalle3",
model: Literal["dalle3", "flux"] = "dalle3",
) -> str:
# Find all images
soup = BeautifulSoup(code, "html.parser")

View File

@ -2,20 +2,21 @@ import asyncio
import httpx
async def call_replicate(
replicate_model_version: str, input: dict[str, str | int], api_token: str
) -> str:
url = "https://api.replicate.com/v1/predictions"
async def call_replicate(input: dict[str, str | int], api_token: str) -> str:
headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json",
}
data = {"version": replicate_model_version, "input": input}
data = {"input": input}
async with httpx.AsyncClient() as client:
try:
response = await client.post(url, headers=headers, json=data)
response = await client.post(
"https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions",
headers=headers,
json=data,
)
response.raise_for_status()
response_json = response.json()
@ -24,16 +25,18 @@ async def call_replicate(
if not prediction_id:
raise ValueError("Prediction ID not found in initial response.")
# Polling every 1 second until the status is succeeded or error
# Polling every 0.1 seconds until the status is succeeded or error (upto 10s)
num_polls = 0
max_polls = 100
while num_polls < max_polls:
num_polls += 1
await asyncio.sleep(0.2)
await asyncio.sleep(0.1)
# Check the status
status_check_url = f"{url}/{prediction_id}"
status_check_url = (
f"https://api.replicate.com/v1/predictions/{prediction_id}"
)
status_response = await client.get(status_check_url, headers=headers)
status_response.raise_for_status()
status_response_json = status_response.json()

View File

@ -1,5 +1,6 @@
import asyncio
from dataclasses import dataclass
import traceback
from fastapi import APIRouter, WebSocket
import openai
from codegen.utils import extract_html_content
@ -63,7 +64,7 @@ async def perform_image_generation(
return completion
if replicate_api_key:
image_generation_model = "sdxl-lightning"
image_generation_model = "flux"
api_key = replicate_api_key
else:
if not openai_api_key:
@ -271,7 +272,7 @@ async def stream_code(websocket: WebSocket):
# we decide which models to run
variant_models = []
if openai_api_key and anthropic_api_key:
variant_models = ["openai", "anthropic"]
variant_models = ["anthropic", "openai"]
elif openai_api_key:
variant_models = ["openai", "openai"]
elif anthropic_api_key:
@ -312,7 +313,29 @@ async def stream_code(websocket: WebSocket):
)
)
completions = await asyncio.gather(*tasks)
# Run the models in parallel and capture exceptions if any
completions = await asyncio.gather(*tasks, return_exceptions=True)
# If all generations failed, throw an error
all_generations_failed = all(
isinstance(completion, Exception) for completion in completions
)
if all_generations_failed:
await throw_error("Error generating code. Please contact support.")
# Print the all the underlying exceptions for debugging
for completion in completions:
traceback.print_exception(
type(completion), completion, completion.__traceback__
)
raise Exception("All generations failed")
# If some completions failed, replace them with empty strings
for index, completion in enumerate(completions):
if isinstance(completion, Exception):
completions[index] = ""
print("Generation failed for variant", index)
print("Models used for generation: ", variant_models)
except openai.AuthenticationError as e:

View File

@ -41,7 +41,7 @@ OUTPUT_DIR: str = "generated_images"
async def generate_and_save_images(
prompts: List[str],
model: Literal["dalle3", "sdxl-lightning"],
model: Literal["dalle3", "flux"],
api_key: Optional[str],
) -> None:
# Ensure the output directory exists
@ -64,7 +64,7 @@ async def generate_and_save_images(
image_data: bytes = await response.read()
# Save the image with a filename based on the input eval
prefix = "replicate_" if model == "sdxl-lightning" else "dalle3_"
prefix = "replicate_" if model == "flux" else "dalle3_"
filename: str = (
f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png"
)
@ -78,7 +78,7 @@ async def generate_and_save_images(
async def main() -> None:
# await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY)
await generate_and_save_images(EVALS, "sdxl-lightning", REPLICATE_API_TOKEN)
await generate_and_save_images(EVALS, "flux", REPLICATE_API_TOKEN)
if __name__ == "__main__":

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 KiB

After

Width:  |  Height:  |  Size: 16 KiB

View File

@ -358,11 +358,7 @@ function App() {
</div>
{/* Generation settings like stack and model */}
<GenerationSettings
settings={settings}
setSettings={setSettings}
selectedCodeGenerationModel={model}
/>
<GenerationSettings settings={settings} setSettings={setSettings} />
{/* Show auto updated message when older models are choosen */}
{showBetterModelMessage && <DeprecationMessage />}

View File

@ -2,20 +2,16 @@ import React from "react";
import { useAppStore } from "../../store/app-store";
import { AppState, Settings } from "../../types";
import OutputSettingsSection from "./OutputSettingsSection";
import ModelSettingsSection from "./ModelSettingsSection";
import { Stack } from "../../lib/stacks";
import { CodeGenerationModel } from "../../lib/models";
interface GenerationSettingsProps {
settings: Settings;
setSettings: React.Dispatch<React.SetStateAction<Settings>>;
selectedCodeGenerationModel: CodeGenerationModel;
}
export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
settings,
setSettings,
selectedCodeGenerationModel,
}) => {
const { appState } = useAppStore();
@ -26,13 +22,6 @@ export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
}));
}
function setCodeGenerationModel(codeGenerationModel: CodeGenerationModel) {
setSettings((prev: Settings) => ({
...prev,
codeGenerationModel,
}));
}
const shouldDisableUpdates =
appState === AppState.CODING || appState === AppState.CODE_READY;
@ -43,12 +32,6 @@ export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
setStack={setStack}
shouldDisableUpdates={shouldDisableUpdates}
/>
<ModelSettingsSection
codeGenerationModel={selectedCodeGenerationModel}
setCodeGenerationModel={setCodeGenerationModel}
shouldDisableUpdates={shouldDisableUpdates}
/>
</div>
);
};

View File

@ -1,65 +0,0 @@
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectTrigger,
} from "../ui/select";
import {
CODE_GENERATION_MODEL_DESCRIPTIONS,
CodeGenerationModel,
} from "../../lib/models";
import { Badge } from "../ui/badge";
interface Props {
codeGenerationModel: CodeGenerationModel;
setCodeGenerationModel: (codeGenerationModel: CodeGenerationModel) => void;
shouldDisableUpdates?: boolean;
}
function ModelSettingsSection({
codeGenerationModel,
setCodeGenerationModel,
shouldDisableUpdates = false,
}: Props) {
return (
<div className="flex flex-col gap-y-2 justify-between text-sm">
<div className="grid grid-cols-3 items-center gap-4">
<span>AI Model:</span>
<Select
value={codeGenerationModel}
onValueChange={(value: string) =>
setCodeGenerationModel(value as CodeGenerationModel)
}
disabled={shouldDisableUpdates}
>
<SelectTrigger className="col-span-2" id="output-settings-js">
<span className="font-semibold">
{CODE_GENERATION_MODEL_DESCRIPTIONS[codeGenerationModel].name}
</span>
</SelectTrigger>
<SelectContent>
<SelectGroup>
{Object.values(CodeGenerationModel).map((model) => (
<SelectItem key={model} value={model}>
<div className="flex items-center">
<span className="font-semibold">
{CODE_GENERATION_MODEL_DESCRIPTIONS[model].name}
</span>
{CODE_GENERATION_MODEL_DESCRIPTIONS[model].inBeta && (
<Badge className="ml-2" variant="secondary">
Beta
</Badge>
)}
</div>
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</div>
</div>
);
}
export default ModelSettingsSection;