Compare commits
No commits in common. "main" and "multiple-generations" have entirely different histories.
main
...
multiple-g
@ -33,6 +33,10 @@ We also just added experimental support for taking a video/screen recording of a
|
|||||||
|
|
||||||
[Follow me on Twitter for updates](https://twitter.com/_abi_).
|
[Follow me on Twitter for updates](https://twitter.com/_abi_).
|
||||||
|
|
||||||
|
## Sponsors
|
||||||
|
|
||||||
|
<a href="https://konghq.com/products/kong-konnect?utm_medium=referral&utm_source=github&utm_campaign=platform&utm_content=screenshot-to-code" target="_blank" title="Kong - powering the API world"><img src="https://picoapps.xyz/s2c-sponsors/Kong-GitHub-240x100.png"></a>
|
||||||
|
|
||||||
## 🚀 Hosted Version
|
## 🚀 Hosted Version
|
||||||
|
|
||||||
[Try it live on the hosted version (paid)](https://screenshottocode.com).
|
[Try it live on the hosted version (paid)](https://screenshottocode.com).
|
||||||
|
|||||||
@ -7,19 +7,19 @@ repos:
|
|||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
# - repo: local
|
- repo: local
|
||||||
# hooks:
|
hooks:
|
||||||
# - id: poetry-pytest
|
- id: poetry-pytest
|
||||||
# name: Run pytest with Poetry
|
name: Run pytest with Poetry
|
||||||
# entry: poetry run --directory backend pytest
|
entry: poetry run --directory backend pytest
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
always_run: true
|
||||||
|
files: ^backend/
|
||||||
|
# - id: poetry-pyright
|
||||||
|
# name: Run pyright with Poetry
|
||||||
|
# entry: poetry run --directory backend pyright
|
||||||
# language: system
|
# language: system
|
||||||
# pass_filenames: false
|
# pass_filenames: false
|
||||||
# always_run: true
|
# always_run: true
|
||||||
# files: ^backend/
|
# files: ^backend/
|
||||||
# # - id: poetry-pyright
|
|
||||||
# # name: Run pyright with Poetry
|
|
||||||
# # entry: poetry run --directory backend pyright
|
|
||||||
# # language: system
|
|
||||||
# # pass_filenames: false
|
|
||||||
# # always_run: true
|
|
||||||
# # files: ^backend/
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ async def process_tasks(
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
api_key: str,
|
api_key: str,
|
||||||
base_url: str | None,
|
base_url: str | None,
|
||||||
model: Literal["dalle3", "flux"],
|
model: Literal["dalle3", "sdxl-lightning"],
|
||||||
):
|
):
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -54,14 +54,18 @@ async def generate_image_dalle(
|
|||||||
|
|
||||||
async def generate_image_replicate(prompt: str, api_key: str) -> str:
|
async def generate_image_replicate(prompt: str, api_key: str) -> str:
|
||||||
|
|
||||||
# We use Flux Schnell
|
# We use SDXL Lightning
|
||||||
return await call_replicate(
|
return await call_replicate(
|
||||||
|
"5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
|
||||||
{
|
{
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
"scheduler": "K_EULER",
|
||||||
"num_outputs": 1,
|
"num_outputs": 1,
|
||||||
"aspect_ratio": "1:1",
|
"guidance_scale": 0,
|
||||||
"output_format": "png",
|
"negative_prompt": "worst quality, low quality",
|
||||||
"output_quality": 100,
|
"num_inference_steps": 4,
|
||||||
},
|
},
|
||||||
api_key,
|
api_key,
|
||||||
)
|
)
|
||||||
@ -98,7 +102,7 @@ async def generate_images(
|
|||||||
api_key: str,
|
api_key: str,
|
||||||
base_url: Union[str, None],
|
base_url: Union[str, None],
|
||||||
image_cache: Dict[str, str],
|
image_cache: Dict[str, str],
|
||||||
model: Literal["dalle3", "flux"] = "dalle3",
|
model: Literal["dalle3", "sdxl-lightning"] = "dalle3",
|
||||||
) -> str:
|
) -> str:
|
||||||
# Find all images
|
# Find all images
|
||||||
soup = BeautifulSoup(code, "html.parser")
|
soup = BeautifulSoup(code, "html.parser")
|
||||||
|
|||||||
@ -2,21 +2,20 @@ import asyncio
|
|||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
async def call_replicate(input: dict[str, str | int], api_token: str) -> str:
|
async def call_replicate(
|
||||||
|
replicate_model_version: str, input: dict[str, str | int], api_token: str
|
||||||
|
) -> str:
|
||||||
|
url = "https://api.replicate.com/v1/predictions"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {api_token}",
|
"Authorization": f"Bearer {api_token}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
data = {"input": input}
|
data = {"version": replicate_model_version, "input": input}
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post(url, headers=headers, json=data)
|
||||||
"https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions",
|
|
||||||
headers=headers,
|
|
||||||
json=data,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
@ -25,18 +24,16 @@ async def call_replicate(input: dict[str, str | int], api_token: str) -> str:
|
|||||||
if not prediction_id:
|
if not prediction_id:
|
||||||
raise ValueError("Prediction ID not found in initial response.")
|
raise ValueError("Prediction ID not found in initial response.")
|
||||||
|
|
||||||
# Polling every 0.1 seconds until the status is succeeded or error (upto 10s)
|
# Polling every 1 second until the status is succeeded or error
|
||||||
num_polls = 0
|
num_polls = 0
|
||||||
max_polls = 100
|
max_polls = 100
|
||||||
while num_polls < max_polls:
|
while num_polls < max_polls:
|
||||||
num_polls += 1
|
num_polls += 1
|
||||||
|
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
# Check the status
|
# Check the status
|
||||||
status_check_url = (
|
status_check_url = f"{url}/{prediction_id}"
|
||||||
f"https://api.replicate.com/v1/predictions/{prediction_id}"
|
|
||||||
)
|
|
||||||
status_response = await client.get(status_check_url, headers=headers)
|
status_response = await client.get(status_check_url, headers=headers)
|
||||||
status_response.raise_for_status()
|
status_response.raise_for_status()
|
||||||
status_response_json = status_response.json()
|
status_response_json = status_response.json()
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import traceback
|
|
||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, WebSocket
|
||||||
import openai
|
import openai
|
||||||
from codegen.utils import extract_html_content
|
from codegen.utils import extract_html_content
|
||||||
@ -64,7 +63,7 @@ async def perform_image_generation(
|
|||||||
return completion
|
return completion
|
||||||
|
|
||||||
if replicate_api_key:
|
if replicate_api_key:
|
||||||
image_generation_model = "flux"
|
image_generation_model = "sdxl-lightning"
|
||||||
api_key = replicate_api_key
|
api_key = replicate_api_key
|
||||||
else:
|
else:
|
||||||
if not openai_api_key:
|
if not openai_api_key:
|
||||||
@ -272,7 +271,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
# we decide which models to run
|
# we decide which models to run
|
||||||
variant_models = []
|
variant_models = []
|
||||||
if openai_api_key and anthropic_api_key:
|
if openai_api_key and anthropic_api_key:
|
||||||
variant_models = ["anthropic", "openai"]
|
variant_models = ["openai", "anthropic"]
|
||||||
elif openai_api_key:
|
elif openai_api_key:
|
||||||
variant_models = ["openai", "openai"]
|
variant_models = ["openai", "openai"]
|
||||||
elif anthropic_api_key:
|
elif anthropic_api_key:
|
||||||
@ -313,29 +312,7 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the models in parallel and capture exceptions if any
|
completions = await asyncio.gather(*tasks)
|
||||||
completions = await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
# If all generations failed, throw an error
|
|
||||||
all_generations_failed = all(
|
|
||||||
isinstance(completion, Exception) for completion in completions
|
|
||||||
)
|
|
||||||
if all_generations_failed:
|
|
||||||
await throw_error("Error generating code. Please contact support.")
|
|
||||||
|
|
||||||
# Print the all the underlying exceptions for debugging
|
|
||||||
for completion in completions:
|
|
||||||
traceback.print_exception(
|
|
||||||
type(completion), completion, completion.__traceback__
|
|
||||||
)
|
|
||||||
raise Exception("All generations failed")
|
|
||||||
|
|
||||||
# If some completions failed, replace them with empty strings
|
|
||||||
for index, completion in enumerate(completions):
|
|
||||||
if isinstance(completion, Exception):
|
|
||||||
completions[index] = ""
|
|
||||||
print("Generation failed for variant", index)
|
|
||||||
|
|
||||||
print("Models used for generation: ", variant_models)
|
print("Models used for generation: ", variant_models)
|
||||||
|
|
||||||
except openai.AuthenticationError as e:
|
except openai.AuthenticationError as e:
|
||||||
|
|||||||
@ -41,7 +41,7 @@ OUTPUT_DIR: str = "generated_images"
|
|||||||
|
|
||||||
async def generate_and_save_images(
|
async def generate_and_save_images(
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
model: Literal["dalle3", "flux"],
|
model: Literal["dalle3", "sdxl-lightning"],
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
# Ensure the output directory exists
|
# Ensure the output directory exists
|
||||||
@ -64,7 +64,7 @@ async def generate_and_save_images(
|
|||||||
image_data: bytes = await response.read()
|
image_data: bytes = await response.read()
|
||||||
|
|
||||||
# Save the image with a filename based on the input eval
|
# Save the image with a filename based on the input eval
|
||||||
prefix = "replicate_" if model == "flux" else "dalle3_"
|
prefix = "replicate_" if model == "sdxl-lightning" else "dalle3_"
|
||||||
filename: str = (
|
filename: str = (
|
||||||
f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png"
|
f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png"
|
||||||
)
|
)
|
||||||
@ -78,7 +78,7 @@ async def generate_and_save_images(
|
|||||||
|
|
||||||
async def main() -> None:
|
async def main() -> None:
|
||||||
# await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY)
|
# await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY)
|
||||||
await generate_and_save_images(EVALS, "flux", REPLICATE_API_TOKEN)
|
await generate_and_save_images(EVALS, "sdxl-lightning", REPLICATE_API_TOKEN)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 2.2 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 1.6 KiB |
@ -358,7 +358,11 @@ function App() {
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Generation settings like stack and model */}
|
{/* Generation settings like stack and model */}
|
||||||
<GenerationSettings settings={settings} setSettings={setSettings} />
|
<GenerationSettings
|
||||||
|
settings={settings}
|
||||||
|
setSettings={setSettings}
|
||||||
|
selectedCodeGenerationModel={model}
|
||||||
|
/>
|
||||||
|
|
||||||
{/* Show auto updated message when older models are choosen */}
|
{/* Show auto updated message when older models are choosen */}
|
||||||
{showBetterModelMessage && <DeprecationMessage />}
|
{showBetterModelMessage && <DeprecationMessage />}
|
||||||
|
|||||||
@ -2,16 +2,20 @@ import React from "react";
|
|||||||
import { useAppStore } from "../../store/app-store";
|
import { useAppStore } from "../../store/app-store";
|
||||||
import { AppState, Settings } from "../../types";
|
import { AppState, Settings } from "../../types";
|
||||||
import OutputSettingsSection from "./OutputSettingsSection";
|
import OutputSettingsSection from "./OutputSettingsSection";
|
||||||
|
import ModelSettingsSection from "./ModelSettingsSection";
|
||||||
import { Stack } from "../../lib/stacks";
|
import { Stack } from "../../lib/stacks";
|
||||||
|
import { CodeGenerationModel } from "../../lib/models";
|
||||||
|
|
||||||
interface GenerationSettingsProps {
|
interface GenerationSettingsProps {
|
||||||
settings: Settings;
|
settings: Settings;
|
||||||
setSettings: React.Dispatch<React.SetStateAction<Settings>>;
|
setSettings: React.Dispatch<React.SetStateAction<Settings>>;
|
||||||
|
selectedCodeGenerationModel: CodeGenerationModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
|
export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
|
||||||
settings,
|
settings,
|
||||||
setSettings,
|
setSettings,
|
||||||
|
selectedCodeGenerationModel,
|
||||||
}) => {
|
}) => {
|
||||||
const { appState } = useAppStore();
|
const { appState } = useAppStore();
|
||||||
|
|
||||||
@ -22,6 +26,13 @@ export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function setCodeGenerationModel(codeGenerationModel: CodeGenerationModel) {
|
||||||
|
setSettings((prev: Settings) => ({
|
||||||
|
...prev,
|
||||||
|
codeGenerationModel,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
const shouldDisableUpdates =
|
const shouldDisableUpdates =
|
||||||
appState === AppState.CODING || appState === AppState.CODE_READY;
|
appState === AppState.CODING || appState === AppState.CODE_READY;
|
||||||
|
|
||||||
@ -32,6 +43,12 @@ export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
|
|||||||
setStack={setStack}
|
setStack={setStack}
|
||||||
shouldDisableUpdates={shouldDisableUpdates}
|
shouldDisableUpdates={shouldDisableUpdates}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<ModelSettingsSection
|
||||||
|
codeGenerationModel={selectedCodeGenerationModel}
|
||||||
|
setCodeGenerationModel={setCodeGenerationModel}
|
||||||
|
shouldDisableUpdates={shouldDisableUpdates}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
65
frontend/src/components/settings/ModelSettingsSection.tsx
Normal file
65
frontend/src/components/settings/ModelSettingsSection.tsx
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import {
|
||||||
|
Select,
|
||||||
|
SelectContent,
|
||||||
|
SelectGroup,
|
||||||
|
SelectItem,
|
||||||
|
SelectTrigger,
|
||||||
|
} from "../ui/select";
|
||||||
|
import {
|
||||||
|
CODE_GENERATION_MODEL_DESCRIPTIONS,
|
||||||
|
CodeGenerationModel,
|
||||||
|
} from "../../lib/models";
|
||||||
|
import { Badge } from "../ui/badge";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
codeGenerationModel: CodeGenerationModel;
|
||||||
|
setCodeGenerationModel: (codeGenerationModel: CodeGenerationModel) => void;
|
||||||
|
shouldDisableUpdates?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
function ModelSettingsSection({
|
||||||
|
codeGenerationModel,
|
||||||
|
setCodeGenerationModel,
|
||||||
|
shouldDisableUpdates = false,
|
||||||
|
}: Props) {
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-y-2 justify-between text-sm">
|
||||||
|
<div className="grid grid-cols-3 items-center gap-4">
|
||||||
|
<span>AI Model:</span>
|
||||||
|
<Select
|
||||||
|
value={codeGenerationModel}
|
||||||
|
onValueChange={(value: string) =>
|
||||||
|
setCodeGenerationModel(value as CodeGenerationModel)
|
||||||
|
}
|
||||||
|
disabled={shouldDisableUpdates}
|
||||||
|
>
|
||||||
|
<SelectTrigger className="col-span-2" id="output-settings-js">
|
||||||
|
<span className="font-semibold">
|
||||||
|
{CODE_GENERATION_MODEL_DESCRIPTIONS[codeGenerationModel].name}
|
||||||
|
</span>
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent>
|
||||||
|
<SelectGroup>
|
||||||
|
{Object.values(CodeGenerationModel).map((model) => (
|
||||||
|
<SelectItem key={model} value={model}>
|
||||||
|
<div className="flex items-center">
|
||||||
|
<span className="font-semibold">
|
||||||
|
{CODE_GENERATION_MODEL_DESCRIPTIONS[model].name}
|
||||||
|
</span>
|
||||||
|
{CODE_GENERATION_MODEL_DESCRIPTIONS[model].inBeta && (
|
||||||
|
<Badge className="ml-2" variant="secondary">
|
||||||
|
Beta
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectGroup>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default ModelSettingsSection;
|
||||||
Loading…
Reference in New Issue
Block a user