update sdxl-lightning references to flux

This commit is contained in:
Abi Raja 2024-09-20 16:45:16 +02:00
parent 6899c7792e
commit a53afd350d
3 changed files with 6 additions and 6 deletions

View File

@ -11,7 +11,7 @@ async def process_tasks(
prompts: List[str],
api_key: str,
base_url: str | None,
model: Literal["dalle3", "sdxl-lightning"],
model: Literal["dalle3", "flux"],
):
import time
@ -98,7 +98,7 @@ async def generate_images(
api_key: str,
base_url: Union[str, None],
image_cache: Dict[str, str],
model: Literal["dalle3", "sdxl-lightning"] = "dalle3",
model: Literal["dalle3", "flux"] = "dalle3",
) -> str:
# Find all images
soup = BeautifulSoup(code, "html.parser")

View File

@ -64,7 +64,7 @@ async def perform_image_generation(
return completion
if replicate_api_key:
image_generation_model = "sdxl-lightning"
image_generation_model = "flux"
api_key = replicate_api_key
else:
if not openai_api_key:

View File

@ -41,7 +41,7 @@ OUTPUT_DIR: str = "generated_images"
async def generate_and_save_images(
prompts: List[str],
model: Literal["dalle3", "sdxl-lightning"],
model: Literal["dalle3", "flux"],
api_key: Optional[str],
) -> None:
# Ensure the output directory exists
@ -64,7 +64,7 @@ async def generate_and_save_images(
image_data: bytes = await response.read()
# Save the image with a filename based on the input eval
prefix = "replicate_" if model == "sdxl-lightning" else "dalle3_"
prefix = "replicate_" if model == "flux" else "dalle3_"
filename: str = (
f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png"
)
@ -78,7 +78,7 @@ async def generate_and_save_images(
async def main() -> None:
# await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY)
await generate_and_save_images(EVALS, "sdxl-lightning", REPLICATE_API_TOKEN)
await generate_and_save_images(EVALS, "flux", REPLICATE_API_TOKEN)
if __name__ == "__main__":