Compare commits

..

No commits in common. "main" and "multiple-generations" have entirely different histories.

11 changed files with 132 additions and 64 deletions

View File

@ -33,6 +33,10 @@ We also just added experimental support for taking a video/screen recording of a
[Follow me on Twitter for updates](https://twitter.com/_abi_). [Follow me on Twitter for updates](https://twitter.com/_abi_).
## Sponsors
<a href="https://konghq.com/products/kong-konnect?utm_medium=referral&utm_source=github&utm_campaign=platform&utm_content=screenshot-to-code" target="_blank" title="Kong - powering the API world"><img src="https://picoapps.xyz/s2c-sponsors/Kong-GitHub-240x100.png"></a>
## 🚀 Hosted Version ## 🚀 Hosted Version
[Try it live on the hosted version (paid)](https://screenshottocode.com). [Try it live on the hosted version (paid)](https://screenshottocode.com).

View File

@ -7,19 +7,19 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: check-yaml - id: check-yaml
- id: check-added-large-files - id: check-added-large-files
# - repo: local - repo: local
# hooks: hooks:
# - id: poetry-pytest - id: poetry-pytest
# name: Run pytest with Poetry name: Run pytest with Poetry
# entry: poetry run --directory backend pytest entry: poetry run --directory backend pytest
language: system
pass_filenames: false
always_run: true
files: ^backend/
# - id: poetry-pyright
# name: Run pyright with Poetry
# entry: poetry run --directory backend pyright
# language: system # language: system
# pass_filenames: false # pass_filenames: false
# always_run: true # always_run: true
# files: ^backend/ # files: ^backend/
# # - id: poetry-pyright
# # name: Run pyright with Poetry
# # entry: poetry run --directory backend pyright
# # language: system
# # pass_filenames: false
# # always_run: true
# # files: ^backend/

View File

@ -11,7 +11,7 @@ async def process_tasks(
prompts: List[str], prompts: List[str],
api_key: str, api_key: str,
base_url: str | None, base_url: str | None,
model: Literal["dalle3", "flux"], model: Literal["dalle3", "sdxl-lightning"],
): ):
import time import time
@ -54,14 +54,18 @@ async def generate_image_dalle(
async def generate_image_replicate(prompt: str, api_key: str) -> str: async def generate_image_replicate(prompt: str, api_key: str) -> str:
# We use Flux Schnell # We use SDXL Lightning
return await call_replicate( return await call_replicate(
"5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
{ {
"width": 1024,
"height": 1024,
"prompt": prompt, "prompt": prompt,
"scheduler": "K_EULER",
"num_outputs": 1, "num_outputs": 1,
"aspect_ratio": "1:1", "guidance_scale": 0,
"output_format": "png", "negative_prompt": "worst quality, low quality",
"output_quality": 100, "num_inference_steps": 4,
}, },
api_key, api_key,
) )
@ -98,7 +102,7 @@ async def generate_images(
api_key: str, api_key: str,
base_url: Union[str, None], base_url: Union[str, None],
image_cache: Dict[str, str], image_cache: Dict[str, str],
model: Literal["dalle3", "flux"] = "dalle3", model: Literal["dalle3", "sdxl-lightning"] = "dalle3",
) -> str: ) -> str:
# Find all images # Find all images
soup = BeautifulSoup(code, "html.parser") soup = BeautifulSoup(code, "html.parser")

View File

@ -2,21 +2,20 @@ import asyncio
import httpx import httpx
async def call_replicate(input: dict[str, str | int], api_token: str) -> str: async def call_replicate(
replicate_model_version: str, input: dict[str, str | int], api_token: str
) -> str:
url = "https://api.replicate.com/v1/predictions"
headers = { headers = {
"Authorization": f"Bearer {api_token}", "Authorization": f"Bearer {api_token}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
data = {"input": input} data = {"version": replicate_model_version, "input": input}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: try:
response = await client.post( response = await client.post(url, headers=headers, json=data)
"https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions",
headers=headers,
json=data,
)
response.raise_for_status() response.raise_for_status()
response_json = response.json() response_json = response.json()
@ -25,18 +24,16 @@ async def call_replicate(input: dict[str, str | int], api_token: str) -> str:
if not prediction_id: if not prediction_id:
raise ValueError("Prediction ID not found in initial response.") raise ValueError("Prediction ID not found in initial response.")
# Polling every 0.1 seconds until the status is succeeded or error (upto 10s) # Polling every 1 second until the status is succeeded or error
num_polls = 0 num_polls = 0
max_polls = 100 max_polls = 100
while num_polls < max_polls: while num_polls < max_polls:
num_polls += 1 num_polls += 1
await asyncio.sleep(0.1) await asyncio.sleep(0.2)
# Check the status # Check the status
status_check_url = ( status_check_url = f"{url}/{prediction_id}"
f"https://api.replicate.com/v1/predictions/{prediction_id}"
)
status_response = await client.get(status_check_url, headers=headers) status_response = await client.get(status_check_url, headers=headers)
status_response.raise_for_status() status_response.raise_for_status()
status_response_json = status_response.json() status_response_json = status_response.json()

View File

@ -1,6 +1,5 @@
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
import traceback
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
import openai import openai
from codegen.utils import extract_html_content from codegen.utils import extract_html_content
@ -64,7 +63,7 @@ async def perform_image_generation(
return completion return completion
if replicate_api_key: if replicate_api_key:
image_generation_model = "flux" image_generation_model = "sdxl-lightning"
api_key = replicate_api_key api_key = replicate_api_key
else: else:
if not openai_api_key: if not openai_api_key:
@ -272,7 +271,7 @@ async def stream_code(websocket: WebSocket):
# we decide which models to run # we decide which models to run
variant_models = [] variant_models = []
if openai_api_key and anthropic_api_key: if openai_api_key and anthropic_api_key:
variant_models = ["anthropic", "openai"] variant_models = ["openai", "anthropic"]
elif openai_api_key: elif openai_api_key:
variant_models = ["openai", "openai"] variant_models = ["openai", "openai"]
elif anthropic_api_key: elif anthropic_api_key:
@ -313,29 +312,7 @@ async def stream_code(websocket: WebSocket):
) )
) )
# Run the models in parallel and capture exceptions if any completions = await asyncio.gather(*tasks)
completions = await asyncio.gather(*tasks, return_exceptions=True)
# If all generations failed, throw an error
all_generations_failed = all(
isinstance(completion, Exception) for completion in completions
)
if all_generations_failed:
await throw_error("Error generating code. Please contact support.")
# Print the all the underlying exceptions for debugging
for completion in completions:
traceback.print_exception(
type(completion), completion, completion.__traceback__
)
raise Exception("All generations failed")
# If some completions failed, replace them with empty strings
for index, completion in enumerate(completions):
if isinstance(completion, Exception):
completions[index] = ""
print("Generation failed for variant", index)
print("Models used for generation: ", variant_models) print("Models used for generation: ", variant_models)
except openai.AuthenticationError as e: except openai.AuthenticationError as e:

View File

@ -41,7 +41,7 @@ OUTPUT_DIR: str = "generated_images"
async def generate_and_save_images( async def generate_and_save_images(
prompts: List[str], prompts: List[str],
model: Literal["dalle3", "flux"], model: Literal["dalle3", "sdxl-lightning"],
api_key: Optional[str], api_key: Optional[str],
) -> None: ) -> None:
# Ensure the output directory exists # Ensure the output directory exists
@ -64,7 +64,7 @@ async def generate_and_save_images(
image_data: bytes = await response.read() image_data: bytes = await response.read()
# Save the image with a filename based on the input eval # Save the image with a filename based on the input eval
prefix = "replicate_" if model == "flux" else "dalle3_" prefix = "replicate_" if model == "sdxl-lightning" else "dalle3_"
filename: str = ( filename: str = (
f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png" f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png"
) )
@ -78,7 +78,7 @@ async def generate_and_save_images(
async def main() -> None: async def main() -> None:
# await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY) # await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY)
await generate_and_save_images(EVALS, "flux", REPLICATE_API_TOKEN) await generate_and_save_images(EVALS, "sdxl-lightning", REPLICATE_API_TOKEN)
if __name__ == "__main__": if __name__ == "__main__":

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@ -358,7 +358,11 @@ function App() {
</div> </div>
{/* Generation settings like stack and model */} {/* Generation settings like stack and model */}
<GenerationSettings settings={settings} setSettings={setSettings} /> <GenerationSettings
settings={settings}
setSettings={setSettings}
selectedCodeGenerationModel={model}
/>
{/* Show auto updated message when older models are choosen */} {/* Show auto updated message when older models are choosen */}
{showBetterModelMessage && <DeprecationMessage />} {showBetterModelMessage && <DeprecationMessage />}

View File

@ -2,16 +2,20 @@ import React from "react";
import { useAppStore } from "../../store/app-store"; import { useAppStore } from "../../store/app-store";
import { AppState, Settings } from "../../types"; import { AppState, Settings } from "../../types";
import OutputSettingsSection from "./OutputSettingsSection"; import OutputSettingsSection from "./OutputSettingsSection";
import ModelSettingsSection from "./ModelSettingsSection";
import { Stack } from "../../lib/stacks"; import { Stack } from "../../lib/stacks";
import { CodeGenerationModel } from "../../lib/models";
interface GenerationSettingsProps { interface GenerationSettingsProps {
settings: Settings; settings: Settings;
setSettings: React.Dispatch<React.SetStateAction<Settings>>; setSettings: React.Dispatch<React.SetStateAction<Settings>>;
selectedCodeGenerationModel: CodeGenerationModel;
} }
export const GenerationSettings: React.FC<GenerationSettingsProps> = ({ export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
settings, settings,
setSettings, setSettings,
selectedCodeGenerationModel,
}) => { }) => {
const { appState } = useAppStore(); const { appState } = useAppStore();
@ -22,6 +26,13 @@ export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
})); }));
} }
function setCodeGenerationModel(codeGenerationModel: CodeGenerationModel) {
setSettings((prev: Settings) => ({
...prev,
codeGenerationModel,
}));
}
const shouldDisableUpdates = const shouldDisableUpdates =
appState === AppState.CODING || appState === AppState.CODE_READY; appState === AppState.CODING || appState === AppState.CODE_READY;
@ -32,6 +43,12 @@ export const GenerationSettings: React.FC<GenerationSettingsProps> = ({
setStack={setStack} setStack={setStack}
shouldDisableUpdates={shouldDisableUpdates} shouldDisableUpdates={shouldDisableUpdates}
/> />
<ModelSettingsSection
codeGenerationModel={selectedCodeGenerationModel}
setCodeGenerationModel={setCodeGenerationModel}
shouldDisableUpdates={shouldDisableUpdates}
/>
</div> </div>
); );
}; };

View File

@ -0,0 +1,65 @@
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectTrigger,
} from "../ui/select";
import {
CODE_GENERATION_MODEL_DESCRIPTIONS,
CodeGenerationModel,
} from "../../lib/models";
import { Badge } from "../ui/badge";
interface Props {
codeGenerationModel: CodeGenerationModel;
setCodeGenerationModel: (codeGenerationModel: CodeGenerationModel) => void;
shouldDisableUpdates?: boolean;
}
function ModelSettingsSection({
codeGenerationModel,
setCodeGenerationModel,
shouldDisableUpdates = false,
}: Props) {
return (
<div className="flex flex-col gap-y-2 justify-between text-sm">
<div className="grid grid-cols-3 items-center gap-4">
<span>AI Model:</span>
<Select
value={codeGenerationModel}
onValueChange={(value: string) =>
setCodeGenerationModel(value as CodeGenerationModel)
}
disabled={shouldDisableUpdates}
>
<SelectTrigger className="col-span-2" id="output-settings-js">
<span className="font-semibold">
{CODE_GENERATION_MODEL_DESCRIPTIONS[codeGenerationModel].name}
</span>
</SelectTrigger>
<SelectContent>
<SelectGroup>
{Object.values(CodeGenerationModel).map((model) => (
<SelectItem key={model} value={model}>
<div className="flex items-center">
<span className="font-semibold">
{CODE_GENERATION_MODEL_DESCRIPTIONS[model].name}
</span>
{CODE_GENERATION_MODEL_DESCRIPTIONS[model].inBeta && (
<Badge className="ml-2" variant="secondary">
Beta
</Badge>
)}
</div>
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</div>
</div>
);
}
export default ModelSettingsSection;