Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a53afd350d | ||
|
|
6899c7792e |
@ -11,7 +11,7 @@ async def process_tasks(
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
api_key: str,
|
api_key: str,
|
||||||
base_url: str | None,
|
base_url: str | None,
|
||||||
model: Literal["dalle3", "sdxl-lightning"],
|
model: Literal["dalle3", "flux"],
|
||||||
):
|
):
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ async def generate_images(
|
|||||||
api_key: str,
|
api_key: str,
|
||||||
base_url: Union[str, None],
|
base_url: Union[str, None],
|
||||||
image_cache: Dict[str, str],
|
image_cache: Dict[str, str],
|
||||||
model: Literal["dalle3", "sdxl-lightning"] = "dalle3",
|
model: Literal["dalle3", "flux"] = "dalle3",
|
||||||
) -> str:
|
) -> str:
|
||||||
# Find all images
|
# Find all images
|
||||||
soup = BeautifulSoup(code, "html.parser")
|
soup = BeautifulSoup(code, "html.parser")
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import traceback
|
||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, WebSocket
|
||||||
import openai
|
import openai
|
||||||
from codegen.utils import extract_html_content
|
from codegen.utils import extract_html_content
|
||||||
@ -63,7 +64,7 @@ async def perform_image_generation(
|
|||||||
return completion
|
return completion
|
||||||
|
|
||||||
if replicate_api_key:
|
if replicate_api_key:
|
||||||
image_generation_model = "sdxl-lightning"
|
image_generation_model = "flux"
|
||||||
api_key = replicate_api_key
|
api_key = replicate_api_key
|
||||||
else:
|
else:
|
||||||
if not openai_api_key:
|
if not openai_api_key:
|
||||||
@ -321,6 +322,12 @@ async def stream_code(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
if all_generations_failed:
|
if all_generations_failed:
|
||||||
await throw_error("Error generating code. Please contact support.")
|
await throw_error("Error generating code. Please contact support.")
|
||||||
|
|
||||||
|
# Print the all the underlying exceptions for debugging
|
||||||
|
for completion in completions:
|
||||||
|
traceback.print_exception(
|
||||||
|
type(completion), completion, completion.__traceback__
|
||||||
|
)
|
||||||
raise Exception("All generations failed")
|
raise Exception("All generations failed")
|
||||||
|
|
||||||
# If some completions failed, replace them with empty strings
|
# If some completions failed, replace them with empty strings
|
||||||
|
|||||||
@ -41,7 +41,7 @@ OUTPUT_DIR: str = "generated_images"
|
|||||||
|
|
||||||
async def generate_and_save_images(
|
async def generate_and_save_images(
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
model: Literal["dalle3", "sdxl-lightning"],
|
model: Literal["dalle3", "flux"],
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
# Ensure the output directory exists
|
# Ensure the output directory exists
|
||||||
@ -64,7 +64,7 @@ async def generate_and_save_images(
|
|||||||
image_data: bytes = await response.read()
|
image_data: bytes = await response.read()
|
||||||
|
|
||||||
# Save the image with a filename based on the input eval
|
# Save the image with a filename based on the input eval
|
||||||
prefix = "replicate_" if model == "sdxl-lightning" else "dalle3_"
|
prefix = "replicate_" if model == "flux" else "dalle3_"
|
||||||
filename: str = (
|
filename: str = (
|
||||||
f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png"
|
f"{prefix}{prompts[i][:50].replace(' ', '_').replace(':', '')}.png"
|
||||||
)
|
)
|
||||||
@ -78,7 +78,7 @@ async def generate_and_save_images(
|
|||||||
|
|
||||||
async def main() -> None:
|
async def main() -> None:
|
||||||
# await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY)
|
# await generate_and_save_images(EVALS, "dalle3", OPENAI_API_KEY)
|
||||||
await generate_and_save_images(EVALS, "sdxl-lightning", REPLICATE_API_TOKEN)
|
await generate_and_save_images(EVALS, "flux", REPLICATE_API_TOKEN)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user