Compare commits
7 Commits
hosted-mul
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a53afd350d | ||
|
|
6899c7792e | ||
|
|
9199fee21d | ||
|
|
8717298def | ||
|
|
a4087e613f | ||
|
|
cf6b94d675 | ||
|
|
589507846b |
@ -7,19 +7,19 @@ repos:
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: poetry-pytest
|
||||
name: Run pytest with Poetry
|
||||
entry: poetry run --directory backend pytest
|
||||
language: system
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
files: ^backend/
|
||||
# - id: poetry-pyright
|
||||
# name: Run pyright with Poetry
|
||||
# entry: poetry run --directory backend pyright
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
# always_run: true
|
||||
# files: ^backend/
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: poetry-pytest
|
||||
# name: Run pytest with Poetry
|
||||
# entry: poetry run --directory backend pytest
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
# always_run: true
|
||||
# files: ^backend/
|
||||
# # - id: poetry-pyright
|
||||
# # name: Run pyright with Poetry
|
||||
# # entry: poetry run --directory backend pyright
|
||||
# # language: system
|
||||
# # pass_filenames: false
|
||||
# # always_run: true
|
||||
# # files: ^backend/
|
||||
|
||||
@ -11,7 +11,7 @@ async def process_tasks(
|
||||
prompts: List[str],
|
||||
api_key: str,
|
||||
base_url: str | None,
|
||||
model: Literal["dalle3", "sdxl-lightning"],
|
||||
model: Literal["dalle3", "flux"],
|
||||
):
|
||||
import time
|
||||
|
||||
@ -54,18 +54,14 @@ async def generate_image_dalle(
|
||||
|
||||
async def generate_image_replicate(prompt: str, api_key: str) -> str:
|
||||
|
||||
# We use SDXL Lightning
|
||||
# We use Flux Schnell
|
||||
return await call_replicate(
|
||||
"5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
|
||||
{
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"prompt": prompt,
|
||||
"scheduler": "K_EULER",
|
||||
"num_outputs": 1,
|
||||
"guidance_scale": 0,
|
||||
"negative_prompt": "worst quality, low quality",
|
||||
"num_inference_steps": 4,
|
||||
"aspect_ratio": "1:1",
|
||||
"output_format": "png",
|
||||
"output_quality": 100,
|
||||
},
|
||||
api_key,
|
||||
)
|
||||
@ -102,7 +98,7 @@ async def generate_images(
|
||||
api_key: str,
|
||||
base_url: Union[str, None],
|
||||
image_cache: Dict[str, str],
|
||||
model: Literal["dalle3", "sdxl-lightning"] = "dalle3",
|
||||
model: Literal["dalle3", "flux"] = "dalle3",
|
||||
) -> str:
|
||||
# Find all images
|
||||
soup = BeautifulSoup(code, "html.parser")
|
||||
|
||||
@ -2,20 +2,21 @@ import asyncio
|
||||
import httpx
|
||||
|
||||
|
||||
async def call_replicate(
|
||||
replicate_model_version: str, input: dict[str, str | int], api_token: str
|
||||
) -> str:
|
||||
url = "https://api.replicate.com/v1/predictions"
|
||||
async def call_replicate(input: dict[str, str | int], api_token: str) -> str:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {"version": replicate_model_version, "input": input}
|
||||
data = {"input": input}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response = await client.post(
|
||||
"https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions",
|
||||
headers=headers,
|
||||
json=data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
@ -24,16 +25,18 @@ async def call_replicate(
|
||||
if not prediction_id:
|
||||
raise ValueError("Prediction ID not found in initial response.")
|
||||
|
||||
# Polling every 1 second until the status is succeeded or error
|
||||
# Polling every 0.1 seconds until the status is succeeded or error (upto 10s)
|
||||
num_polls = 0
|
||||
max_polls = 100
|
||||
while num_polls < max_polls:
|
||||
num_polls += 1
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Check the status
|
||||
status_check_url = f"{url}/{prediction_id}"
|
||||
status_check_url = (
|
||||
f"https://api.replicate.com/v1/predictions/{prediction_id}"
|
||||
)
|
||||
status_response = await client.get(status_check_url, headers=headers)
|
||||
status_response.raise_for_status()
|
||||
status_response_json = status_response.json()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
import traceback
|
||||
from fastapi import APIRouter, WebSocket
|
||||
import openai
|
||||
from codegen.utils import extract_html_content
|
||||
@ -63,7 +64,7 @@ async def perform_image_generation(
|
||||
return completion
|
||||
|
||||
if replicate_api_key:
|
||||
image_generation_model = "sdxl-lightning"
|
||||
image_generation_model = "flux"
|
||||
api_key = replicate_api_key
|
||||
else:
|
||||
if not openai_api_key:
|
||||
@ -271,7 +272,7 @@ async def stream_code(websocket: WebSocket):
|
||||
# we decide which models to run
|
||||
variant_models = []
|
||||
if openai_api_key and anthropic_api_key:
|
||||
variant_models = ["openai", "anthropic"]
|
||||
variant_models = ["anthropic", "openai"]
|
||||
elif openai_api_key:
|
||||
variant_models = ["openai", "openai"]
|
||||
elif anthropic_api_key:
|
||||
@ -312,7 +313,29 @@ async def stream_code(websocket: WebSocket):
|
||||
)
|
||||
)
|
||||
|
||||
completions = await asyncio.gather(*tasks)
|
||||
# Run the models in parallel and capture exceptions if any
|
||||
completions = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# If all generations failed, throw an error
|
||||
all_generations_failed = all(
|
||||
isinstance(completion, Exception) for completion in completions
|
||||
)
|
||||
if all_generations_failed:
|
||||
await throw_error("Error generating code. Please contact support.")
|
||||
|
||||
# Print the all the underlying exceptions for debugging
|
||||
for completion in completions:
|
||||
traceback.print_exception(
|
||||
type(completion), completion, completion.__traceback__
|
||||
)
|
||||
raise Exception("All generations failed")
|
||||
|
||||
# If some completions failed, replace them with empty strings
|
||||
for index, completion in enumerate(completions):
|
||||
if isinstance(completion, Exception):
|
||||
completions[index] = ""
|
||||
print("Generation failed for variant", index)
|
||||
|
||||
print("Models used for generation: ", variant_models)
|
||||
|
||||
except openai.AuthenticationError as e:
|
||||
|
||||
@ -41,7 +41,7 @@ OUTPUT_DIR: str = "generated_images"
|
||||
|
||||
async def generate_and_save_images(
|
||||
prompts: List[str],
|
||||
model: Literal["dalle3", "sdxl-lightning"],
|
||||
model: Literal["dalle3", "flux"],
|
||||
api_key: Optional[str],
|
||||
) -> None:
|
||||
# Ensure the output directory exists
|
||||
@ -64,7 +64,7 @@ async def generate_and_save_images(
|
||||
image_data: bytes = await response.read()
|
||||
|
||||
# Save the image with a filename based on the input eval
|
||||
prefix = "replicate_" if model == "sdxl-lightning" else "dalle3_"
|
||||
prefix = "replicate_" if model == "flux" else "dalle3_"
|
||||
filename: str = (
|
||||
f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png"
|
||||
)
|
||||
@ -78,7 +78,7 @@ async def generate_and_save_images(
|
||||
|
||||
async def main() -> None:
|
||||
# await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY)
|
||||
await generate_and_save_images(EVALS, "sdxl-lightning", REPLICATE_API_TOKEN)
|
||||
await generate_and_save_images(EVALS, "flux", REPLICATE_API_TOKEN)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 16 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 1.6 KiB After Width: | Height: | Size: 16 KiB |
@ -358,11 +358,7 @@ function App() {
|
||||
</div>
|
||||
|
||||
{/* Generation settings like stack and model */}
|
||||
<GenerationSettings
|
||||
settings={settings}
|
||||
setSettings={setSettings}
|
||||
selectedCodeGenerationModel={model}
|
||||
/>
|
||||
<GenerationSettings settings={settings} setSettings={setSettings} />
|
||||
|
||||
{/* Show auto updated message when older models are choosen */}
|
||||
{showBetterModelMessage && <DeprecationMessage />}
|
||||
|
||||
@ -2,20 +2,16 @@ import React from "react";
|
||||
import { useAppStore } from "../../store/app-store";
|
||||
import { AppState, Settings } from "../../types";
|
||||
import OutputSettingsSection from "./OutputSettingsSection";
|
||||
import ModelSettingsSection from "./ModelSettingsSection";
|
||||
import { Stack } from "../../lib/stacks";
|
||||
import { CodeGenerationModel } from "../../lib/models";
|
||||
|
||||
interface GenerationSettingsProps {
|
||||
settings: Settings;
|
||||
setSettings: React.Dispatch<React.SetStateAction<Settings>>;
|
||||
selectedCodeGenerationModel: CodeGenerationModel;
|
||||
}
|
||||
|
||||
export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
|
||||
settings,
|
||||
setSettings,
|
||||
selectedCodeGenerationModel,
|
||||
}) => {
|
||||
const { appState } = useAppStore();
|
||||
|
||||
@ -26,13 +22,6 @@ export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
|
||||
}));
|
||||
}
|
||||
|
||||
function setCodeGenerationModel(codeGenerationModel: CodeGenerationModel) {
|
||||
setSettings((prev: Settings) => ({
|
||||
...prev,
|
||||
codeGenerationModel,
|
||||
}));
|
||||
}
|
||||
|
||||
const shouldDisableUpdates =
|
||||
appState === AppState.CODING || appState === AppState.CODE_READY;
|
||||
|
||||
@ -43,12 +32,6 @@ export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
|
||||
setStack={setStack}
|
||||
shouldDisableUpdates={shouldDisableUpdates}
|
||||
/>
|
||||
|
||||
<ModelSettingsSection
|
||||
codeGenerationModel={selectedCodeGenerationModel}
|
||||
setCodeGenerationModel={setCodeGenerationModel}
|
||||
shouldDisableUpdates={shouldDisableUpdates}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@ -1,65 +0,0 @@
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectGroup,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
} from "../ui/select";
|
||||
import {
|
||||
CODE_GENERATION_MODEL_DESCRIPTIONS,
|
||||
CodeGenerationModel,
|
||||
} from "../../lib/models";
|
||||
import { Badge } from "../ui/badge";
|
||||
|
||||
interface Props {
|
||||
codeGenerationModel: CodeGenerationModel;
|
||||
setCodeGenerationModel: (codeGenerationModel: CodeGenerationModel) => void;
|
||||
shouldDisableUpdates?: boolean;
|
||||
}
|
||||
|
||||
function ModelSettingsSection({
|
||||
codeGenerationModel,
|
||||
setCodeGenerationModel,
|
||||
shouldDisableUpdates = false,
|
||||
}: Props) {
|
||||
return (
|
||||
<div className="flex flex-col gap-y-2 justify-between text-sm">
|
||||
<div className="grid grid-cols-3 items-center gap-4">
|
||||
<span>AI Model:</span>
|
||||
<Select
|
||||
value={codeGenerationModel}
|
||||
onValueChange={(value: string) =>
|
||||
setCodeGenerationModel(value as CodeGenerationModel)
|
||||
}
|
||||
disabled={shouldDisableUpdates}
|
||||
>
|
||||
<SelectTrigger className="col-span-2" id="output-settings-js">
|
||||
<span className="font-semibold">
|
||||
{CODE_GENERATION_MODEL_DESCRIPTIONS[codeGenerationModel].name}
|
||||
</span>
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectGroup>
|
||||
{Object.values(CodeGenerationModel).map((model) => (
|
||||
<SelectItem key={model} value={model}>
|
||||
<div className="flex items-center">
|
||||
<span className="font-semibold">
|
||||
{CODE_GENERATION_MODEL_DESCRIPTIONS[model].name}
|
||||
</span>
|
||||
{CODE_GENERATION_MODEL_DESCRIPTIONS[model].inBeta && (
|
||||
<Badge className="ml-2" variant="secondary">
|
||||
Beta
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default ModelSettingsSection;
|
||||
Loading…
Reference in New Issue
Block a user