first commit
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
29
README.md
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Stable Diffusion for Remote Sensing Image Generation
|
||||||
|
|
||||||
|
#### Author: Zhiqiang yuan @ AIRCAS, [Send a Email](yuan_zhi_qiang@sina.cn)
|
||||||
|
|
||||||
|
A simple project for text-to-image remote sensing image generation.
|
||||||
|
We will release the code of **using text to control regions for super-large RS image generation** later.
|
||||||
|
|
||||||
|
## Environment configuration
|
||||||
|
|
||||||
|
Follow [original training repo](https://github.com/justinpinkney/stable-diffusion.git) .
|
||||||
|
|
||||||
|
|
||||||
|
## Pretrained weights
|
||||||
|
|
||||||
|
We used [RSITMD](https://github.com/xiaoyuan1996/AMFMN) as training data and fine-tuned stable diffusion for 10 epochs with 1 x A100 GPU.
|
||||||
|
When the batchsize is 4, the GPU memory consumption is about 40+ Gb during training, and about 20+ Gb during sampling.
|
||||||
|
The pretrain weights is realesed at [last-pruned.ckpt](https://github.com/xiaoyuan1996/AMFMN).
|
||||||
|
|
||||||
|
## Using
|
||||||
|
Download the pretrain weights to current dir, and run with:
|
||||||
|
```commandline
|
||||||
|
bash sample.sh
|
||||||
|
```
|
||||||
|
We will update the train code ASAP.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
**Caption:** some boats drived in the sea
|
||||||
|

|
||||||
|
|
||||||
BIN
assets/shows1.png
Normal file
|
After Width: | Height: | Size: 2.0 MiB |
54
configs/autoencoder/autoencoder_kl_16x16x16.yaml
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
embed_dim: 16
|
||||||
|
lossconfig:
|
||||||
|
target: ldm.modules.losses.LPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
disc_start: 50001
|
||||||
|
kl_weight: 0.000001
|
||||||
|
disc_weight: 0.5
|
||||||
|
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [16]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 12
|
||||||
|
wrap: True
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 2
|
||||||
53
configs/autoencoder/autoencoder_kl_32x32x4.yaml
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
embed_dim: 4
|
||||||
|
lossconfig:
|
||||||
|
target: ldm.modules.losses.LPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
disc_start: 50001
|
||||||
|
kl_weight: 0.000001
|
||||||
|
disc_weight: 0.5
|
||||||
|
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 12
|
||||||
|
wrap: True
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 2
|
||||||
54
configs/autoencoder/autoencoder_kl_64x64x3.yaml
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
embed_dim: 3
|
||||||
|
lossconfig:
|
||||||
|
target: ldm.modules.losses.LPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
disc_start: 50001
|
||||||
|
kl_weight: 0.000001
|
||||||
|
disc_weight: 0.5
|
||||||
|
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 12
|
||||||
|
wrap: True
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 2
|
||||||
53
configs/autoencoder/autoencoder_kl_8x8x64.yaml
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
embed_dim: 64
|
||||||
|
lossconfig:
|
||||||
|
target: ldm.modules.losses.LPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
disc_start: 50001
|
||||||
|
kl_weight: 0.000001
|
||||||
|
disc_weight: 0.5
|
||||||
|
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 64
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [16,8]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 12
|
||||||
|
wrap: True
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 2
|
||||||
86
configs/latent-diffusion/celebahq-ldm-vq-4.yaml
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 2.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 224
|
||||||
|
attention_resolutions:
|
||||||
|
# note: this isn\t actually the resolution but
|
||||||
|
# the downsampling factor, i.e. this corresnponds to
|
||||||
|
# attention on spatial resolution 8,16,32, as the
|
||||||
|
# spatial reolution of the latents is 64 for f4
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
num_head_channels: 32
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ckpt_path: models/first_stage_models/vq-f4/model.ckpt
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config: __is_unconditional__
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 48
|
||||||
|
num_workers: 5
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: taming.data.faceshq.CelebAHQTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: taming.data.faceshq.CelebAHQValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
98
configs/latent-diffusion/cin-ldm-vq-f8.yaml
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: class_label
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 256
|
||||||
|
attention_resolutions:
|
||||||
|
#note: this isn\t actually the resolution but
|
||||||
|
# the downsampling factor, i.e. this corresnponds to
|
||||||
|
# attention on spatial resolution 8,16,32, as the
|
||||||
|
# spatial reolution of the latents is 32 for f8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_head_channels: 32
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 512
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
n_embed: 16384
|
||||||
|
ckpt_path: configs/first_stage_models/vq-f8/model.yaml
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions:
|
||||||
|
- 32
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.ClassEmbedder
|
||||||
|
params:
|
||||||
|
embed_dim: 512
|
||||||
|
key: class_label
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 12
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetTrain
|
||||||
|
params:
|
||||||
|
config:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetValidation
|
||||||
|
params:
|
||||||
|
config:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
68
configs/latent-diffusion/cin256-v2.yaml
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 0.0001
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: class_label
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 192
|
||||||
|
attention_resolutions:
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 5
|
||||||
|
num_heads: 1
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 512
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.ClassEmbedder
|
||||||
|
params:
|
||||||
|
n_classes: 1001
|
||||||
|
embed_dim: 512
|
||||||
|
key: class_label
|
||||||
85
configs/latent-diffusion/ffhq-ldm-vq-4.yaml
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 2.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 224
|
||||||
|
attention_resolutions:
|
||||||
|
# note: this isn\t actually the resolution but
|
||||||
|
# the downsampling factor, i.e. this corresnponds to
|
||||||
|
# attention on spatial resolution 8,16,32, as the
|
||||||
|
# spatial reolution of the latents is 64 for f4
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
num_head_channels: 32
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config: __is_unconditional__
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 42
|
||||||
|
num_workers: 5
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: taming.data.faceshq.FFHQTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: taming.data.faceshq.FFHQValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
85
configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 2.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 224
|
||||||
|
attention_resolutions:
|
||||||
|
# note: this isn\t actually the resolution but
|
||||||
|
# the downsampling factor, i.e. this corresnponds to
|
||||||
|
# attention on spatial resolution 8,16,32, as the
|
||||||
|
# spatial reolution of the latents is 64 for f4
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
num_head_channels: 32
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config: __is_unconditional__
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 48
|
||||||
|
num_workers: 5
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.lsun.LSUNBedroomsTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: ldm.data.lsun.LSUNBedroomsValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
91
configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0155
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
loss_type: l1
|
||||||
|
first_stage_key: "image"
|
||||||
|
cond_stage_key: "image"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: False
|
||||||
|
concat_mode: False
|
||||||
|
scale_by_std: True
|
||||||
|
monitor: 'val/loss_simple_ema'
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [10000]
|
||||||
|
cycle_lengths: [10000000000000]
|
||||||
|
f_start: [1.e-6]
|
||||||
|
f_max: [1.]
|
||||||
|
f_min: [ 1.]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 192
|
||||||
|
attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
|
||||||
|
num_heads: 8
|
||||||
|
use_scale_shift_norm: True
|
||||||
|
resblock_updown: True
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config: "__is_unconditional__"
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 96
|
||||||
|
num_workers: 5
|
||||||
|
wrap: False
|
||||||
|
train:
|
||||||
|
target: ldm.data.lsun.LSUNChurchesTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: ldm.data.lsun.LSUNChurchesValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
71
configs/latent-diffusion/txt2img-1p4B-eval.yaml
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.012
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions:
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: true
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 32
|
||||||
133
configs/stable-diffusion/RSITMD.yaml
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "image"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
num_val_workers: 0 # Avoid a weird val dataloader issue
|
||||||
|
train:
|
||||||
|
target: ldm.data.simple.hf_dataset_RSITMD
|
||||||
|
params:
|
||||||
|
name: RSITMD-captions
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
- target: torchvision.transforms.RandomHorizontalFlip
|
||||||
|
validation:
|
||||||
|
target: ldm.data.simple.TextOnly
|
||||||
|
params:
|
||||||
|
captions:
|
||||||
|
- "There is a baseball field beside the green amusement park around the red track."
|
||||||
|
- "The green playground around the red runway is a baseball field."
|
||||||
|
- "The oval building connected with the hedge is located beside the lawn, which has small paths and trees."
|
||||||
|
- "Many buildings are surrounded by a large, oval building."
|
||||||
|
output_size: 512
|
||||||
|
n_gpus: 1 # small hack to sure we see all our samples
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 2000
|
||||||
|
save_top_k: -1
|
||||||
|
monitor: null
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 2000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: True
|
||||||
|
log_all_val: True
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: True
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
||||||
131
configs/stable-diffusion/dev.yaml
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 32 # 320 # TODO increase
|
||||||
|
attention_resolutions: [ ] # is equal to fixed spatial resolution: 32 , 16 , 8
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, ]
|
||||||
|
#num_head_channels: 32
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 32
|
||||||
|
use_checkpoint: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 32
|
||||||
|
n_layer: 1 #32 # TODO: increase
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 10
|
||||||
|
num_workers: 4
|
||||||
|
n_nodes: 1
|
||||||
|
train:
|
||||||
|
shards: '{000000..000010}.tar -' # TODO: wild guess, change
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
shuffle: 5000
|
||||||
|
n_examples: 16519100 # TODO: find out
|
||||||
|
validation:
|
||||||
|
shards: '{000011..000012}.tar -' # TODO: wild guess, change
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
shuffle: 0
|
||||||
|
n_examples: 60000 # TODO: find out
|
||||||
|
val_num_workers: 2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000 # 5000
|
||||||
|
max_images: 0
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: True
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 20000 # every 20k training steps
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
|
||||||
128
configs/stable-diffusion/dev_mn.yaml
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 32 # 320 # TODO increase
|
||||||
|
attention_resolutions: [ ] # is equal to fixed spatial resolution: 32 , 16 , 8
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, ]
|
||||||
|
#num_head_channels: 32
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 32
|
||||||
|
use_checkpoint: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 32
|
||||||
|
n_layer: 1 #32 # TODO: increase
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
n_nodes: 4
|
||||||
|
train:
|
||||||
|
shards: '{000000..231339}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{231346..231349}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 500 # 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 1000 # every 20k training steps
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
|
||||||
|
|
||||||
109
configs/stable-diffusion/dev_mn_dummy.yaml
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 32 # 320 # TODO increase
|
||||||
|
attention_resolutions: [ ] # is equal to fixed spatial resolution: 32 , 16 , 8
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, ]
|
||||||
|
#num_head_channels: 32
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 32
|
||||||
|
use_checkpoint: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 32
|
||||||
|
n_layer: 1 #32 # TODO: increase
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.dummy.DummyData
|
||||||
|
params:
|
||||||
|
length: 20000
|
||||||
|
size: [256, 256, 3]
|
||||||
|
validation:
|
||||||
|
target: ldm.data.dummy.DummyData
|
||||||
|
params:
|
||||||
|
length: 10000
|
||||||
|
size: [256, 256, 3]
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 500 # 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 1000 # every 20k training steps
|
||||||
|
num_sanity_val_steps: 0
|
||||||
130
configs/stable-diffusion/face_clip.yaml
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "image"
|
||||||
|
cond_stage_key: "image"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 2000 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FaceClipEncoder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 2
|
||||||
|
num_workers: 2
|
||||||
|
num_val_workers: 0 # Avoid a weird val dataloader issue
|
||||||
|
train:
|
||||||
|
target: ldm.data.simple.FolderData
|
||||||
|
params:
|
||||||
|
root_dir: /mnt/data_rome/ffhq/images1024x1024
|
||||||
|
ext: png
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomHorizontalFlip
|
||||||
|
validation:
|
||||||
|
target: ldm.data.simple.FolderData
|
||||||
|
params:
|
||||||
|
root_dir: /mnt/data_rome/celeba1000/HQ
|
||||||
|
ext: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: true
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
save_top_k: -1
|
||||||
|
monitor: null
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 2000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: True
|
||||||
|
log_all_val: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
||||||
157
configs/stable-diffusion/inpainting/v1-edgeinpainting.yaml
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 7.5e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: hybrid # important
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
ckpt_path: "/fsx/stable-diffusion/stable-diffusion/checkpoints/v1pp/v1pp-flatlined-hr.ckpt"
|
||||||
|
|
||||||
|
concat_keys:
|
||||||
|
- mask
|
||||||
|
- masked_image
|
||||||
|
- smoothing_strength
|
||||||
|
|
||||||
|
c_concat_log_start: 1
|
||||||
|
c_concat_log_end: 5
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 2500 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 10 # 4 data + 4 downscaled image + 1 mask + 1 smoothing strength
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "__improvedaesthetic__"
|
||||||
|
batch_size: 2
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
min_size: 512
|
||||||
|
max_pwatermark: 0.8
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddEdge
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddEdge
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 2000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
disabled: False
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
ddim_steps: 100 # todo check these out for inpainting,
|
||||||
|
ddim_eta: 1.0 # todo check these out for inpainting,
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
@ -0,0 +1,156 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 7.5e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: hybrid # important
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
ckpt_path: "/fsx/stable-diffusion/stable-diffusion/checkpoints/v1pp/v1pphrflatlined2-pruned.ckpt"
|
||||||
|
|
||||||
|
ucg_training:
|
||||||
|
txt:
|
||||||
|
p: 0.1
|
||||||
|
val: ""
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "__improvedaesthetic__"
|
||||||
|
batch_size: 2
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
min_size: 512
|
||||||
|
max_pwatermark: 0.8
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
p_drop: 0.25
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
p_drop: 0.25
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 2000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
disabled: False
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
ddim_steps: 100 # todo check these out for inpainting,
|
||||||
|
ddim_eta: 1.0 # todo check these out for inpainting,
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
@ -0,0 +1,149 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 7.5e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: hybrid # important
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
ckpt_path: "/fsx/stable-diffusion/stable-diffusion/checkpoints/v1pp/v1pp-flatlined-hr.ckpt"
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "__improvedaesthetic__"
|
||||||
|
batch_size: 2
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
min_size: 512
|
||||||
|
max_pwatermark: 0.8
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 2000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
disabled: False
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
ddim_steps: 100 # todo check these out for inpainting,
|
||||||
|
ddim_eta: 1.0 # todo check these out for inpainting,
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
@ -0,0 +1,144 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 7.5e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: hybrid # important
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
ckpt_path: "/fsx/stable-diffusion/stable-diffusion/checkpoints2/v1pp/v1pp-flatline-pruned.ckpt"
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
min_size: 512
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 2000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
disabled: False
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
ddim_steps: 100 # todo check these out for inpainting,
|
||||||
|
ddim_eta: 1.0 # todo check these out for inpainting,
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
133
configs/stable-diffusion/pokemon.yaml
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "image"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
num_val_workers: 0 # Avoid a weird val dataloader issue
|
||||||
|
train:
|
||||||
|
target: ldm.data.simple.hf_dataset
|
||||||
|
params:
|
||||||
|
name: lambdalabs/pokemon-blip-captions
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
- target: torchvision.transforms.RandomHorizontalFlip
|
||||||
|
validation:
|
||||||
|
target: ldm.data.simple.TextOnly
|
||||||
|
params:
|
||||||
|
captions:
|
||||||
|
- "A pokemon with green eyes, large wings, and a hat"
|
||||||
|
- "A cute bunny rabbit"
|
||||||
|
- "Yoda"
|
||||||
|
- "An epic landscape photo of a mountain"
|
||||||
|
output_size: 512
|
||||||
|
n_gpus: 2 # small hack to sure we see all our samples
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 2000
|
||||||
|
save_top_k: -1
|
||||||
|
monitor: null
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 2000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: True
|
||||||
|
log_all_val: True
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: True
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
||||||
135
configs/stable-diffusion/sd-image-condition-attn-finetune.yaml
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "jpg"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
unet_trainable: attn
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:ssh -i ~/.ssh/id_rsa jpinkney@104.171.201.154 cat /mnt/data_rome/laion/improved_aesthetics_6plus/ims"
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 2
|
||||||
|
multinode: True
|
||||||
|
min_size: 256
|
||||||
|
train:
|
||||||
|
shards: '{00000..01209}.tar'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{00000..00008}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: false
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: True
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 8
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 4
|
||||||
134
configs/stable-diffusion/sd-image-condition-finetune.yaml
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "jpg"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "/mnt/data_rome/laion/improved_aesthetics_6plus/ims"
|
||||||
|
batch_size: 6
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
min_size: 256
|
||||||
|
train:
|
||||||
|
shards: '{00000..01209}.tar'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{00000..00008}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: false
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: True
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 8
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
||||||
138
configs/stable-diffusion/sd_finetune_256.yaml
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
unet_trainable: "attn"
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1000 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "/mnt/data_rome/laion/improved_aesthetics_6plus/ims"
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 8
|
||||||
|
multinode: True
|
||||||
|
min_size: 256
|
||||||
|
train:
|
||||||
|
shards: '{00000..01209}.tar'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{00000..00003}.tar'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
138
configs/stable-diffusion/sd_finetune_256_full.yaml
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
unet_trainable: "all"
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1000 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "/mnt/data_rome/laion/improved_aesthetics_6plus/ims"
|
||||||
|
batch_size: 92
|
||||||
|
num_workers: 8
|
||||||
|
multinode: True
|
||||||
|
min_size: 256
|
||||||
|
train:
|
||||||
|
shards: '{00000..01209}.tar'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{00000..00003}.tar'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
||||||
@ -0,0 +1,135 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
@ -0,0 +1,131 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 50
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{000000..231317}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{231318..231349}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
127
configs/stable-diffusion/txt2img-1p4B-multinode-t5-encoder.yaml
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenT5Embedder
|
||||||
|
params:
|
||||||
|
version: "google/t5-v1_1-xl"
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 12
|
||||||
|
num_workers: 4
|
||||||
|
train:
|
||||||
|
shards: '{000000..231317}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{231318..231349}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 50000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
128
configs/stable-diffusion/txt2img-1p4B-multinode.yaml
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 32
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 12
|
||||||
|
num_workers: 4
|
||||||
|
train:
|
||||||
|
shards: '{000000..231317}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{231318..231349}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 50000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
127
configs/stable-diffusion/txt2img-clip-encoder-dev.yaml
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 56
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{000000..231317}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{231318..231349}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 50000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
129
configs/stable-diffusion/txt2img-ldm-frozen-dev.yaml
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 32
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 52
|
||||||
|
num_workers: 4
|
||||||
|
multinode: False
|
||||||
|
train:
|
||||||
|
shards: '{000000..231317}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{231318..231349}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 50000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
129
configs/stable-diffusion/txt2img-ldm-unfrozen-dev.yaml
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 32
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 12
|
||||||
|
num_workers: 4
|
||||||
|
multinode: False
|
||||||
|
train:
|
||||||
|
shards: '{000000..231317}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{231318..231349}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 50000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
130
configs/stable-diffusion/txt2img-ldm-vae-f8.yaml
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04 # TODO: run with scale_lr False
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 128 # 320 # TODO increase
|
||||||
|
attention_resolutions: [ 4, 2, 1 ] # is equal to fixed spatial resolution: 32 , 16 , 8
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1,2,4,4 ]
|
||||||
|
#num_head_channels: 32
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: True
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "/home/robin/projects/latent-diffusion/models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 3 #32 # TODO: increase
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 60
|
||||||
|
num_workers: 4
|
||||||
|
n_nodes: 2 # TODO: runs with two gpus
|
||||||
|
train:
|
||||||
|
shards: '{000000..000010}.tar -' # TODO: wild guess, change
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
shuffle: 5000
|
||||||
|
n_examples: 16519100 # TODO: find out
|
||||||
|
validation:
|
||||||
|
shards: '{000011..000012}.tar -' # TODO: wild guess, change
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
shuffle: 0
|
||||||
|
n_examples: 60000 # TODO: find out
|
||||||
|
val_num_workers: 2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000 # 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: True
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 20000 # every 20k training steps
|
||||||
|
num_sanity_val_steps: 0
|
||||||
@ -0,0 +1,133 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
#ckpt_path: "/home/mchorse/stable-diffusion-ckpts/768f16-2022-06-23-pruned.ckpt"
|
||||||
|
|
||||||
|
#scheduler_config: # 10000 warmup steps
|
||||||
|
# target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
# params:
|
||||||
|
# warm_up_steps: [ 10000 ]
|
||||||
|
# cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
# f_start: [ 1.e-6 ]
|
||||||
|
# f_max: [ 1. ]
|
||||||
|
# f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64 # not really needed
|
||||||
|
in_channels: 16
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 16
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ 16 ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
|
||||||
|
batch_size: 3
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 2000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 2000
|
||||||
|
max_images: 2
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 2
|
||||||
|
unconditional_guidance_scale: 5.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 4
|
||||||
@ -0,0 +1,127 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 16
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 16 # not really needed
|
||||||
|
in_channels: 16
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 320 # TODO: scale model here
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 16
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f16/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ 16 ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 55
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
min_size: 256 # TODO: experiment. Note: for 2B, images are stored at max 384 resolution
|
||||||
|
train:
|
||||||
|
shards: '{000000..231317}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{231318..231349}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
@ -0,0 +1,65 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 48
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 48
|
||||||
|
in_channels: 16
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 16
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ 16 ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
@ -0,0 +1,133 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 48
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
ckpt_path: "/home/mchorse/stable-diffusion-ckpts/768f16-2022-06-23-pruned.ckpt"
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 48 # not really needed
|
||||||
|
in_channels: 16
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 16
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ 16 ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
|
||||||
|
batch_size: 6
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 768
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 768
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 768
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 768
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
@ -0,0 +1,130 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 48
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
ckpt_path: "/home/mchorse/stable-diffusion-ckpts/256f16-2022-06-15-216k-pruned.ckpt"
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 48 # not really needed
|
||||||
|
in_channels: 16
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 320 # TODO: scale model here
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 16
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ 16 ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 6
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
min_size: 384 # TODO: experiment. Note: for 2B, images are stored at max 384 resolution
|
||||||
|
train:
|
||||||
|
shards: '{000000..231317}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 768
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 768
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{231318..231349}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 768
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 768
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
128
configs/stable-diffusion/txt2img-t5-encoder-dev.yaml
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 2048
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenT5Embedder
|
||||||
|
params:
|
||||||
|
version: "google/t5-v1_1-xl"
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 40
|
||||||
|
num_workers: 4
|
||||||
|
multinode: False
|
||||||
|
train:
|
||||||
|
shards: '{000000..231317}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{231318..231349}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 50000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -0,0 +1,177 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
|
||||||
|
params:
|
||||||
|
low_scale_key: "LR_image" # TODO: adapt
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "image"
|
||||||
|
#first_stage_key: "jpg" # TODO: use this later
|
||||||
|
cond_stage_key: "caption"
|
||||||
|
#cond_stage_key: "txt" # TODO: use this later
|
||||||
|
image_size: 64
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: "hybrid-adm"
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
low_scale_config:
|
||||||
|
target: ldm.modules.encoders.modules.LowScaleEncoder
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
timesteps: 1000
|
||||||
|
max_noise_level: 100
|
||||||
|
output_size: 64
|
||||||
|
model_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
num_classes: 1000 # timesteps for noise conditoining
|
||||||
|
image_size: 64 # not really needed
|
||||||
|
in_channels: 20
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 32 # TODO: more
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 16
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f16/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ 16 ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
#data:
|
||||||
|
# target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
# params:
|
||||||
|
# tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
# batch_size: 4
|
||||||
|
# num_workers: 4
|
||||||
|
# multinode: True
|
||||||
|
# min_size: 256 # TODO: experiment. Note: for 2B, images are stored at max 384 resolution
|
||||||
|
# train:
|
||||||
|
# shards: '{000000..231317}.tar -'
|
||||||
|
# shuffle: 10000
|
||||||
|
# image_key: jpg
|
||||||
|
# image_transforms:
|
||||||
|
# - target: torchvision.transforms.Resize
|
||||||
|
# params:
|
||||||
|
# size: 1024
|
||||||
|
# interpolation: 3
|
||||||
|
# - target: torchvision.transforms.RandomCrop
|
||||||
|
# params:
|
||||||
|
# size: 1024
|
||||||
|
#
|
||||||
|
# # NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
# validation:
|
||||||
|
# shards: '{231318..231349}.tar -'
|
||||||
|
# shuffle: 0
|
||||||
|
# image_key: jpg
|
||||||
|
# image_transforms:
|
||||||
|
# - target: torchvision.transforms.Resize
|
||||||
|
# params:
|
||||||
|
# size: 1024
|
||||||
|
# interpolation: 3
|
||||||
|
# - target: torchvision.transforms.CenterCrop
|
||||||
|
# params:
|
||||||
|
# size: 1024
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 7
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
downscale_f: 4
|
||||||
|
degradation: "cv_nearest"
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 10
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
sample: False
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
#unconditional_guidance_scale: 3.0
|
||||||
|
#unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
# val_check_interval: 5000000 # really sorry # TODO: bring back in
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
||||||
@ -0,0 +1,170 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
|
||||||
|
params:
|
||||||
|
low_scale_key: "lr"
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: "hybrid-adm"
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
low_scale_config:
|
||||||
|
target: ldm.modules.encoders.modules.LowScaleEncoder
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
timesteps: 1000
|
||||||
|
max_noise_level: 100
|
||||||
|
output_size: 64
|
||||||
|
model_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
num_classes: 1000 # timesteps for noise conditoining
|
||||||
|
image_size: 64 # not really needed
|
||||||
|
in_channels: 20
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 96
|
||||||
|
attention_resolutions: [ 8, 4, 2 ] # -> at 32, 16, 8
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 8, 8 ]
|
||||||
|
# -> res, ds: (64, 1), (32, 2), (16, 4), (8, 8), (4, 16)
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 16
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f16/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ 16 ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
|
||||||
|
batch_size: 10
|
||||||
|
num_workers: 4
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddLR
|
||||||
|
params:
|
||||||
|
factor: 4
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddLR
|
||||||
|
params:
|
||||||
|
factor: 4
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 4
|
||||||
@ -0,0 +1,149 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 384
|
||||||
|
attention_resolutions: [ 8, 4, 2, 1 ]
|
||||||
|
num_res_blocks: [ 2, 2, 2, 2 ]
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data: # TODO
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 4
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.dummy.DummyData
|
||||||
|
params:
|
||||||
|
length: 20000
|
||||||
|
size: [256, 256, 3]
|
||||||
|
validation:
|
||||||
|
target: ldm.data.dummy.DummyData
|
||||||
|
params:
|
||||||
|
length: 10000
|
||||||
|
size: [256, 256, 3]
|
||||||
|
|
||||||
|
#data:
|
||||||
|
# target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
# params:
|
||||||
|
# tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
|
||||||
|
# batch_size: 4
|
||||||
|
# num_workers: 4
|
||||||
|
# multinode: True
|
||||||
|
# train:
|
||||||
|
# shards: '{00000..17279}.tar -'
|
||||||
|
# shuffle: 10000
|
||||||
|
# image_key: jpg
|
||||||
|
# image_transforms:
|
||||||
|
# - target: torchvision.transforms.Resize
|
||||||
|
# params:
|
||||||
|
# size: 512
|
||||||
|
# interpolation: 3
|
||||||
|
# - target: torchvision.transforms.RandomCrop
|
||||||
|
# params:
|
||||||
|
# size: 512
|
||||||
|
#
|
||||||
|
# # NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
# validation:
|
||||||
|
# shards: '{17280..17535}.tar -'
|
||||||
|
# shuffle: 0
|
||||||
|
# image_key: jpg
|
||||||
|
# image_transforms:
|
||||||
|
# - target: torchvision.transforms.Resize
|
||||||
|
# params:
|
||||||
|
# size: 512
|
||||||
|
# interpolation: 3
|
||||||
|
# - target: torchvision.transforms.CenterCrop
|
||||||
|
# params:
|
||||||
|
# size: 512
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5 # TODO
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 200 # TODO: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
@ -0,0 +1,137 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 8.e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 416
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: [ 2, 2, 2, 2 ]
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "/fsx/stable-diffusion/stable-diffusion/models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "__improvedaesthetic__"
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# # NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: false
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
disabled: True
|
||||||
|
batch_frequency: 2500
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
||||||
@ -0,0 +1,149 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 384
|
||||||
|
attention_resolutions: [ 8, 4, 2, 1 ]
|
||||||
|
num_res_blocks: [ 2, 2, 2, 2 ]
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data: # TODO
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 4
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.dummy.DummyData
|
||||||
|
params:
|
||||||
|
length: 20000
|
||||||
|
size: [512, 512, 3]
|
||||||
|
validation:
|
||||||
|
target: ldm.data.dummy.DummyData
|
||||||
|
params:
|
||||||
|
length: 10000
|
||||||
|
size: [512, 512, 3]
|
||||||
|
|
||||||
|
#data:
|
||||||
|
# target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
# params:
|
||||||
|
# tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
|
||||||
|
# batch_size: 4
|
||||||
|
# num_workers: 4
|
||||||
|
# multinode: True
|
||||||
|
# train:
|
||||||
|
# shards: '{00000..17279}.tar -'
|
||||||
|
# shuffle: 10000
|
||||||
|
# image_key: jpg
|
||||||
|
# image_transforms:
|
||||||
|
# - target: torchvision.transforms.Resize
|
||||||
|
# params:
|
||||||
|
# size: 512
|
||||||
|
# interpolation: 3
|
||||||
|
# - target: torchvision.transforms.RandomCrop
|
||||||
|
# params:
|
||||||
|
# size: 512
|
||||||
|
#
|
||||||
|
# # NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
# validation:
|
||||||
|
# shards: '{17280..17535}.tar -'
|
||||||
|
# shuffle: 0
|
||||||
|
# image_key: jpg
|
||||||
|
# image_transforms:
|
||||||
|
# - target: torchvision.transforms.Resize
|
||||||
|
# params:
|
||||||
|
# size: 512
|
||||||
|
# interpolation: 3
|
||||||
|
# - target: torchvision.transforms.CenterCrop
|
||||||
|
# params:
|
||||||
|
# size: 512
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5 # TODO
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 1000 # TODO: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
@ -0,0 +1,135 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 416
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: [ 2, 2, 2, 2 ]
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "__improvedaesthetic__"
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
# # NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: false
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 2500
|
||||||
|
max_images: 2
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 2
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
214
configs/stable-diffusion/upscaling/upscale-v1-with-f16.yaml
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
|
||||||
|
params:
|
||||||
|
low_scale_key: "lr"
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: "hybrid-adm"
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
low_scale_config:
|
||||||
|
target: ldm.modules.encoders.modules.LowScaleEncoder
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
timesteps: 1000
|
||||||
|
max_noise_level: 250
|
||||||
|
output_size: null
|
||||||
|
model_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "/fsx/stable-diffusion/stable-diffusion/models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
num_classes: 251 # timesteps for noise conditoining
|
||||||
|
image_size: 64 # not really needed
|
||||||
|
in_channels: 20
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 128
|
||||||
|
attention_resolutions: [ 8, 4, 2 ] # -> at 32, 16, 8
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 6, 8 ]
|
||||||
|
# -> res, ds: (64, 1), (32, 2), (16, 4), (6, 8), (4, 16)
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 16
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "/fsx/stable-diffusion/stable-diffusion/models/first_stage_models/kl-f16/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ 16 ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
#data: # TODO: finetune here later
|
||||||
|
# target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
# params:
|
||||||
|
# tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
|
||||||
|
# batch_size: 10
|
||||||
|
# num_workers: 4
|
||||||
|
# train:
|
||||||
|
# shards: '{00000..17279}.tar -'
|
||||||
|
# shuffle: 10000
|
||||||
|
# image_key: jpg
|
||||||
|
# image_transforms:
|
||||||
|
# - target: torchvision.transforms.Resize
|
||||||
|
# params:
|
||||||
|
# size: 1024
|
||||||
|
# interpolation: 3
|
||||||
|
# - target: torchvision.transforms.RandomCrop
|
||||||
|
# params:
|
||||||
|
# size: 1024
|
||||||
|
# postprocess:
|
||||||
|
# target: ldm.data.laion.AddLR
|
||||||
|
# params:
|
||||||
|
# factor: 2
|
||||||
|
#
|
||||||
|
# # NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
# validation:
|
||||||
|
# shards: '{17280..17535}.tar -'
|
||||||
|
# shuffle: 0
|
||||||
|
# image_key: jpg
|
||||||
|
# image_transforms:
|
||||||
|
# - target: torchvision.transforms.Resize
|
||||||
|
# params:
|
||||||
|
# size: 1024
|
||||||
|
# interpolation: 3
|
||||||
|
# - target: torchvision.transforms.CenterCrop
|
||||||
|
# params:
|
||||||
|
# size: 1024
|
||||||
|
# postprocess:
|
||||||
|
# target: ldm.data.laion.AddLR
|
||||||
|
# params:
|
||||||
|
# factor: 2
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "__improvedaesthetic__"
|
||||||
|
batch_size: 28
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
min_size: 512
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddLR
|
||||||
|
params:
|
||||||
|
factor: 2
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddLR
|
||||||
|
params:
|
||||||
|
factor: 2
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
181
configs/stable-diffusion/upscaling_256.yaml
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: "hybrid-adm"
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
low_scale_key: "lr"
|
||||||
|
|
||||||
|
low_scale_config:
|
||||||
|
target: ldm.modules.encoders.modules.LowScaleEncoder
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
timesteps: 1000
|
||||||
|
max_noise_level: 100
|
||||||
|
output_size: null
|
||||||
|
model_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
num_classes: 1000
|
||||||
|
image_size: 16 # unused
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:ssh -v -i ~/.ssh/id_rsa jpinkney@104.171.201.154 cat /mnt/data_rome/laion/improved_aesthetics_6plus/ims"
|
||||||
|
batch_size: 48
|
||||||
|
num_workers: 8
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{00000..01209}.tar'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddLR
|
||||||
|
params:
|
||||||
|
factor: 4
|
||||||
|
output_size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{00000..00012}.tar'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddLR
|
||||||
|
params:
|
||||||
|
factor: 1
|
||||||
|
output_size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 2000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
||||||
181
configs/stable-diffusion/upscaling_512.yaml
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: "hybrid-adm"
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
low_scale_key: "lr"
|
||||||
|
|
||||||
|
low_scale_config:
|
||||||
|
target: ldm.modules.encoders.modules.LowScaleEncoder
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
timesteps: 1000
|
||||||
|
max_noise_level: 100
|
||||||
|
output_size: null
|
||||||
|
model_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 5000 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
num_classes: 1000
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:ssh -i ~/.ssh/id_rsa jpinkney@104.171.201.154 cat /mnt/data_rome/laion/improved_aesthetics_6plus/ims"
|
||||||
|
batch_size: 3
|
||||||
|
num_workers: 2
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{00000..01209}.tar'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddLR
|
||||||
|
params:
|
||||||
|
factor: 4
|
||||||
|
output_size: 512
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{00000..00012}.tar'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddLR
|
||||||
|
params:
|
||||||
|
factor: 4
|
||||||
|
output_size: 512
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 4
|
||||||
69
configs/stable-diffusion/v1-inference.yaml
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
135
configs/stable-diffusion/v1_improvedaesthetics.yaml
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "__improvedaesthetic__"
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
135
configs/stable-diffusion/v1_laionhr.yaml
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
132
configs/stable-diffusion/v2_laionhr1024.yaml
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
# NOTE disabled for resuming
|
||||||
|
#scheduler_config: # 10000 warmup steps
|
||||||
|
# target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
# params:
|
||||||
|
# warm_up_steps: [ 10000 ]
|
||||||
|
# cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
# f_start: [ 1.e-6 ]
|
||||||
|
# f_max: [ 1. ]
|
||||||
|
# f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64 # not really needed
|
||||||
|
in_channels: 16
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 16
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ 16 ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
|
||||||
|
batch_size: 3
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 2000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 2000
|
||||||
|
max_images: 2
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 2
|
||||||
|
unconditional_guidance_scale: 5.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 4
|
||||||
132
configs/stable-diffusion/v2_laionhr1024_2.yaml
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 7.5e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
# NOTE disabled for resuming
|
||||||
|
#scheduler_config: # 10000 warmup steps
|
||||||
|
# target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
# params:
|
||||||
|
# warm_up_steps: [ 10000 ]
|
||||||
|
# cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
# f_start: [ 1.e-6 ]
|
||||||
|
# f_max: [ 1. ]
|
||||||
|
# f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64 # not really needed
|
||||||
|
in_channels: 16
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 16
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ 16 ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion-high-resolution/"
|
||||||
|
batch_size: 3
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 1024
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 2000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 2000
|
||||||
|
max_images: 2
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 2
|
||||||
|
unconditional_guidance_scale: 5.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 2
|
||||||
131
configs/stable-diffusion/v2_pretraining.yaml
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.001
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 16
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.22765929 # magic number
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 16 # not really needed
|
||||||
|
in_channels: 16
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 320 # TODO: scale model here
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 16
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f16/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ 16 ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/"
|
||||||
|
batch_size: 55
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
min_size: 256
|
||||||
|
train:
|
||||||
|
shards: '{000000..231317}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{231318..231349}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: false
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
||||||
137
configs/stable-diffusion/v3_pretraining.yaml
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 8.e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 416
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: [ 2, 2, 2, 2 ]
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "/fsx/stable-diffusion/stable-diffusion/models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "__improvedaesthetic__"
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
train:
|
||||||
|
shards: '{00000..17279}.tar -'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
# # NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{17280..17535}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: false
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
disabled: True
|
||||||
|
batch_frequency: 2500
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
#replace_sampler_ddp: False
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
||||||
BIN
data/DejaVuSans.ttf
Normal file
BIN
data/example_conditioning/superresolution/sample_0.jpg
Normal file
|
After Width: | Height: | Size: 14 KiB |
1
data/example_conditioning/text_conditional/sample_0.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
A basket of cerries
|
||||||
1000
data/imagenet_clsidx_to_label.txt
Executable file
BIN
data/imagenet_train_hr_indices.p
Normal file
BIN
data/imagenet_val_hr_indices.p
Normal file
1000
data/index_synset.yaml
Normal file
BIN
data/inpainting_examples/6458524847_2f4c361183_k.png
Normal file
|
After Width: | Height: | Size: 466 KiB |
BIN
data/inpainting_examples/6458524847_2f4c361183_k_mask.png
Normal file
|
After Width: | Height: | Size: 7.4 KiB |
BIN
data/inpainting_examples/8399166846_f6fb4e4b8e_k.png
Normal file
|
After Width: | Height: | Size: 539 KiB |
BIN
data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png
Normal file
|
After Width: | Height: | Size: 7.6 KiB |
BIN
data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png
Normal file
|
After Width: | Height: | Size: 450 KiB |
BIN
data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
data/inpainting_examples/bench2.png
Normal file
|
After Width: | Height: | Size: 553 KiB |
BIN
data/inpainting_examples/bench2_mask.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png
Normal file
|
After Width: | Height: | Size: 418 KiB |
BIN
data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png
Normal file
|
After Width: | Height: | Size: 6.1 KiB |
BIN
data/inpainting_examples/billow926-12-Wc-Zgx6Y.png
Normal file
|
After Width: | Height: | Size: 542 KiB |
BIN
data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png
Normal file
|
After Width: | Height: | Size: 9.5 KiB |
BIN
data/inpainting_examples/overture-creations-5sI6fQgYIuo.png
Normal file
|
After Width: | Height: | Size: 395 KiB |
BIN
data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png
Normal file
|
After Width: | Height: | Size: 465 KiB |
|
After Width: | Height: | Size: 7.8 KiB |
0
ldm/data/__init__.py
Normal file
40
ldm/data/base.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from abc import abstractmethod
|
||||||
|
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
||||||
|
|
||||||
|
|
||||||
|
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||||
|
'''
|
||||||
|
Define an interface to make the IterableDatasets for text2img data chainable
|
||||||
|
'''
|
||||||
|
def __init__(self, num_records=0, valid_ids=None, size=256):
|
||||||
|
super().__init__()
|
||||||
|
self.num_records = num_records
|
||||||
|
self.valid_ids = valid_ids
|
||||||
|
self.sample_ids = valid_ids
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_records
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __iter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PRNGMixin(object):
|
||||||
|
"""
|
||||||
|
Adds a prng property which is a numpy RandomState which gets
|
||||||
|
reinitialized whenever the pid changes to avoid synchronized sampling
|
||||||
|
behavior when used in conjunction with multiprocessing.
|
||||||
|
"""
|
||||||
|
@property
|
||||||
|
def prng(self):
|
||||||
|
currentpid = os.getpid()
|
||||||
|
if getattr(self, "_initpid", None) != currentpid:
|
||||||
|
self._initpid = currentpid
|
||||||
|
self._prng = np.random.RandomState()
|
||||||
|
return self._prng
|
||||||
253
ldm/data/coco.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import albumentations
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class CocoBase(Dataset):
|
||||||
|
"""needed for (image, caption, segmentation) pairs"""
|
||||||
|
def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
|
||||||
|
crop_size=None, force_no_crop=False, given_files=None, use_segmentation=True,crop_type=None):
|
||||||
|
self.split = self.get_split()
|
||||||
|
self.size = size
|
||||||
|
if crop_size is None:
|
||||||
|
self.crop_size = size
|
||||||
|
else:
|
||||||
|
self.crop_size = crop_size
|
||||||
|
|
||||||
|
assert crop_type in [None, 'random', 'center']
|
||||||
|
self.crop_type = crop_type
|
||||||
|
self.use_segmenation = use_segmentation
|
||||||
|
self.onehot = onehot_segmentation # return segmentation as rgb or one hot
|
||||||
|
self.stuffthing = use_stuffthing # include thing in segmentation
|
||||||
|
if self.onehot and not self.stuffthing:
|
||||||
|
raise NotImplemented("One hot mode is only supported for the "
|
||||||
|
"stuffthings version because labels are stored "
|
||||||
|
"a bit different.")
|
||||||
|
|
||||||
|
data_json = datajson
|
||||||
|
with open(data_json) as json_file:
|
||||||
|
self.json_data = json.load(json_file)
|
||||||
|
self.img_id_to_captions = dict()
|
||||||
|
self.img_id_to_filepath = dict()
|
||||||
|
self.img_id_to_segmentation_filepath = dict()
|
||||||
|
|
||||||
|
assert data_json.split("/")[-1] in [f"captions_train{self.year()}.json",
|
||||||
|
f"captions_val{self.year()}.json"]
|
||||||
|
# TODO currently hardcoded paths, would be better to follow logic in
|
||||||
|
# cocstuff pixelmaps
|
||||||
|
if self.use_segmenation:
|
||||||
|
if self.stuffthing:
|
||||||
|
self.segmentation_prefix = (
|
||||||
|
f"data/cocostuffthings/val{self.year()}" if
|
||||||
|
data_json.endswith(f"captions_val{self.year()}.json") else
|
||||||
|
f"data/cocostuffthings/train{self.year()}")
|
||||||
|
else:
|
||||||
|
self.segmentation_prefix = (
|
||||||
|
f"data/coco/annotations/stuff_val{self.year()}_pixelmaps" if
|
||||||
|
data_json.endswith(f"captions_val{self.year()}.json") else
|
||||||
|
f"data/coco/annotations/stuff_train{self.year()}_pixelmaps")
|
||||||
|
|
||||||
|
imagedirs = self.json_data["images"]
|
||||||
|
self.labels = {"image_ids": list()}
|
||||||
|
for imgdir in tqdm(imagedirs, desc="ImgToPath"):
|
||||||
|
self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
|
||||||
|
self.img_id_to_captions[imgdir["id"]] = list()
|
||||||
|
pngfilename = imgdir["file_name"].replace("jpg", "png")
|
||||||
|
if self.use_segmenation:
|
||||||
|
self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
|
||||||
|
self.segmentation_prefix, pngfilename)
|
||||||
|
if given_files is not None:
|
||||||
|
if pngfilename in given_files:
|
||||||
|
self.labels["image_ids"].append(imgdir["id"])
|
||||||
|
else:
|
||||||
|
self.labels["image_ids"].append(imgdir["id"])
|
||||||
|
|
||||||
|
capdirs = self.json_data["annotations"]
|
||||||
|
for capdir in tqdm(capdirs, desc="ImgToCaptions"):
|
||||||
|
# there are in average 5 captions per image
|
||||||
|
#self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
|
||||||
|
self.img_id_to_captions[capdir["image_id"]].append(capdir["caption"])
|
||||||
|
|
||||||
|
self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
|
||||||
|
if self.split=="validation":
|
||||||
|
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
|
||||||
|
else:
|
||||||
|
# default option for train is random crop
|
||||||
|
if self.crop_type in [None, 'random']:
|
||||||
|
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
|
||||||
|
else:
|
||||||
|
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
|
||||||
|
self.preprocessor = albumentations.Compose(
|
||||||
|
[self.rescaler, self.cropper],
|
||||||
|
additional_targets={"segmentation": "image"})
|
||||||
|
if force_no_crop:
|
||||||
|
self.rescaler = albumentations.Resize(height=self.size, width=self.size)
|
||||||
|
self.preprocessor = albumentations.Compose(
|
||||||
|
[self.rescaler],
|
||||||
|
additional_targets={"segmentation": "image"})
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def year(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels["image_ids"])
|
||||||
|
|
||||||
|
def preprocess_image(self, image_path, segmentation_path=None):
|
||||||
|
image = Image.open(image_path)
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
if segmentation_path:
|
||||||
|
segmentation = Image.open(segmentation_path)
|
||||||
|
if not self.onehot and not segmentation.mode == "RGB":
|
||||||
|
segmentation = segmentation.convert("RGB")
|
||||||
|
segmentation = np.array(segmentation).astype(np.uint8)
|
||||||
|
if self.onehot:
|
||||||
|
assert self.stuffthing
|
||||||
|
# stored in caffe format: unlabeled==255. stuff and thing from
|
||||||
|
# 0-181. to be compatible with the labels in
|
||||||
|
# https://github.com/nightrome/cocostuff/blob/master/labels.txt
|
||||||
|
# we shift stuffthing one to the right and put unlabeled in zero
|
||||||
|
# as long as segmentation is uint8 shifting to right handles the
|
||||||
|
# latter too
|
||||||
|
assert segmentation.dtype == np.uint8
|
||||||
|
segmentation = segmentation + 1
|
||||||
|
|
||||||
|
processed = self.preprocessor(image=image, segmentation=segmentation)
|
||||||
|
|
||||||
|
image, segmentation = processed["image"], processed["segmentation"]
|
||||||
|
else:
|
||||||
|
image = self.preprocessor(image=image,)['image']
|
||||||
|
|
||||||
|
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
|
if segmentation_path:
|
||||||
|
if self.onehot:
|
||||||
|
assert segmentation.dtype == np.uint8
|
||||||
|
# make it one hot
|
||||||
|
n_labels = 183
|
||||||
|
flatseg = np.ravel(segmentation)
|
||||||
|
onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
|
||||||
|
onehot[np.arange(flatseg.size), flatseg] = True
|
||||||
|
onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
|
||||||
|
segmentation = onehot
|
||||||
|
else:
|
||||||
|
segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
|
||||||
|
return image, segmentation
|
||||||
|
else:
|
||||||
|
return image
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
|
||||||
|
if self.use_segmenation:
|
||||||
|
seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
|
||||||
|
image, segmentation = self.preprocess_image(img_path, seg_path)
|
||||||
|
else:
|
||||||
|
image = self.preprocess_image(img_path)
|
||||||
|
captions = self.img_id_to_captions[self.labels["image_ids"][i]]
|
||||||
|
# randomly draw one of all available captions per image
|
||||||
|
caption = captions[np.random.randint(0, len(captions))]
|
||||||
|
example = {"image": image,
|
||||||
|
#"caption": [str(caption[0])],
|
||||||
|
"caption": caption,
|
||||||
|
"img_path": img_path,
|
||||||
|
"filename_": img_path.split(os.sep)[-1]
|
||||||
|
}
|
||||||
|
if self.use_segmenation:
|
||||||
|
example.update({"seg_path": seg_path, 'segmentation': segmentation})
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
class CocoImagesAndCaptionsTrain2017(CocoBase):
|
||||||
|
"""returns a pair of (image, caption)"""
|
||||||
|
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,):
|
||||||
|
super().__init__(size=size,
|
||||||
|
dataroot="data/coco/train2017",
|
||||||
|
datajson="data/coco/annotations/captions_train2017.json",
|
||||||
|
onehot_segmentation=onehot_segmentation,
|
||||||
|
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
|
||||||
|
|
||||||
|
def get_split(self):
|
||||||
|
return "train"
|
||||||
|
|
||||||
|
def year(self):
|
||||||
|
return '2017'
|
||||||
|
|
||||||
|
|
||||||
|
class CocoImagesAndCaptionsValidation2017(CocoBase):
|
||||||
|
"""returns a pair of (image, caption)"""
|
||||||
|
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
|
||||||
|
given_files=None):
|
||||||
|
super().__init__(size=size,
|
||||||
|
dataroot="data/coco/val2017",
|
||||||
|
datajson="data/coco/annotations/captions_val2017.json",
|
||||||
|
onehot_segmentation=onehot_segmentation,
|
||||||
|
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
|
||||||
|
given_files=given_files)
|
||||||
|
|
||||||
|
def get_split(self):
|
||||||
|
return "validation"
|
||||||
|
|
||||||
|
def year(self):
|
||||||
|
return '2017'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CocoImagesAndCaptionsTrain2014(CocoBase):
|
||||||
|
"""returns a pair of (image, caption)"""
|
||||||
|
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,crop_type='random'):
|
||||||
|
super().__init__(size=size,
|
||||||
|
dataroot="data/coco/train2014",
|
||||||
|
datajson="data/coco/annotations2014/annotations/captions_train2014.json",
|
||||||
|
onehot_segmentation=onehot_segmentation,
|
||||||
|
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
|
||||||
|
use_segmentation=False,
|
||||||
|
crop_type=crop_type)
|
||||||
|
|
||||||
|
def get_split(self):
|
||||||
|
return "train"
|
||||||
|
|
||||||
|
def year(self):
|
||||||
|
return '2014'
|
||||||
|
|
||||||
|
class CocoImagesAndCaptionsValidation2014(CocoBase):
|
||||||
|
"""returns a pair of (image, caption)"""
|
||||||
|
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
|
||||||
|
given_files=None,crop_type='center',**kwargs):
|
||||||
|
super().__init__(size=size,
|
||||||
|
dataroot="data/coco/val2014",
|
||||||
|
datajson="data/coco/annotations2014/annotations/captions_val2014.json",
|
||||||
|
onehot_segmentation=onehot_segmentation,
|
||||||
|
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
|
||||||
|
given_files=given_files,
|
||||||
|
use_segmentation=False,
|
||||||
|
crop_type=crop_type)
|
||||||
|
|
||||||
|
def get_split(self):
|
||||||
|
return "validation"
|
||||||
|
|
||||||
|
def year(self):
|
||||||
|
return '2014'
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file:
|
||||||
|
json_data = json.load(json_file)
|
||||||
|
capdirs = json_data["annotations"]
|
||||||
|
import pudb; pudb.set_trace()
|
||||||
|
#d2 = CocoImagesAndCaptionsTrain2014(size=256)
|
||||||
|
d2 = CocoImagesAndCaptionsValidation2014(size=256)
|
||||||
|
print("constructed dataset.")
|
||||||
|
print(f"length of {d2.__class__.__name__}: {len(d2)}")
|
||||||
|
|
||||||
|
ex2 = d2[0]
|
||||||
|
# ex3 = d3[0]
|
||||||
|
# print(ex1["image"].shape)
|
||||||
|
print(ex2["image"].shape)
|
||||||
|
# print(ex3["image"].shape)
|
||||||
|
# print(ex1["segmentation"].shape)
|
||||||
|
print(ex2["caption"].__class__.__name__)
|
||||||
34
ldm/data/dummy.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
from torch.utils.data import Dataset, Subset
|
||||||
|
|
||||||
|
class DummyData(Dataset):
|
||||||
|
def __init__(self, length, size):
|
||||||
|
self.length = length
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
x = np.random.randn(*self.size)
|
||||||
|
letters = string.ascii_lowercase
|
||||||
|
y = ''.join(random.choice(string.ascii_lowercase) for i in range(10))
|
||||||
|
return {"jpg": x, "txt": y}
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataWithEmbeddings(Dataset):
|
||||||
|
def __init__(self, length, size, emb_size):
|
||||||
|
self.length = length
|
||||||
|
self.size = size
|
||||||
|
self.emb_size = emb_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
x = np.random.randn(*self.size)
|
||||||
|
y = np.random.randn(*self.emb_size).astype(np.float32)
|
||||||
|
return {"jpg": x, "txt": y}
|
||||||
|
|
||||||
394
ldm/data/imagenet.py
Normal file
@ -0,0 +1,394 @@
|
|||||||
|
import os, yaml, pickle, shutil, tarfile, glob
|
||||||
|
import cv2
|
||||||
|
import albumentations
|
||||||
|
import PIL
|
||||||
|
import numpy as np
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from functools import partial
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.utils.data import Dataset, Subset
|
||||||
|
|
||||||
|
import taming.data.utils as tdu
|
||||||
|
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
|
||||||
|
from taming.data.imagenet import ImagePaths
|
||||||
|
|
||||||
|
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
|
||||||
|
|
||||||
|
|
||||||
|
def synset2idx(path_to_yaml="data/index_synset.yaml"):
|
||||||
|
with open(path_to_yaml) as f:
|
||||||
|
di2s = yaml.load(f)
|
||||||
|
return dict((v,k) for k,v in di2s.items())
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetBase(Dataset):
|
||||||
|
def __init__(self, config=None):
|
||||||
|
self.config = config or OmegaConf.create()
|
||||||
|
if not type(self.config)==dict:
|
||||||
|
self.config = OmegaConf.to_container(self.config)
|
||||||
|
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
|
||||||
|
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
||||||
|
self._prepare()
|
||||||
|
self._prepare_synset_to_human()
|
||||||
|
self._prepare_idx_to_synset()
|
||||||
|
self._prepare_human_to_integer_label()
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.data[i]
|
||||||
|
|
||||||
|
def _prepare(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _filter_relpaths(self, relpaths):
|
||||||
|
ignore = set([
|
||||||
|
"n06596364_9591.JPEG",
|
||||||
|
])
|
||||||
|
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
||||||
|
if "sub_indices" in self.config:
|
||||||
|
indices = str_to_indices(self.config["sub_indices"])
|
||||||
|
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
|
||||||
|
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
|
||||||
|
files = []
|
||||||
|
for rpath in relpaths:
|
||||||
|
syn = rpath.split("/")[0]
|
||||||
|
if syn in synsets:
|
||||||
|
files.append(rpath)
|
||||||
|
return files
|
||||||
|
else:
|
||||||
|
return relpaths
|
||||||
|
|
||||||
|
def _prepare_synset_to_human(self):
|
||||||
|
SIZE = 2655750
|
||||||
|
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
||||||
|
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
||||||
|
if (not os.path.exists(self.human_dict) or
|
||||||
|
not os.path.getsize(self.human_dict)==SIZE):
|
||||||
|
download(URL, self.human_dict)
|
||||||
|
|
||||||
|
def _prepare_idx_to_synset(self):
|
||||||
|
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
||||||
|
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
||||||
|
if (not os.path.exists(self.idx2syn)):
|
||||||
|
download(URL, self.idx2syn)
|
||||||
|
|
||||||
|
def _prepare_human_to_integer_label(self):
|
||||||
|
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
|
||||||
|
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
|
||||||
|
if (not os.path.exists(self.human2integer)):
|
||||||
|
download(URL, self.human2integer)
|
||||||
|
with open(self.human2integer, "r") as f:
|
||||||
|
lines = f.read().splitlines()
|
||||||
|
assert len(lines) == 1000
|
||||||
|
self.human2integer_dict = dict()
|
||||||
|
for line in lines:
|
||||||
|
value, key = line.split(":")
|
||||||
|
self.human2integer_dict[key] = int(value)
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
with open(self.txt_filelist, "r") as f:
|
||||||
|
self.relpaths = f.read().splitlines()
|
||||||
|
l1 = len(self.relpaths)
|
||||||
|
self.relpaths = self._filter_relpaths(self.relpaths)
|
||||||
|
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
|
||||||
|
|
||||||
|
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
||||||
|
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
||||||
|
|
||||||
|
unique_synsets = np.unique(self.synsets)
|
||||||
|
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
||||||
|
if not self.keep_orig_class_label:
|
||||||
|
self.class_labels = [class_dict[s] for s in self.synsets]
|
||||||
|
else:
|
||||||
|
self.class_labels = [self.synset2idx[s] for s in self.synsets]
|
||||||
|
|
||||||
|
with open(self.human_dict, "r") as f:
|
||||||
|
human_dict = f.read().splitlines()
|
||||||
|
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
||||||
|
|
||||||
|
self.human_labels = [human_dict[s] for s in self.synsets]
|
||||||
|
|
||||||
|
labels = {
|
||||||
|
"relpath": np.array(self.relpaths),
|
||||||
|
"synsets": np.array(self.synsets),
|
||||||
|
"class_label": np.array(self.class_labels),
|
||||||
|
"human_label": np.array(self.human_labels),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.process_images:
|
||||||
|
self.size = retrieve(self.config, "size", default=256)
|
||||||
|
self.data = ImagePaths(self.abspaths,
|
||||||
|
labels=labels,
|
||||||
|
size=self.size,
|
||||||
|
random_crop=self.random_crop,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data = self.abspaths
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetTrain(ImageNetBase):
|
||||||
|
NAME = "ILSVRC2012_train"
|
||||||
|
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
||||||
|
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
||||||
|
FILES = [
|
||||||
|
"ILSVRC2012_img_train.tar",
|
||||||
|
]
|
||||||
|
SIZES = [
|
||||||
|
147897477120,
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, process_images=True, data_root=None, **kwargs):
|
||||||
|
self.process_images = process_images
|
||||||
|
self.data_root = data_root
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def _prepare(self):
|
||||||
|
if self.data_root:
|
||||||
|
self.root = os.path.join(self.data_root, self.NAME)
|
||||||
|
else:
|
||||||
|
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
||||||
|
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
||||||
|
|
||||||
|
self.datadir = os.path.join(self.root, "data")
|
||||||
|
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||||
|
self.expected_length = 1281167
|
||||||
|
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
||||||
|
default=True)
|
||||||
|
if not tdu.is_prepared(self.root):
|
||||||
|
# prep
|
||||||
|
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||||
|
|
||||||
|
datadir = self.datadir
|
||||||
|
if not os.path.exists(datadir):
|
||||||
|
path = os.path.join(self.root, self.FILES[0])
|
||||||
|
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
||||||
|
import academictorrents as at
|
||||||
|
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||||
|
assert atpath == path
|
||||||
|
|
||||||
|
print("Extracting {} to {}".format(path, datadir))
|
||||||
|
os.makedirs(datadir, exist_ok=True)
|
||||||
|
with tarfile.open(path, "r:") as tar:
|
||||||
|
tar.extractall(path=datadir)
|
||||||
|
|
||||||
|
print("Extracting sub-tars.")
|
||||||
|
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
||||||
|
for subpath in tqdm(subpaths):
|
||||||
|
subdir = subpath[:-len(".tar")]
|
||||||
|
os.makedirs(subdir, exist_ok=True)
|
||||||
|
with tarfile.open(subpath, "r:") as tar:
|
||||||
|
tar.extractall(path=subdir)
|
||||||
|
|
||||||
|
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||||
|
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||||
|
filelist = sorted(filelist)
|
||||||
|
filelist = "\n".join(filelist)+"\n"
|
||||||
|
with open(self.txt_filelist, "w") as f:
|
||||||
|
f.write(filelist)
|
||||||
|
|
||||||
|
tdu.mark_prepared(self.root)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetValidation(ImageNetBase):
|
||||||
|
NAME = "ILSVRC2012_validation"
|
||||||
|
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
||||||
|
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
||||||
|
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
||||||
|
FILES = [
|
||||||
|
"ILSVRC2012_img_val.tar",
|
||||||
|
"validation_synset.txt",
|
||||||
|
]
|
||||||
|
SIZES = [
|
||||||
|
6744924160,
|
||||||
|
1950000,
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, process_images=True, data_root=None, **kwargs):
|
||||||
|
self.data_root = data_root
|
||||||
|
self.process_images = process_images
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def _prepare(self):
|
||||||
|
if self.data_root:
|
||||||
|
self.root = os.path.join(self.data_root, self.NAME)
|
||||||
|
else:
|
||||||
|
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
||||||
|
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
||||||
|
self.datadir = os.path.join(self.root, "data")
|
||||||
|
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||||
|
self.expected_length = 50000
|
||||||
|
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
||||||
|
default=False)
|
||||||
|
if not tdu.is_prepared(self.root):
|
||||||
|
# prep
|
||||||
|
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||||
|
|
||||||
|
datadir = self.datadir
|
||||||
|
if not os.path.exists(datadir):
|
||||||
|
path = os.path.join(self.root, self.FILES[0])
|
||||||
|
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
||||||
|
import academictorrents as at
|
||||||
|
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||||
|
assert atpath == path
|
||||||
|
|
||||||
|
print("Extracting {} to {}".format(path, datadir))
|
||||||
|
os.makedirs(datadir, exist_ok=True)
|
||||||
|
with tarfile.open(path, "r:") as tar:
|
||||||
|
tar.extractall(path=datadir)
|
||||||
|
|
||||||
|
vspath = os.path.join(self.root, self.FILES[1])
|
||||||
|
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
||||||
|
download(self.VS_URL, vspath)
|
||||||
|
|
||||||
|
with open(vspath, "r") as f:
|
||||||
|
synset_dict = f.read().splitlines()
|
||||||
|
synset_dict = dict(line.split() for line in synset_dict)
|
||||||
|
|
||||||
|
print("Reorganizing into synset folders")
|
||||||
|
synsets = np.unique(list(synset_dict.values()))
|
||||||
|
for s in synsets:
|
||||||
|
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
||||||
|
for k, v in synset_dict.items():
|
||||||
|
src = os.path.join(datadir, k)
|
||||||
|
dst = os.path.join(datadir, v)
|
||||||
|
shutil.move(src, dst)
|
||||||
|
|
||||||
|
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||||
|
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||||
|
filelist = sorted(filelist)
|
||||||
|
filelist = "\n".join(filelist)+"\n"
|
||||||
|
with open(self.txt_filelist, "w") as f:
|
||||||
|
f.write(filelist)
|
||||||
|
|
||||||
|
tdu.mark_prepared(self.root)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetSR(Dataset):
|
||||||
|
def __init__(self, size=None,
|
||||||
|
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
|
||||||
|
random_crop=True):
|
||||||
|
"""
|
||||||
|
Imagenet Superresolution Dataloader
|
||||||
|
Performs following ops in order:
|
||||||
|
1. crops a crop of size s from image either as random or center crop
|
||||||
|
2. resizes crop to size with cv2.area_interpolation
|
||||||
|
3. degrades resized crop with degradation_fn
|
||||||
|
|
||||||
|
:param size: resizing to size after cropping
|
||||||
|
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
|
||||||
|
:param downscale_f: Low Resolution Downsample factor
|
||||||
|
:param min_crop_f: determines crop size s,
|
||||||
|
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
|
||||||
|
:param max_crop_f: ""
|
||||||
|
:param data_root:
|
||||||
|
:param random_crop:
|
||||||
|
"""
|
||||||
|
self.base = self.get_base()
|
||||||
|
assert size
|
||||||
|
assert (size / downscale_f).is_integer()
|
||||||
|
self.size = size
|
||||||
|
self.LR_size = int(size / downscale_f)
|
||||||
|
self.min_crop_f = min_crop_f
|
||||||
|
self.max_crop_f = max_crop_f
|
||||||
|
assert(max_crop_f <= 1.)
|
||||||
|
self.center_crop = not random_crop
|
||||||
|
|
||||||
|
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
|
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
||||||
|
|
||||||
|
if degradation == "bsrgan":
|
||||||
|
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
||||||
|
|
||||||
|
elif degradation == "bsrgan_light":
|
||||||
|
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
|
||||||
|
|
||||||
|
else:
|
||||||
|
interpolation_fn = {
|
||||||
|
"cv_nearest": cv2.INTER_NEAREST,
|
||||||
|
"cv_bilinear": cv2.INTER_LINEAR,
|
||||||
|
"cv_bicubic": cv2.INTER_CUBIC,
|
||||||
|
"cv_area": cv2.INTER_AREA,
|
||||||
|
"cv_lanczos": cv2.INTER_LANCZOS4,
|
||||||
|
"pil_nearest": PIL.Image.NEAREST,
|
||||||
|
"pil_bilinear": PIL.Image.BILINEAR,
|
||||||
|
"pil_bicubic": PIL.Image.BICUBIC,
|
||||||
|
"pil_box": PIL.Image.BOX,
|
||||||
|
"pil_hamming": PIL.Image.HAMMING,
|
||||||
|
"pil_lanczos": PIL.Image.LANCZOS,
|
||||||
|
}[degradation]
|
||||||
|
|
||||||
|
self.pil_interpolation = degradation.startswith("pil_")
|
||||||
|
|
||||||
|
if self.pil_interpolation:
|
||||||
|
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
|
||||||
|
interpolation=interpolation_fn)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.base)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
example = self.base[i]
|
||||||
|
image = Image.open(example["file_path_"])
|
||||||
|
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
|
||||||
|
min_side_len = min(image.shape[:2])
|
||||||
|
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
|
||||||
|
crop_side_len = int(crop_side_len)
|
||||||
|
|
||||||
|
if self.center_crop:
|
||||||
|
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
|
||||||
|
|
||||||
|
image = self.cropper(image=image)["image"]
|
||||||
|
image = self.image_rescaler(image=image)["image"]
|
||||||
|
|
||||||
|
if self.pil_interpolation:
|
||||||
|
image_pil = PIL.Image.fromarray(image)
|
||||||
|
LR_image = self.degradation_process(image_pil)
|
||||||
|
LR_image = np.array(LR_image).astype(np.uint8)
|
||||||
|
|
||||||
|
else:
|
||||||
|
LR_image = self.degradation_process(image=image)["image"]
|
||||||
|
|
||||||
|
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
||||||
|
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
||||||
|
example["caption"] = example["human_label"] # dummy caption
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetSRTrain(ImageNetSR):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def get_base(self):
|
||||||
|
with open("data/imagenet_train_hr_indices.p", "rb") as f:
|
||||||
|
indices = pickle.load(f)
|
||||||
|
dset = ImageNetTrain(process_images=False,)
|
||||||
|
return Subset(dset, indices)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetSRValidation(ImageNetSR):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def get_base(self):
|
||||||
|
with open("data/imagenet_val_hr_indices.p", "rb") as f:
|
||||||
|
indices = pickle.load(f)
|
||||||
|
dset = ImageNetValidation(process_images=False,)
|
||||||
|
return Subset(dset, indices)
|
||||||
0
ldm/data/inpainting/__init__.py
Normal file
166
ldm/data/inpainting/synthetic_mask.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
from PIL import Image, ImageDraw
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
settings = {
|
||||||
|
"256narrow": {
|
||||||
|
"p_irr": 1,
|
||||||
|
"min_n_irr": 4,
|
||||||
|
"max_n_irr": 50,
|
||||||
|
"max_l_irr": 40,
|
||||||
|
"max_w_irr": 10,
|
||||||
|
"min_n_box": None,
|
||||||
|
"max_n_box": None,
|
||||||
|
"min_s_box": None,
|
||||||
|
"max_s_box": None,
|
||||||
|
"marg": None,
|
||||||
|
},
|
||||||
|
"256train": {
|
||||||
|
"p_irr": 0.5,
|
||||||
|
"min_n_irr": 1,
|
||||||
|
"max_n_irr": 5,
|
||||||
|
"max_l_irr": 200,
|
||||||
|
"max_w_irr": 100,
|
||||||
|
"min_n_box": 1,
|
||||||
|
"max_n_box": 4,
|
||||||
|
"min_s_box": 30,
|
||||||
|
"max_s_box": 150,
|
||||||
|
"marg": 10,
|
||||||
|
},
|
||||||
|
"512train": { # TODO: experimental
|
||||||
|
"p_irr": 0.5,
|
||||||
|
"min_n_irr": 1,
|
||||||
|
"max_n_irr": 5,
|
||||||
|
"max_l_irr": 450,
|
||||||
|
"max_w_irr": 250,
|
||||||
|
"min_n_box": 1,
|
||||||
|
"max_n_box": 4,
|
||||||
|
"min_s_box": 30,
|
||||||
|
"max_s_box": 300,
|
||||||
|
"marg": 10,
|
||||||
|
},
|
||||||
|
"512train-large": { # TODO: experimental
|
||||||
|
"p_irr": 0.5,
|
||||||
|
"min_n_irr": 1,
|
||||||
|
"max_n_irr": 5,
|
||||||
|
"max_l_irr": 450,
|
||||||
|
"max_w_irr": 400,
|
||||||
|
"min_n_box": 1,
|
||||||
|
"max_n_box": 4,
|
||||||
|
"min_s_box": 75,
|
||||||
|
"max_s_box": 450,
|
||||||
|
"marg": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def gen_segment_mask(mask, start, end, brush_width):
|
||||||
|
mask = mask > 0
|
||||||
|
mask = (255 * mask).astype(np.uint8)
|
||||||
|
mask = Image.fromarray(mask)
|
||||||
|
draw = ImageDraw.Draw(mask)
|
||||||
|
draw.line([start, end], fill=255, width=brush_width, joint="curve")
|
||||||
|
mask = np.array(mask) / 255
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def gen_box_mask(mask, masked):
|
||||||
|
x_0, y_0, w, h = masked
|
||||||
|
mask[y_0:y_0 + h, x_0:x_0 + w] = 1
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def gen_round_mask(mask, masked, radius):
|
||||||
|
x_0, y_0, w, h = masked
|
||||||
|
xy = [(x_0, y_0), (x_0 + w, y_0 + w)]
|
||||||
|
|
||||||
|
mask = mask > 0
|
||||||
|
mask = (255 * mask).astype(np.uint8)
|
||||||
|
mask = Image.fromarray(mask)
|
||||||
|
draw = ImageDraw.Draw(mask)
|
||||||
|
draw.rounded_rectangle(xy, radius=radius, fill=255)
|
||||||
|
mask = np.array(mask) / 255
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def gen_large_mask(prng, img_h, img_w,
|
||||||
|
marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr,
|
||||||
|
min_n_box, max_n_box, min_s_box, max_s_box):
|
||||||
|
"""
|
||||||
|
img_h: int, an image height
|
||||||
|
img_w: int, an image width
|
||||||
|
marg: int, a margin for a box starting coordinate
|
||||||
|
p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask
|
||||||
|
|
||||||
|
min_n_irr: int, min number of segments
|
||||||
|
max_n_irr: int, max number of segments
|
||||||
|
max_l_irr: max length of a segment in polygonal chain
|
||||||
|
max_w_irr: max width of a segment in polygonal chain
|
||||||
|
|
||||||
|
min_n_box: int, min bound for the number of box primitives
|
||||||
|
max_n_box: int, max bound for the number of box primitives
|
||||||
|
min_s_box: int, min length of a box side
|
||||||
|
max_s_box: int, max length of a box side
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask = np.zeros((img_h, img_w))
|
||||||
|
uniform = prng.randint
|
||||||
|
|
||||||
|
if np.random.uniform(0, 1) < p_irr: # generate polygonal chain
|
||||||
|
n = uniform(min_n_irr, max_n_irr) # sample number of segments
|
||||||
|
|
||||||
|
for _ in range(n):
|
||||||
|
y = uniform(0, img_h) # sample a starting point
|
||||||
|
x = uniform(0, img_w)
|
||||||
|
|
||||||
|
a = uniform(0, 360) # sample angle
|
||||||
|
l = uniform(10, max_l_irr) # sample segment length
|
||||||
|
w = uniform(5, max_w_irr) # sample a segment width
|
||||||
|
|
||||||
|
# draw segment starting from (x,y) to (x_,y_) using brush of width w
|
||||||
|
x_ = x + l * np.sin(a)
|
||||||
|
y_ = y + l * np.cos(a)
|
||||||
|
|
||||||
|
mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w)
|
||||||
|
x, y = x_, y_
|
||||||
|
else: # generate Box masks
|
||||||
|
n = uniform(min_n_box, max_n_box) # sample number of rectangles
|
||||||
|
|
||||||
|
for _ in range(n):
|
||||||
|
h = uniform(min_s_box, max_s_box) # sample box shape
|
||||||
|
w = uniform(min_s_box, max_s_box)
|
||||||
|
|
||||||
|
x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box
|
||||||
|
y_0 = uniform(marg, img_h - marg - h)
|
||||||
|
|
||||||
|
if np.random.uniform(0, 1) < 0.5:
|
||||||
|
mask = gen_box_mask(mask, masked=(x_0, y_0, w, h))
|
||||||
|
else:
|
||||||
|
r = uniform(0, 60) # sample radius
|
||||||
|
mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"])
|
||||||
|
make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"])
|
||||||
|
make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"])
|
||||||
|
make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"])
|
||||||
|
|
||||||
|
|
||||||
|
MASK_MODES = {
|
||||||
|
"256train": make_lama_mask,
|
||||||
|
"256narrow": make_narrow_lama_mask,
|
||||||
|
"512train": make_512_lama_mask,
|
||||||
|
"512train-large": make_512_lama_mask_large
|
||||||
|
}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
out = sys.argv[1]
|
||||||
|
|
||||||
|
prng = np.random.RandomState(1)
|
||||||
|
kwargs = settings["256train"]
|
||||||
|
mask = gen_large_mask(prng, 256, 256, **kwargs)
|
||||||
|
mask = (255 * mask).astype(np.uint8)
|
||||||
|
mask = Image.fromarray(mask)
|
||||||
|
mask.save(out)
|
||||||
537
ldm/data/laion.py
Normal file
@ -0,0 +1,537 @@
|
|||||||
|
import webdataset as wds
|
||||||
|
import kornia
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import torchvision
|
||||||
|
from PIL import Image
|
||||||
|
import glob
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from tqdm import tqdm
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from einops import rearrange
|
||||||
|
import torch
|
||||||
|
from webdataset.handlers import warn_and_continue
|
||||||
|
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.data.inpainting.synthetic_mask import gen_large_mask, MASK_MODES
|
||||||
|
from ldm.data.base import PRNGMixin
|
||||||
|
|
||||||
|
|
||||||
|
class DataWithWings(torch.utils.data.IterableDataset):
|
||||||
|
def __init__(self, min_size, transform=None, target_transform=None):
|
||||||
|
self.min_size = min_size
|
||||||
|
self.transform = transform if transform is not None else nn.Identity()
|
||||||
|
self.target_transform = target_transform if target_transform is not None else nn.Identity()
|
||||||
|
self.kv = OnDiskKV(file='/home/ubuntu/laion5B-watermark-safety-ordered', key_format='q', value_format='ee')
|
||||||
|
self.kv_aesthetic = OnDiskKV(file='/home/ubuntu/laion5B-aesthetic-tags-kv', key_format='q', value_format='e')
|
||||||
|
self.pwatermark_threshold = 0.8
|
||||||
|
self.punsafe_threshold = 0.5
|
||||||
|
self.aesthetic_threshold = 5.
|
||||||
|
self.total_samples = 0
|
||||||
|
self.samples = 0
|
||||||
|
location = 'pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -'
|
||||||
|
|
||||||
|
self.inner_dataset = wds.DataPipeline(
|
||||||
|
wds.ResampledShards(location),
|
||||||
|
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
||||||
|
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||||
|
wds.decode('pilrgb', handler=wds.warn_and_continue),
|
||||||
|
wds.map(self._add_tags, handler=wds.ignore_and_continue),
|
||||||
|
wds.select(self._filter_predicate),
|
||||||
|
wds.map_dict(jpg=self.transform, txt=self.target_transform, punsafe=self._punsafe_to_class, handler=wds.warn_and_continue),
|
||||||
|
wds.to_tuple('jpg', 'txt', 'punsafe', handler=wds.warn_and_continue),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_hash(url, text):
|
||||||
|
if url is None:
|
||||||
|
url = ''
|
||||||
|
if text is None:
|
||||||
|
text = ''
|
||||||
|
total = (url + text).encode('utf-8')
|
||||||
|
return mmh3.hash64(total)[0]
|
||||||
|
|
||||||
|
def _add_tags(self, x):
|
||||||
|
hsh = self._compute_hash(x['json']['url'], x['txt'])
|
||||||
|
pwatermark, punsafe = self.kv[hsh]
|
||||||
|
aesthetic = self.kv_aesthetic[hsh][0]
|
||||||
|
return {**x, 'pwatermark': pwatermark, 'punsafe': punsafe, 'aesthetic': aesthetic}
|
||||||
|
|
||||||
|
def _punsafe_to_class(self, punsafe):
|
||||||
|
return torch.tensor(punsafe >= self.punsafe_threshold).long()
|
||||||
|
|
||||||
|
def _filter_predicate(self, x):
|
||||||
|
try:
|
||||||
|
return x['pwatermark'] < self.pwatermark_threshold and x['aesthetic'] >= self.aesthetic_threshold and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.inner_dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
||||||
|
"""Take a list of samples (as dictionary) and create a batch, preserving the keys.
|
||||||
|
If `tensors` is True, `ndarray` objects are combined into
|
||||||
|
tensor batches.
|
||||||
|
:param dict samples: list of samples
|
||||||
|
:param bool tensors: whether to turn lists of ndarrays into a single ndarray
|
||||||
|
:returns: single sample consisting of a batch
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
keys = set.intersection(*[set(sample.keys()) for sample in samples])
|
||||||
|
batched = {key: [] for key in keys}
|
||||||
|
|
||||||
|
for s in samples:
|
||||||
|
[batched[key].append(s[key]) for key in batched]
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for key in batched:
|
||||||
|
if isinstance(batched[key][0], (int, float)):
|
||||||
|
if combine_scalars:
|
||||||
|
result[key] = np.array(list(batched[key]))
|
||||||
|
elif isinstance(batched[key][0], torch.Tensor):
|
||||||
|
if combine_tensors:
|
||||||
|
result[key] = torch.stack(list(batched[key]))
|
||||||
|
elif isinstance(batched[key][0], np.ndarray):
|
||||||
|
if combine_tensors:
|
||||||
|
result[key] = np.array(list(batched[key]))
|
||||||
|
else:
|
||||||
|
result[key] = list(batched[key])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class WebDataModuleFromConfig(pl.LightningDataModule):
|
||||||
|
def __init__(self, tar_base, batch_size, train=None, validation=None,
|
||||||
|
test=None, num_workers=4, multinode=True, min_size=None,
|
||||||
|
max_pwatermark=1.0,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(self)
|
||||||
|
print(f'Setting tar base to {tar_base}')
|
||||||
|
self.tar_base = tar_base
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.train = train
|
||||||
|
self.validation = validation
|
||||||
|
self.test = test
|
||||||
|
self.multinode = multinode
|
||||||
|
self.min_size = min_size # filter out very small images
|
||||||
|
self.max_pwatermark = max_pwatermark # filter out watermarked images
|
||||||
|
|
||||||
|
def make_loader(self, dataset_config, train=True):
|
||||||
|
if 'image_transforms' in dataset_config:
|
||||||
|
image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms]
|
||||||
|
else:
|
||||||
|
image_transforms = []
|
||||||
|
|
||||||
|
image_transforms.extend([torchvision.transforms.ToTensor(),
|
||||||
|
torchvision.transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||||
|
image_transforms = torchvision.transforms.Compose(image_transforms)
|
||||||
|
|
||||||
|
if 'transforms' in dataset_config:
|
||||||
|
transforms_config = OmegaConf.to_container(dataset_config.transforms)
|
||||||
|
else:
|
||||||
|
transforms_config = dict()
|
||||||
|
|
||||||
|
transform_dict = {dkey: load_partial_from_config(transforms_config[dkey])
|
||||||
|
if transforms_config[dkey] != 'identity' else identity
|
||||||
|
for dkey in transforms_config}
|
||||||
|
img_key = dataset_config.get('image_key', 'jpeg')
|
||||||
|
transform_dict.update({img_key: image_transforms})
|
||||||
|
|
||||||
|
if 'postprocess' in dataset_config:
|
||||||
|
postprocess = instantiate_from_config(dataset_config['postprocess'])
|
||||||
|
else:
|
||||||
|
postprocess = None
|
||||||
|
|
||||||
|
shuffle = dataset_config.get('shuffle', 0)
|
||||||
|
shardshuffle = shuffle > 0
|
||||||
|
|
||||||
|
nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only
|
||||||
|
|
||||||
|
if self.tar_base == "__improvedaesthetic__":
|
||||||
|
print("## Warning, loading the same improved aesthetic dataset "
|
||||||
|
"for all splits and ignoring shards parameter.")
|
||||||
|
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
|
||||||
|
else:
|
||||||
|
tars = os.path.join(self.tar_base, dataset_config.shards)
|
||||||
|
|
||||||
|
dset = wds.WebDataset(
|
||||||
|
tars,
|
||||||
|
nodesplitter=nodesplitter,
|
||||||
|
shardshuffle=shardshuffle,
|
||||||
|
handler=wds.warn_and_continue).repeat().shuffle(shuffle)
|
||||||
|
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
|
||||||
|
|
||||||
|
dset = (dset
|
||||||
|
.select(self.filter_keys)
|
||||||
|
.decode('pil', handler=wds.warn_and_continue)
|
||||||
|
.select(self.filter_size)
|
||||||
|
.map_dict(**transform_dict, handler=wds.warn_and_continue)
|
||||||
|
)
|
||||||
|
if postprocess is not None:
|
||||||
|
dset = dset.map(postprocess)
|
||||||
|
dset = (dset
|
||||||
|
.batched(self.batch_size, partial=False,
|
||||||
|
collation_fn=dict_collation_fn)
|
||||||
|
)
|
||||||
|
|
||||||
|
loader = wds.WebLoader(dset, batch_size=None, shuffle=False,
|
||||||
|
num_workers=self.num_workers)
|
||||||
|
|
||||||
|
return loader
|
||||||
|
|
||||||
|
def filter_size(self, x):
|
||||||
|
try:
|
||||||
|
valid = True
|
||||||
|
if self.min_size is not None and self.min_size > 1:
|
||||||
|
try:
|
||||||
|
valid = valid and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
|
||||||
|
except Exception:
|
||||||
|
valid = False
|
||||||
|
if self.max_pwatermark is not None and self.max_pwatermark < 1.0:
|
||||||
|
try:
|
||||||
|
valid = valid and x['json']['pwatermark'] <= self.max_pwatermark
|
||||||
|
except Exception:
|
||||||
|
valid = False
|
||||||
|
return valid
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def filter_keys(self, x):
|
||||||
|
try:
|
||||||
|
return ("jpg" in x) and ("txt" in x)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return self.make_loader(self.train)
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
return self.make_loader(self.validation, train=False)
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
return self.make_loader(self.test, train=False)
|
||||||
|
|
||||||
|
|
||||||
|
from ldm.modules.image_degradation import degradation_fn_bsr_light
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
class AddLR(object):
|
||||||
|
def __init__(self, factor, output_size, initial_size=None, image_key="jpg"):
|
||||||
|
self.factor = factor
|
||||||
|
self.output_size = output_size
|
||||||
|
self.image_key = image_key
|
||||||
|
self.initial_size = initial_size
|
||||||
|
|
||||||
|
def pt2np(self, x):
|
||||||
|
x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def np2pt(self, x):
|
||||||
|
x = torch.from_numpy(x)/127.5-1.0
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||||
|
x = self.pt2np(sample[self.image_key])
|
||||||
|
if self.initial_size is not None:
|
||||||
|
x = cv2.resize(x, (self.initial_size, self.initial_size), interpolation=2)
|
||||||
|
x = degradation_fn_bsr_light(x, sf=self.factor)['image']
|
||||||
|
x = cv2.resize(x, (self.output_size, self.output_size), interpolation=2)
|
||||||
|
x = self.np2pt(x)
|
||||||
|
sample['lr'] = x
|
||||||
|
return sample
|
||||||
|
|
||||||
|
class AddBW(object):
|
||||||
|
def __init__(self, image_key="jpg"):
|
||||||
|
self.image_key = image_key
|
||||||
|
|
||||||
|
def pt2np(self, x):
|
||||||
|
x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def np2pt(self, x):
|
||||||
|
x = torch.from_numpy(x)/127.5-1.0
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||||
|
x = sample[self.image_key]
|
||||||
|
w = torch.rand(3, device=x.device)
|
||||||
|
w /= w.sum()
|
||||||
|
out = torch.einsum('hwc,c->hw', x, w)
|
||||||
|
|
||||||
|
# Keep as 3ch so we can pass to encoder, also we might want to add hints
|
||||||
|
sample['lr'] = out.unsqueeze(-1).tile(1,1,3)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
class AddMask(PRNGMixin):
|
||||||
|
def __init__(self, mode="512train", p_drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
|
||||||
|
self.make_mask = MASK_MODES[mode]
|
||||||
|
self.p_drop = p_drop
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||||
|
x = sample['jpg']
|
||||||
|
mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
|
||||||
|
if self.prng.choice(2, p=[1 - self.p_drop, self.p_drop]):
|
||||||
|
mask = np.ones_like(mask)
|
||||||
|
mask[mask < 0.5] = 0
|
||||||
|
mask[mask > 0.5] = 1
|
||||||
|
mask = torch.from_numpy(mask[..., None])
|
||||||
|
sample['mask'] = mask
|
||||||
|
sample['masked_image'] = x * (mask < 0.5)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class AddEdge(PRNGMixin):
|
||||||
|
def __init__(self, mode="512train", mask_edges=True):
|
||||||
|
super().__init__()
|
||||||
|
assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
|
||||||
|
self.make_mask = MASK_MODES[mode]
|
||||||
|
self.n_down_choices = [0]
|
||||||
|
self.sigma_choices = [1, 2]
|
||||||
|
self.mask_edges = mask_edges
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, sample):
|
||||||
|
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||||
|
x = sample['jpg']
|
||||||
|
|
||||||
|
mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
|
||||||
|
mask[mask < 0.5] = 0
|
||||||
|
mask[mask > 0.5] = 1
|
||||||
|
mask = torch.from_numpy(mask[..., None])
|
||||||
|
sample['mask'] = mask
|
||||||
|
|
||||||
|
n_down_idx = self.prng.choice(len(self.n_down_choices))
|
||||||
|
sigma_idx = self.prng.choice(len(self.sigma_choices))
|
||||||
|
|
||||||
|
n_choices = len(self.n_down_choices)*len(self.sigma_choices)
|
||||||
|
raveled_idx = np.ravel_multi_index((n_down_idx, sigma_idx),
|
||||||
|
(len(self.n_down_choices), len(self.sigma_choices)))
|
||||||
|
normalized_idx = raveled_idx/max(1, n_choices-1)
|
||||||
|
|
||||||
|
n_down = self.n_down_choices[n_down_idx]
|
||||||
|
sigma = self.sigma_choices[sigma_idx]
|
||||||
|
|
||||||
|
kernel_size = 4*sigma+1
|
||||||
|
kernel_size = (kernel_size, kernel_size)
|
||||||
|
sigma = (sigma, sigma)
|
||||||
|
canny = kornia.filters.Canny(
|
||||||
|
low_threshold=0.1,
|
||||||
|
high_threshold=0.2,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
sigma=sigma,
|
||||||
|
hysteresis=True,
|
||||||
|
)
|
||||||
|
y = (x+1.0)/2.0 # in 01
|
||||||
|
y = y.unsqueeze(0).permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
# down
|
||||||
|
for i_down in range(n_down):
|
||||||
|
size = min(y.shape[-2], y.shape[-1])//2
|
||||||
|
y = kornia.geometry.transform.resize(y, size, antialias=True)
|
||||||
|
|
||||||
|
# edge
|
||||||
|
_, y = canny(y)
|
||||||
|
|
||||||
|
if n_down > 0:
|
||||||
|
size = x.shape[0], x.shape[1]
|
||||||
|
y = kornia.geometry.transform.resize(y, size, interpolation="nearest")
|
||||||
|
|
||||||
|
y = y.permute(0, 2, 3, 1)[0].expand(-1, -1, 3).contiguous()
|
||||||
|
y = y*2.0-1.0
|
||||||
|
|
||||||
|
if self.mask_edges:
|
||||||
|
sample['masked_image'] = y * (mask < 0.5)
|
||||||
|
else:
|
||||||
|
sample['masked_image'] = y
|
||||||
|
sample['mask'] = torch.zeros_like(sample['mask'])
|
||||||
|
|
||||||
|
# concat normalized idx
|
||||||
|
sample['smoothing_strength'] = torch.ones_like(sample['mask'])*normalized_idx
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def example00():
|
||||||
|
url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -"
|
||||||
|
dataset = wds.WebDataset(url)
|
||||||
|
example = next(iter(dataset))
|
||||||
|
for k in example:
|
||||||
|
print(k, type(example[k]))
|
||||||
|
|
||||||
|
print(example["__key__"])
|
||||||
|
for k in ["json", "txt"]:
|
||||||
|
print(example[k].decode())
|
||||||
|
|
||||||
|
image = Image.open(io.BytesIO(example["jpg"]))
|
||||||
|
outdir = "tmp"
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
image.save(os.path.join(outdir, example["__key__"] + ".png"))
|
||||||
|
|
||||||
|
|
||||||
|
def load_example(example):
|
||||||
|
return {
|
||||||
|
"key": example["__key__"],
|
||||||
|
"image": Image.open(io.BytesIO(example["jpg"])),
|
||||||
|
"text": example["txt"].decode(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for i, example in tqdm(enumerate(dataset)):
|
||||||
|
ex = load_example(example)
|
||||||
|
print(ex["image"].size, ex["text"])
|
||||||
|
if i >= 100:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def example01():
|
||||||
|
# the first laion shards contain ~10k examples each
|
||||||
|
url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/{000000..000002}.tar -"
|
||||||
|
|
||||||
|
batch_size = 3
|
||||||
|
shuffle_buffer = 10000
|
||||||
|
dset = wds.WebDataset(
|
||||||
|
url,
|
||||||
|
nodesplitter=wds.shardlists.split_by_node,
|
||||||
|
shardshuffle=True,
|
||||||
|
)
|
||||||
|
dset = (dset
|
||||||
|
.shuffle(shuffle_buffer, initial=shuffle_buffer)
|
||||||
|
.decode('pil', handler=warn_and_continue)
|
||||||
|
.batched(batch_size, partial=False,
|
||||||
|
collation_fn=dict_collation_fn)
|
||||||
|
)
|
||||||
|
|
||||||
|
num_workers = 2
|
||||||
|
loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=num_workers)
|
||||||
|
|
||||||
|
batch_sizes = list()
|
||||||
|
keys_per_epoch = list()
|
||||||
|
for epoch in range(5):
|
||||||
|
keys = list()
|
||||||
|
for batch in tqdm(loader):
|
||||||
|
batch_sizes.append(len(batch["__key__"]))
|
||||||
|
keys.append(batch["__key__"])
|
||||||
|
|
||||||
|
for bs in batch_sizes:
|
||||||
|
assert bs==batch_size
|
||||||
|
print(f"{len(batch_sizes)} batches of size {batch_size}.")
|
||||||
|
batch_sizes = list()
|
||||||
|
|
||||||
|
keys_per_epoch.append(keys)
|
||||||
|
for i_batch in [0, 1, -1]:
|
||||||
|
print(f"Batch {i_batch} of epoch {epoch}:")
|
||||||
|
print(keys[i_batch])
|
||||||
|
print("next epoch.")
|
||||||
|
|
||||||
|
|
||||||
|
def example02():
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from torch.utils.data import IterableDataset
|
||||||
|
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
|
||||||
|
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
||||||
|
|
||||||
|
#config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
|
||||||
|
#config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
|
||||||
|
config = OmegaConf.load("configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml")
|
||||||
|
datamod = WebDataModuleFromConfig(**config["data"]["params"])
|
||||||
|
dataloader = datamod.train_dataloader()
|
||||||
|
|
||||||
|
for batch in dataloader:
|
||||||
|
print(batch.keys())
|
||||||
|
print(batch["jpg"].shape)
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def example03():
|
||||||
|
# improved aesthetics
|
||||||
|
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
|
||||||
|
dataset = wds.WebDataset(tars)
|
||||||
|
|
||||||
|
def filter_keys(x):
|
||||||
|
try:
|
||||||
|
return ("jpg" in x) and ("txt" in x)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def filter_size(x):
|
||||||
|
try:
|
||||||
|
return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def filter_watermark(x):
|
||||||
|
try:
|
||||||
|
return x['json']['pwatermark'] < 0.5
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
dataset = (dataset
|
||||||
|
.select(filter_keys)
|
||||||
|
.decode('pil', handler=wds.warn_and_continue))
|
||||||
|
n_save = 20
|
||||||
|
n_total = 0
|
||||||
|
n_large = 0
|
||||||
|
n_large_nowm = 0
|
||||||
|
for i, example in enumerate(dataset):
|
||||||
|
n_total += 1
|
||||||
|
if filter_size(example):
|
||||||
|
n_large += 1
|
||||||
|
if filter_watermark(example):
|
||||||
|
n_large_nowm += 1
|
||||||
|
if n_large_nowm < n_save+1:
|
||||||
|
image = example["jpg"]
|
||||||
|
image.save(os.path.join("tmp", f"{n_large_nowm-1:06}.png"))
|
||||||
|
|
||||||
|
if i%500 == 0:
|
||||||
|
print(i)
|
||||||
|
print(f"Large: {n_large}/{n_total} | {n_large/n_total*100:.2f}%")
|
||||||
|
if n_large > 0:
|
||||||
|
print(f"No Watermark: {n_large_nowm}/{n_large} | {n_large_nowm/n_large*100:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def example04():
|
||||||
|
# improved aesthetics
|
||||||
|
for i_shard in range(60208)[::-1]:
|
||||||
|
print(i_shard)
|
||||||
|
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{:06}.tar -".format(i_shard)
|
||||||
|
dataset = wds.WebDataset(tars)
|
||||||
|
|
||||||
|
def filter_keys(x):
|
||||||
|
try:
|
||||||
|
return ("jpg" in x) and ("txt" in x)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def filter_size(x):
|
||||||
|
try:
|
||||||
|
return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
dataset = (dataset
|
||||||
|
.select(filter_keys)
|
||||||
|
.decode('pil', handler=wds.warn_and_continue))
|
||||||
|
try:
|
||||||
|
example = next(iter(dataset))
|
||||||
|
except Exception:
|
||||||
|
print(f"Error @ {i_shard}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
#example01()
|
||||||
|
#example02()
|
||||||
|
example03()
|
||||||
|
#example04()
|
||||||
92
ldm/data/lsun.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNBase(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
txt_file,
|
||||||
|
data_root,
|
||||||
|
size=None,
|
||||||
|
interpolation="bicubic",
|
||||||
|
flip_p=0.5
|
||||||
|
):
|
||||||
|
self.data_paths = txt_file
|
||||||
|
self.data_root = data_root
|
||||||
|
with open(self.data_paths, "r") as f:
|
||||||
|
self.image_paths = f.read().splitlines()
|
||||||
|
self._length = len(self.image_paths)
|
||||||
|
self.labels = {
|
||||||
|
"relative_file_path_": [l for l in self.image_paths],
|
||||||
|
"file_path_": [os.path.join(self.data_root, l)
|
||||||
|
for l in self.image_paths],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||||
|
"bilinear": PIL.Image.BILINEAR,
|
||||||
|
"bicubic": PIL.Image.BICUBIC,
|
||||||
|
"lanczos": PIL.Image.LANCZOS,
|
||||||
|
}[interpolation]
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
example = dict((k, self.labels[k][i]) for k in self.labels)
|
||||||
|
image = Image.open(example["file_path_"])
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
# default to score-sde preprocessing
|
||||||
|
img = np.array(image).astype(np.uint8)
|
||||||
|
crop = min(img.shape[0], img.shape[1])
|
||||||
|
h, w, = img.shape[0], img.shape[1]
|
||||||
|
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||||
|
(w - crop) // 2:(w + crop) // 2]
|
||||||
|
|
||||||
|
image = Image.fromarray(img)
|
||||||
|
if self.size is not None:
|
||||||
|
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||||
|
|
||||||
|
image = self.flip(image)
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNChurchesTrain(LSUNBase):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNChurchesValidation(LSUNBase):
|
||||||
|
def __init__(self, flip_p=0., **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
||||||
|
flip_p=flip_p, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNBedroomsTrain(LSUNBase):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNBedroomsValidation(LSUNBase):
|
||||||
|
def __init__(self, flip_p=0.0, **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
||||||
|
flip_p=flip_p, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNCatsTrain(LSUNBase):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNCatsValidation(LSUNBase):
|
||||||
|
def __init__(self, flip_p=0., **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
||||||
|
flip_p=flip_p, **kwargs)
|
||||||
393
ldm/data/simple.py
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
from typing import Dict
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import DictConfig, ListConfig
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from einops import rearrange
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from datasets import load_dataset
|
||||||
|
import copy
|
||||||
|
import csv
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# Some hacky things to make experimentation easier
|
||||||
|
def make_transform_multi_folder_data(paths, caption_files=None, **kwargs):
|
||||||
|
ds = make_multi_folder_data(paths, caption_files, **kwargs)
|
||||||
|
return TransformDataset(ds)
|
||||||
|
|
||||||
|
def make_nfp_data(base_path):
|
||||||
|
dirs = list(Path(base_path).glob("*/"))
|
||||||
|
print(f"Found {len(dirs)} folders")
|
||||||
|
print(dirs)
|
||||||
|
tforms = [transforms.Resize(512), transforms.CenterCrop(512)]
|
||||||
|
datasets = [NfpDataset(x, image_transforms=copy.copy(tforms), default_caption="A view from a train window") for x in dirs]
|
||||||
|
return torch.utils.data.ConcatDataset(datasets)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDataset(Dataset):
|
||||||
|
def __init__(self, root_dir, image_transforms, caption_file, offset=8, n=2):
|
||||||
|
self.root_dir = Path(root_dir)
|
||||||
|
self.caption_file = caption_file
|
||||||
|
self.n = n
|
||||||
|
ext = "mp4"
|
||||||
|
self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}")))
|
||||||
|
self.offset = offset
|
||||||
|
|
||||||
|
if isinstance(image_transforms, ListConfig):
|
||||||
|
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||||
|
image_transforms.extend([transforms.ToTensor(),
|
||||||
|
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||||
|
image_transforms = transforms.Compose(image_transforms)
|
||||||
|
self.tform = image_transforms
|
||||||
|
with open(self.caption_file) as f:
|
||||||
|
reader = csv.reader(f)
|
||||||
|
rows = [row for row in reader]
|
||||||
|
self.captions = dict(rows)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
for i in range(10):
|
||||||
|
try:
|
||||||
|
return self._load_sample(index)
|
||||||
|
except Exception:
|
||||||
|
# Not really good enough but...
|
||||||
|
print("uh oh")
|
||||||
|
|
||||||
|
def _load_sample(self, index):
|
||||||
|
n = self.n
|
||||||
|
filename = self.paths[index]
|
||||||
|
min_frame = 2*self.offset + 2
|
||||||
|
vid = cv2.VideoCapture(str(filename))
|
||||||
|
max_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
curr_frame_n = random.randint(min_frame, max_frames)
|
||||||
|
vid.set(cv2.CAP_PROP_POS_FRAMES,curr_frame_n)
|
||||||
|
_, curr_frame = vid.read()
|
||||||
|
|
||||||
|
prev_frames = []
|
||||||
|
for i in range(n):
|
||||||
|
prev_frame_n = curr_frame_n - (i+1)*self.offset
|
||||||
|
vid.set(cv2.CAP_PROP_POS_FRAMES,prev_frame_n)
|
||||||
|
_, prev_frame = vid.read()
|
||||||
|
prev_frame = self.tform(Image.fromarray(prev_frame[...,::-1]))
|
||||||
|
prev_frames.append(prev_frame)
|
||||||
|
|
||||||
|
vid.release()
|
||||||
|
caption = self.captions[filename.name]
|
||||||
|
data = {
|
||||||
|
"image": self.tform(Image.fromarray(curr_frame[...,::-1])),
|
||||||
|
"prev": torch.cat(prev_frames, dim=-1),
|
||||||
|
"txt": caption
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
# end hacky things
|
||||||
|
|
||||||
|
|
||||||
|
def make_tranforms(image_transforms):
|
||||||
|
if isinstance(image_transforms, ListConfig):
|
||||||
|
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||||
|
image_transforms.extend([transforms.ToTensor(),
|
||||||
|
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||||
|
image_transforms = transforms.Compose(image_transforms)
|
||||||
|
return image_transforms
|
||||||
|
|
||||||
|
|
||||||
|
def make_multi_folder_data(paths, caption_files=None, **kwargs):
|
||||||
|
"""Make a concat dataset from multiple folders
|
||||||
|
Don't suport captions yet
|
||||||
|
|
||||||
|
If paths is a list, that's ok, if it's a Dict interpret it as:
|
||||||
|
k=folder v=n_times to repeat that
|
||||||
|
"""
|
||||||
|
list_of_paths = []
|
||||||
|
if isinstance(paths, (Dict, DictConfig)):
|
||||||
|
assert caption_files is None, \
|
||||||
|
"Caption files not yet supported for repeats"
|
||||||
|
for folder_path, repeats in paths.items():
|
||||||
|
list_of_paths.extend([folder_path]*repeats)
|
||||||
|
paths = list_of_paths
|
||||||
|
|
||||||
|
if caption_files is not None:
|
||||||
|
datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)]
|
||||||
|
else:
|
||||||
|
datasets = [FolderData(p, **kwargs) for p in paths]
|
||||||
|
return torch.utils.data.ConcatDataset(datasets)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class NfpDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
root_dir,
|
||||||
|
image_transforms=[],
|
||||||
|
ext="jpg",
|
||||||
|
default_caption="",
|
||||||
|
) -> None:
|
||||||
|
"""assume sequential frames and a deterministic transform"""
|
||||||
|
|
||||||
|
self.root_dir = Path(root_dir)
|
||||||
|
self.default_caption = default_caption
|
||||||
|
|
||||||
|
self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}")))
|
||||||
|
self.tform = make_tranforms(image_transforms)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
prev = self.paths[index]
|
||||||
|
curr = self.paths[index+1]
|
||||||
|
data = {}
|
||||||
|
data["image"] = self._load_im(curr)
|
||||||
|
data["prev"] = self._load_im(prev)
|
||||||
|
data["txt"] = self.default_caption
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _load_im(self, filename):
|
||||||
|
im = Image.open(filename).convert("RGB")
|
||||||
|
return self.tform(im)
|
||||||
|
|
||||||
|
|
||||||
|
class FolderData(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
root_dir,
|
||||||
|
caption_file=None,
|
||||||
|
image_transforms=[],
|
||||||
|
ext="jpg",
|
||||||
|
default_caption="",
|
||||||
|
postprocess=None,
|
||||||
|
return_paths=False,
|
||||||
|
) -> None:
|
||||||
|
"""Create a dataset from a folder of images.
|
||||||
|
If you pass in a root directory it will be searched for images
|
||||||
|
ending in ext (ext can be a list)
|
||||||
|
"""
|
||||||
|
self.root_dir = Path(root_dir)
|
||||||
|
self.default_caption = default_caption
|
||||||
|
self.return_paths = return_paths
|
||||||
|
if isinstance(postprocess, DictConfig):
|
||||||
|
postprocess = instantiate_from_config(postprocess)
|
||||||
|
self.postprocess = postprocess
|
||||||
|
if caption_file is not None:
|
||||||
|
with open(caption_file, "rt") as f:
|
||||||
|
ext = Path(caption_file).suffix.lower()
|
||||||
|
if ext == ".json":
|
||||||
|
captions = json.load(f)
|
||||||
|
elif ext == ".jsonl":
|
||||||
|
lines = f.readlines()
|
||||||
|
lines = [json.loads(x) for x in lines]
|
||||||
|
captions = {x["file_name"]: x["text"].strip("\n") for x in lines}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognised format: {ext}")
|
||||||
|
self.captions = captions
|
||||||
|
else:
|
||||||
|
self.captions = None
|
||||||
|
|
||||||
|
if not isinstance(ext, (tuple, list, ListConfig)):
|
||||||
|
ext = [ext]
|
||||||
|
|
||||||
|
# Only used if there is no caption file
|
||||||
|
self.paths = []
|
||||||
|
for e in ext:
|
||||||
|
self.paths.extend(sorted(list(self.root_dir.rglob(f"*.{e}"))))
|
||||||
|
self.tform = make_tranforms(image_transforms)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.captions is not None:
|
||||||
|
return len(self.captions.keys())
|
||||||
|
else:
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
data = {}
|
||||||
|
if self.captions is not None:
|
||||||
|
chosen = list(self.captions.keys())[index]
|
||||||
|
caption = self.captions.get(chosen, None)
|
||||||
|
if caption is None:
|
||||||
|
caption = self.default_caption
|
||||||
|
filename = self.root_dir/chosen
|
||||||
|
else:
|
||||||
|
filename = self.paths[index]
|
||||||
|
|
||||||
|
if self.return_paths:
|
||||||
|
data["path"] = str(filename)
|
||||||
|
|
||||||
|
im = Image.open(filename).convert("RGB")
|
||||||
|
im = self.process_im(im)
|
||||||
|
data["image"] = im
|
||||||
|
|
||||||
|
if self.captions is not None:
|
||||||
|
data["txt"] = caption
|
||||||
|
else:
|
||||||
|
data["txt"] = self.default_caption
|
||||||
|
|
||||||
|
if self.postprocess is not None:
|
||||||
|
data = self.postprocess(data)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def process_im(self, im):
|
||||||
|
im = im.convert("RGB")
|
||||||
|
return self.tform(im)
|
||||||
|
import random
|
||||||
|
|
||||||
|
class TransformDataset():
|
||||||
|
def __init__(self, ds, extra_label="sksbspic"):
|
||||||
|
self.ds = ds
|
||||||
|
self.extra_label = extra_label
|
||||||
|
self.transforms = {
|
||||||
|
"align": transforms.Resize(768),
|
||||||
|
"centerzoom": transforms.CenterCrop(768),
|
||||||
|
"randzoom": transforms.RandomCrop(768),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
data = self.ds[index]
|
||||||
|
|
||||||
|
im = data['image']
|
||||||
|
im = im.permute(2,0,1)
|
||||||
|
# In case data is smaller than expected
|
||||||
|
im = transforms.Resize(1024)(im)
|
||||||
|
|
||||||
|
tform_name = random.choice(list(self.transforms.keys()))
|
||||||
|
im = self.transforms[tform_name](im)
|
||||||
|
|
||||||
|
im = im.permute(1,2,0)
|
||||||
|
|
||||||
|
data['image'] = im
|
||||||
|
data['txt'] = data['txt'] + f" {self.extra_label} {tform_name}"
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.ds)
|
||||||
|
|
||||||
|
def hf_dataset(
|
||||||
|
name,
|
||||||
|
image_transforms=[],
|
||||||
|
image_column="image",
|
||||||
|
text_column="text",
|
||||||
|
split='train',
|
||||||
|
image_key='image',
|
||||||
|
caption_key='txt',
|
||||||
|
):
|
||||||
|
"""Make huggingface dataset with appropriate list of transforms applied
|
||||||
|
"""
|
||||||
|
ds = load_dataset(name, split=split)
|
||||||
|
tform = make_tranforms(image_transforms)
|
||||||
|
|
||||||
|
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
|
||||||
|
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
|
||||||
|
|
||||||
|
def pre_process(examples):
|
||||||
|
processed = {}
|
||||||
|
processed[image_key] = [tform(im) for im in examples[image_column]]
|
||||||
|
processed[caption_key] = examples[text_column]
|
||||||
|
return processed
|
||||||
|
|
||||||
|
ds.set_transform(pre_process)
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def hf_dataset_RSITMD(
|
||||||
|
name,
|
||||||
|
image_transforms=[],
|
||||||
|
image_column="image",
|
||||||
|
text_column="text",
|
||||||
|
split='train',
|
||||||
|
image_key='image',
|
||||||
|
caption_key='txt',
|
||||||
|
):
|
||||||
|
"""Make huggingface dataset with appropriate list of transforms applied
|
||||||
|
"""
|
||||||
|
|
||||||
|
data_files = {
|
||||||
|
"train": "/mmu_nlp_ssd/yuanzhiqiang/dif/data/RSITMD/hf_train.json",
|
||||||
|
"validation": "/mmu_nlp_ssd/yuanzhiqiang/dif/data/RSITMD/hf_val.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
ds = load_dataset("json", data_files=data_files, field="data")
|
||||||
|
|
||||||
|
if split == 'train':
|
||||||
|
ds = ds['train']
|
||||||
|
else:
|
||||||
|
ds = ds['validation']
|
||||||
|
|
||||||
|
tform = make_tranforms(image_transforms)
|
||||||
|
|
||||||
|
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
|
||||||
|
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
|
||||||
|
|
||||||
|
def pre_process(examples):
|
||||||
|
examples['image'] = [Image.open(x) for x in examples['image']]
|
||||||
|
|
||||||
|
processed = {}
|
||||||
|
processed[image_key] = [tform(im) for im in examples[image_column]]
|
||||||
|
processed[caption_key] = examples[text_column]
|
||||||
|
return processed
|
||||||
|
|
||||||
|
ds.set_transform(pre_process)
|
||||||
|
|
||||||
|
return ds
|
||||||
|
|
||||||
|
class TextOnly(Dataset):
|
||||||
|
def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
|
||||||
|
"""Returns only captions with dummy images"""
|
||||||
|
self.output_size = output_size
|
||||||
|
self.image_key = image_key
|
||||||
|
self.caption_key = caption_key
|
||||||
|
if isinstance(captions, Path):
|
||||||
|
self.captions = self._load_caption_file(captions)
|
||||||
|
else:
|
||||||
|
self.captions = captions
|
||||||
|
|
||||||
|
if n_gpus > 1:
|
||||||
|
# hack to make sure that all the captions appear on each gpu
|
||||||
|
repeated = [n_gpus*[x] for x in self.captions]
|
||||||
|
self.captions = []
|
||||||
|
[self.captions.extend(x) for x in repeated]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.captions)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
dummy_im = torch.zeros(3, self.output_size, self.output_size)
|
||||||
|
dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
|
||||||
|
return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
|
||||||
|
|
||||||
|
def _load_caption_file(self, filename):
|
||||||
|
with open(filename, 'rt') as f:
|
||||||
|
captions = f.readlines()
|
||||||
|
return [x.strip('\n') for x in captions]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
class IdRetreivalDataset(FolderData):
|
||||||
|
def __init__(self, ret_file, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
with open(ret_file, "rt") as f:
|
||||||
|
self.ret = json.load(f)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
data = super().__getitem__(index)
|
||||||
|
key = self.paths[index].name
|
||||||
|
matches = self.ret[key]
|
||||||
|
if len(matches) > 0:
|
||||||
|
retreived = random.choice(matches)
|
||||||
|
else:
|
||||||
|
retreived = key
|
||||||
|
filename = self.root_dir/retreived
|
||||||
|
im = Image.open(filename).convert("RGB")
|
||||||
|
im = self.process_im(im)
|
||||||
|
# data["match"] = im
|
||||||
|
data["match"] = torch.cat((data["image"], im), dim=-1)
|
||||||
|
return data
|
||||||
77
ldm/extras.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
import torch
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import logging
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def all_logging_disabled(highest_level=logging.CRITICAL):
|
||||||
|
"""
|
||||||
|
A context manager that will prevent any logging messages
|
||||||
|
triggered during the body from being processed.
|
||||||
|
|
||||||
|
:param highest_level: the maximum logging level in use.
|
||||||
|
This would only need to be changed if a custom level greater than CRITICAL
|
||||||
|
is defined.
|
||||||
|
|
||||||
|
https://gist.github.com/simon-weber/7853144
|
||||||
|
"""
|
||||||
|
# two kind-of hacks here:
|
||||||
|
# * can't get the highest logging level in effect => delegate to the user
|
||||||
|
# * can't get the current module-level override => use an undocumented
|
||||||
|
# (but non-private!) interface
|
||||||
|
|
||||||
|
previous_level = logging.root.manager.disable
|
||||||
|
|
||||||
|
logging.disable(highest_level)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
logging.disable(previous_level)
|
||||||
|
|
||||||
|
def load_training_dir(train_dir, device, epoch="last"):
|
||||||
|
"""Load a checkpoint and config from training directory"""
|
||||||
|
train_dir = Path(train_dir)
|
||||||
|
ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
|
||||||
|
assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
|
||||||
|
config = list(train_dir.rglob(f"*-project.yaml"))
|
||||||
|
assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
|
||||||
|
if len(config) > 1:
|
||||||
|
print(f"found {len(config)} matching config files")
|
||||||
|
config = sorted(config)[-1]
|
||||||
|
print(f"selecting {config}")
|
||||||
|
else:
|
||||||
|
config = config[0]
|
||||||
|
|
||||||
|
|
||||||
|
config = OmegaConf.load(config)
|
||||||
|
return load_model_from_config(config, ckpt[0], device)
|
||||||
|
|
||||||
|
def load_model_from_config(config, ckpt, device="cpu", verbose=False):
|
||||||
|
"""Loads a model from config and a ckpt
|
||||||
|
if config is a path will use omegaconf to load
|
||||||
|
"""
|
||||||
|
if isinstance(config, (str, Path)):
|
||||||
|
config = OmegaConf.load(config)
|
||||||
|
|
||||||
|
with all_logging_disabled():
|
||||||
|
print(f"Loading model from {ckpt}")
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
global_step = pl_sd["global_step"]
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
if len(m) > 0 and verbose:
|
||||||
|
print("missing keys:")
|
||||||
|
print(m)
|
||||||
|
if len(u) > 0 and verbose:
|
||||||
|
print("unexpected keys:")
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
model.cond_stage_model.device = device
|
||||||
|
return model
|
||||||
96
ldm/guidance.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
from scipy import interpolate
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from IPython.display import clear_output
|
||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class GuideModel(torch.nn.Module, abc.ABC):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def preprocess(self, x_img):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def compute_loss(self, inp):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Guider(torch.nn.Module):
|
||||||
|
def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
|
||||||
|
"""Apply classifier guidance
|
||||||
|
|
||||||
|
Specify a guidance scale as either a scalar
|
||||||
|
Or a schedule as a list of tuples t = 0->1 and scale, e.g.
|
||||||
|
[(0, 10), (0.5, 20), (1, 50)]
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.sampler = sampler
|
||||||
|
self.index = 0
|
||||||
|
self.show = verbose
|
||||||
|
self.guide_model = guide_model
|
||||||
|
self.history = []
|
||||||
|
|
||||||
|
if isinstance(scale, (Tuple, List)):
|
||||||
|
times = np.array([x[0] for x in scale])
|
||||||
|
values = np.array([x[1] for x in scale])
|
||||||
|
self.scale_schedule = {"times": times, "values": values}
|
||||||
|
else:
|
||||||
|
self.scale_schedule = float(scale)
|
||||||
|
|
||||||
|
self.ddim_timesteps = sampler.ddim_timesteps
|
||||||
|
self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
|
||||||
|
|
||||||
|
|
||||||
|
def get_scales(self):
|
||||||
|
if isinstance(self.scale_schedule, float):
|
||||||
|
return len(self.ddim_timesteps)*[self.scale_schedule]
|
||||||
|
|
||||||
|
interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
|
||||||
|
fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
|
||||||
|
return interpolater(fractional_steps)
|
||||||
|
|
||||||
|
def modify_score(self, model, e_t, x, t, c):
|
||||||
|
|
||||||
|
# TODO look up index by t
|
||||||
|
scale = self.get_scales()[self.index]
|
||||||
|
|
||||||
|
if (scale == 0):
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
|
||||||
|
with torch.enable_grad():
|
||||||
|
x_in = x.detach().requires_grad_(True)
|
||||||
|
pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
|
||||||
|
x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
|
||||||
|
|
||||||
|
inp = self.guide_model.preprocess(x_img)
|
||||||
|
loss = self.guide_model.compute_loss(inp)
|
||||||
|
grads = torch.autograd.grad(loss.sum(), x_in)[0]
|
||||||
|
correction = grads * scale
|
||||||
|
|
||||||
|
if self.show:
|
||||||
|
clear_output(wait=True)
|
||||||
|
print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
|
||||||
|
self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
|
||||||
|
plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
|
||||||
|
plt.axis('off')
|
||||||
|
plt.show()
|
||||||
|
plt.imshow(correction[0][0].detach().cpu())
|
||||||
|
plt.axis('off')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
e_t_mod = e_t - sqrt_1ma*correction
|
||||||
|
if self.show:
|
||||||
|
fig, axs = plt.subplots(1, 3)
|
||||||
|
axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||||
|
axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||||
|
axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||||
|
plt.show()
|
||||||
|
self.index += 1
|
||||||
|
return e_t_mod
|
||||||
98
ldm/lr_scheduler.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaWarmUpCosineScheduler:
|
||||||
|
"""
|
||||||
|
note: use with a base_lr of 1.0
|
||||||
|
"""
|
||||||
|
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
||||||
|
self.lr_warm_up_steps = warm_up_steps
|
||||||
|
self.lr_start = lr_start
|
||||||
|
self.lr_min = lr_min
|
||||||
|
self.lr_max = lr_max
|
||||||
|
self.lr_max_decay_steps = max_decay_steps
|
||||||
|
self.last_lr = 0.
|
||||||
|
self.verbosity_interval = verbosity_interval
|
||||||
|
|
||||||
|
def schedule(self, n, **kwargs):
|
||||||
|
if self.verbosity_interval > 0:
|
||||||
|
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||||
|
if n < self.lr_warm_up_steps:
|
||||||
|
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
||||||
|
self.last_lr = lr
|
||||||
|
return lr
|
||||||
|
else:
|
||||||
|
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
||||||
|
t = min(t, 1.0)
|
||||||
|
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
||||||
|
1 + np.cos(t * np.pi))
|
||||||
|
self.last_lr = lr
|
||||||
|
return lr
|
||||||
|
|
||||||
|
def __call__(self, n, **kwargs):
|
||||||
|
return self.schedule(n,**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaWarmUpCosineScheduler2:
|
||||||
|
"""
|
||||||
|
supports repeated iterations, configurable via lists
|
||||||
|
note: use with a base_lr of 1.0.
|
||||||
|
"""
|
||||||
|
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
||||||
|
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
||||||
|
self.lr_warm_up_steps = warm_up_steps
|
||||||
|
self.f_start = f_start
|
||||||
|
self.f_min = f_min
|
||||||
|
self.f_max = f_max
|
||||||
|
self.cycle_lengths = cycle_lengths
|
||||||
|
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
||||||
|
self.last_f = 0.
|
||||||
|
self.verbosity_interval = verbosity_interval
|
||||||
|
|
||||||
|
def find_in_interval(self, n):
|
||||||
|
interval = 0
|
||||||
|
for cl in self.cum_cycles[1:]:
|
||||||
|
if n <= cl:
|
||||||
|
return interval
|
||||||
|
interval += 1
|
||||||
|
|
||||||
|
def schedule(self, n, **kwargs):
|
||||||
|
cycle = self.find_in_interval(n)
|
||||||
|
n = n - self.cum_cycles[cycle]
|
||||||
|
if self.verbosity_interval > 0:
|
||||||
|
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||||
|
f"current cycle {cycle}")
|
||||||
|
if n < self.lr_warm_up_steps[cycle]:
|
||||||
|
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
else:
|
||||||
|
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
||||||
|
t = min(t, 1.0)
|
||||||
|
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||||
|
1 + np.cos(t * np.pi))
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
|
||||||
|
def __call__(self, n, **kwargs):
|
||||||
|
return self.schedule(n, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||||
|
|
||||||
|
def schedule(self, n, **kwargs):
|
||||||
|
cycle = self.find_in_interval(n)
|
||||||
|
n = n - self.cum_cycles[cycle]
|
||||||
|
if self.verbosity_interval > 0:
|
||||||
|
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||||
|
f"current cycle {cycle}")
|
||||||
|
|
||||||
|
if n < self.lr_warm_up_steps[cycle]:
|
||||||
|
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
else:
|
||||||
|
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
|
||||||
443
ldm/models/autoencoder.py
Normal file
@ -0,0 +1,443 @@
|
|||||||
|
import torch
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
|
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
|
class VQModel(pl.LightningModule):
|
||||||
|
def __init__(self,
|
||||||
|
ddconfig,
|
||||||
|
lossconfig,
|
||||||
|
n_embed,
|
||||||
|
embed_dim,
|
||||||
|
ckpt_path=None,
|
||||||
|
ignore_keys=[],
|
||||||
|
image_key="image",
|
||||||
|
colorize_nlabels=None,
|
||||||
|
monitor=None,
|
||||||
|
batch_resize_range=None,
|
||||||
|
scheduler_config=None,
|
||||||
|
lr_g_factor=1.0,
|
||||||
|
remap=None,
|
||||||
|
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||||
|
use_ema=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.n_embed = n_embed
|
||||||
|
self.image_key = image_key
|
||||||
|
self.encoder = Encoder(**ddconfig)
|
||||||
|
self.decoder = Decoder(**ddconfig)
|
||||||
|
self.loss = instantiate_from_config(lossconfig)
|
||||||
|
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||||
|
remap=remap,
|
||||||
|
sane_index_shape=sane_index_shape)
|
||||||
|
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||||
|
if colorize_nlabels is not None:
|
||||||
|
assert type(colorize_nlabels)==int
|
||||||
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||||
|
if monitor is not None:
|
||||||
|
self.monitor = monitor
|
||||||
|
self.batch_resize_range = batch_resize_range
|
||||||
|
if self.batch_resize_range is not None:
|
||||||
|
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||||
|
|
||||||
|
self.use_ema = use_ema
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema = LitEma(self)
|
||||||
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.lr_g_factor = lr_g_factor
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def ema_scope(self, context=None):
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema.store(self.parameters())
|
||||||
|
self.model_ema.copy_to(self)
|
||||||
|
if context is not None:
|
||||||
|
print(f"{context}: Switched to EMA weights")
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema.restore(self.parameters())
|
||||||
|
if context is not None:
|
||||||
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||||
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
|
del sd[k]
|
||||||
|
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||||
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||||
|
if len(missing) > 0:
|
||||||
|
print(f"Missing Keys: {missing}")
|
||||||
|
print(f"Unexpected Keys: {unexpected}")
|
||||||
|
|
||||||
|
def on_train_batch_end(self, *args, **kwargs):
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema(self)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
quant, emb_loss, info = self.quantize(h)
|
||||||
|
return quant, emb_loss, info
|
||||||
|
|
||||||
|
def encode_to_prequant(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def decode(self, quant):
|
||||||
|
quant = self.post_quant_conv(quant)
|
||||||
|
dec = self.decoder(quant)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def decode_code(self, code_b):
|
||||||
|
quant_b = self.quantize.embed_code(code_b)
|
||||||
|
dec = self.decode(quant_b)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def forward(self, input, return_pred_indices=False):
|
||||||
|
quant, diff, (_,_,ind) = self.encode(input)
|
||||||
|
dec = self.decode(quant)
|
||||||
|
if return_pred_indices:
|
||||||
|
return dec, diff, ind
|
||||||
|
return dec, diff
|
||||||
|
|
||||||
|
def get_input(self, batch, k):
|
||||||
|
x = batch[k]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[..., None]
|
||||||
|
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||||
|
if self.batch_resize_range is not None:
|
||||||
|
lower_size = self.batch_resize_range[0]
|
||||||
|
upper_size = self.batch_resize_range[1]
|
||||||
|
if self.global_step <= 4:
|
||||||
|
# do the first few batches with max size to avoid later oom
|
||||||
|
new_resize = upper_size
|
||||||
|
else:
|
||||||
|
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||||
|
if new_resize != x.shape[2]:
|
||||||
|
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||||
|
x = x.detach()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
|
# https://github.com/pytorch/pytorch/issues/37142
|
||||||
|
# try not to fool the heuristics
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||||
|
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
# autoencode
|
||||||
|
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train",
|
||||||
|
predicted_indices=ind)
|
||||||
|
|
||||||
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||||
|
return aeloss
|
||||||
|
|
||||||
|
if optimizer_idx == 1:
|
||||||
|
# discriminator
|
||||||
|
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train")
|
||||||
|
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||||
|
return discloss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
log_dict = self._validation_step(batch, batch_idx)
|
||||||
|
with self.ema_scope():
|
||||||
|
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||||
|
return log_dict
|
||||||
|
|
||||||
|
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||||
|
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split="val"+suffix,
|
||||||
|
predicted_indices=ind
|
||||||
|
)
|
||||||
|
|
||||||
|
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split="val"+suffix,
|
||||||
|
predicted_indices=ind
|
||||||
|
)
|
||||||
|
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||||
|
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||||
|
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||||
|
self.log(f"val{suffix}/aeloss", aeloss,
|
||||||
|
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||||
|
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||||
|
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||||
|
self.log_dict(log_dict_ae)
|
||||||
|
self.log_dict(log_dict_disc)
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
lr_d = self.learning_rate
|
||||||
|
lr_g = self.lr_g_factor*self.learning_rate
|
||||||
|
print("lr_d", lr_d)
|
||||||
|
print("lr_g", lr_g)
|
||||||
|
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||||
|
list(self.decoder.parameters())+
|
||||||
|
list(self.quantize.parameters())+
|
||||||
|
list(self.quant_conv.parameters())+
|
||||||
|
list(self.post_quant_conv.parameters()),
|
||||||
|
lr=lr_g, betas=(0.5, 0.9))
|
||||||
|
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||||
|
lr=lr_d, betas=(0.5, 0.9))
|
||||||
|
|
||||||
|
if self.scheduler_config is not None:
|
||||||
|
scheduler = instantiate_from_config(self.scheduler_config)
|
||||||
|
|
||||||
|
print("Setting up LambdaLR scheduler...")
|
||||||
|
scheduler = [
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return [opt_ae, opt_disc], scheduler
|
||||||
|
return [opt_ae, opt_disc], []
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
|
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
x = x.to(self.device)
|
||||||
|
if only_inputs:
|
||||||
|
log["inputs"] = x
|
||||||
|
return log
|
||||||
|
xrec, _ = self(x)
|
||||||
|
if x.shape[1] > 3:
|
||||||
|
# colorize with random projection
|
||||||
|
assert xrec.shape[1] > 3
|
||||||
|
x = self.to_rgb(x)
|
||||||
|
xrec = self.to_rgb(xrec)
|
||||||
|
log["inputs"] = x
|
||||||
|
log["reconstructions"] = xrec
|
||||||
|
if plot_ema:
|
||||||
|
with self.ema_scope():
|
||||||
|
xrec_ema, _ = self(x)
|
||||||
|
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||||
|
log["reconstructions_ema"] = xrec_ema
|
||||||
|
return log
|
||||||
|
|
||||||
|
def to_rgb(self, x):
|
||||||
|
assert self.image_key == "segmentation"
|
||||||
|
if not hasattr(self, "colorize"):
|
||||||
|
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||||
|
x = F.conv2d(x, weight=self.colorize)
|
||||||
|
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VQModelInterface(VQModel):
|
||||||
|
def __init__(self, embed_dim, *args, **kwargs):
|
||||||
|
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def decode(self, h, force_not_quantize=False):
|
||||||
|
# also go through quantization layer
|
||||||
|
if not force_not_quantize:
|
||||||
|
quant, emb_loss, info = self.quantize(h)
|
||||||
|
else:
|
||||||
|
quant = h
|
||||||
|
quant = self.post_quant_conv(quant)
|
||||||
|
dec = self.decoder(quant)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderKL(pl.LightningModule):
|
||||||
|
def __init__(self,
|
||||||
|
ddconfig,
|
||||||
|
lossconfig,
|
||||||
|
embed_dim,
|
||||||
|
ckpt_path=None,
|
||||||
|
ignore_keys=[],
|
||||||
|
image_key="image",
|
||||||
|
colorize_nlabels=None,
|
||||||
|
monitor=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.image_key = image_key
|
||||||
|
self.encoder = Encoder(**ddconfig)
|
||||||
|
self.decoder = Decoder(**ddconfig)
|
||||||
|
self.loss = instantiate_from_config(lossconfig)
|
||||||
|
assert ddconfig["double_z"]
|
||||||
|
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
if colorize_nlabels is not None:
|
||||||
|
assert type(colorize_nlabels)==int
|
||||||
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||||
|
if monitor is not None:
|
||||||
|
self.monitor = monitor
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||||
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
|
del sd[k]
|
||||||
|
self.load_state_dict(sd, strict=False)
|
||||||
|
print(f"Restored from {path}")
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
moments = self.quant_conv(h)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
return posterior
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
dec = self.decoder(z)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def forward(self, input, sample_posterior=True):
|
||||||
|
posterior = self.encode(input)
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample()
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
dec = self.decode(z)
|
||||||
|
return dec, posterior
|
||||||
|
|
||||||
|
def get_input(self, batch, k):
|
||||||
|
x = batch[k]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[..., None]
|
||||||
|
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
|
inputs = self.get_input(batch, self.image_key)
|
||||||
|
reconstructions, posterior = self(inputs)
|
||||||
|
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
# train encoder+decoder+logvar
|
||||||
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train")
|
||||||
|
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||||
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||||
|
return aeloss
|
||||||
|
|
||||||
|
if optimizer_idx == 1:
|
||||||
|
# train the discriminator
|
||||||
|
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train")
|
||||||
|
|
||||||
|
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||||
|
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||||
|
return discloss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
inputs = self.get_input(batch, self.image_key)
|
||||||
|
reconstructions, posterior = self(inputs)
|
||||||
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="val")
|
||||||
|
|
||||||
|
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="val")
|
||||||
|
|
||||||
|
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||||
|
self.log_dict(log_dict_ae)
|
||||||
|
self.log_dict(log_dict_disc)
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
lr = self.learning_rate
|
||||||
|
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||||
|
list(self.decoder.parameters())+
|
||||||
|
list(self.quant_conv.parameters())+
|
||||||
|
list(self.post_quant_conv.parameters()),
|
||||||
|
lr=lr, betas=(0.5, 0.9))
|
||||||
|
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||||
|
lr=lr, betas=(0.5, 0.9))
|
||||||
|
return [opt_ae, opt_disc], []
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
x = x.to(self.device)
|
||||||
|
if not only_inputs:
|
||||||
|
xrec, posterior = self(x)
|
||||||
|
if x.shape[1] > 3:
|
||||||
|
# colorize with random projection
|
||||||
|
assert xrec.shape[1] > 3
|
||||||
|
x = self.to_rgb(x)
|
||||||
|
xrec = self.to_rgb(xrec)
|
||||||
|
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||||
|
log["reconstructions"] = xrec
|
||||||
|
log["inputs"] = x
|
||||||
|
return log
|
||||||
|
|
||||||
|
def to_rgb(self, x):
|
||||||
|
assert self.image_key == "segmentation"
|
||||||
|
if not hasattr(self, "colorize"):
|
||||||
|
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||||
|
x = F.conv2d(x, weight=self.colorize)
|
||||||
|
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityFirstStage(torch.nn.Module):
|
||||||
|
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||||
|
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def encode(self, x, *args, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x, *args, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def quantize(self, x, *args, **kwargs):
|
||||||
|
if self.vq_interface:
|
||||||
|
return x, None, [None, None, None]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
return x
|
||||||
0
ldm/models/diffusion/__init__.py
Normal file
267
ldm/models/diffusion/classifier.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
from copy import deepcopy
|
||||||
|
from einops import rearrange
|
||||||
|
from glob import glob
|
||||||
|
from natsort import natsorted
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
|
||||||
|
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
||||||
|
|
||||||
|
__models__ = {
|
||||||
|
'class_label': EncoderUNetModel,
|
||||||
|
'segmentation': UNetModel
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def disabled_train(self, mode=True):
|
||||||
|
"""Overwrite model.train with this function to make sure train/eval mode
|
||||||
|
does not change anymore."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class NoisyLatentImageClassifier(pl.LightningModule):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
diffusion_path,
|
||||||
|
num_classes,
|
||||||
|
ckpt_path=None,
|
||||||
|
pool='attention',
|
||||||
|
label_key=None,
|
||||||
|
diffusion_ckpt_path=None,
|
||||||
|
scheduler_config=None,
|
||||||
|
weight_decay=1.e-2,
|
||||||
|
log_steps=10,
|
||||||
|
monitor='val/loss',
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_classes = num_classes
|
||||||
|
# get latest config of diffusion model
|
||||||
|
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
|
||||||
|
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
||||||
|
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
||||||
|
self.load_diffusion()
|
||||||
|
|
||||||
|
self.monitor = monitor
|
||||||
|
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||||
|
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
||||||
|
self.log_steps = log_steps
|
||||||
|
|
||||||
|
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
|
||||||
|
else self.diffusion_model.cond_stage_key
|
||||||
|
|
||||||
|
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
|
||||||
|
|
||||||
|
if self.label_key not in __models__:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
self.load_classifier(ckpt_path, pool)
|
||||||
|
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.use_scheduler = self.scheduler_config is not None
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||||
|
sd = torch.load(path, map_location="cpu")
|
||||||
|
if "state_dict" in list(sd.keys()):
|
||||||
|
sd = sd["state_dict"]
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
|
del sd[k]
|
||||||
|
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||||
|
sd, strict=False)
|
||||||
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||||
|
if len(missing) > 0:
|
||||||
|
print(f"Missing Keys: {missing}")
|
||||||
|
if len(unexpected) > 0:
|
||||||
|
print(f"Unexpected Keys: {unexpected}")
|
||||||
|
|
||||||
|
def load_diffusion(self):
|
||||||
|
model = instantiate_from_config(self.diffusion_config)
|
||||||
|
self.diffusion_model = model.eval()
|
||||||
|
self.diffusion_model.train = disabled_train
|
||||||
|
for param in self.diffusion_model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def load_classifier(self, ckpt_path, pool):
|
||||||
|
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
||||||
|
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
|
||||||
|
model_config.out_channels = self.num_classes
|
||||||
|
if self.label_key == 'class_label':
|
||||||
|
model_config.pool = pool
|
||||||
|
|
||||||
|
self.model = __models__[self.label_key](**model_config)
|
||||||
|
if ckpt_path is not None:
|
||||||
|
print('#####################################################################')
|
||||||
|
print(f'load from ckpt "{ckpt_path}"')
|
||||||
|
print('#####################################################################')
|
||||||
|
self.init_from_ckpt(ckpt_path)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_x_noisy(self, x, t, noise=None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x))
|
||||||
|
continuous_sqrt_alpha_cumprod = None
|
||||||
|
if self.diffusion_model.use_continuous_noise:
|
||||||
|
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
||||||
|
# todo: make sure t+1 is correct here
|
||||||
|
|
||||||
|
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
|
||||||
|
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
|
||||||
|
|
||||||
|
def forward(self, x_noisy, t, *args, **kwargs):
|
||||||
|
return self.model(x_noisy, t)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_input(self, batch, k):
|
||||||
|
x = batch[k]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[..., None]
|
||||||
|
x = rearrange(x, 'b h w c -> b c h w')
|
||||||
|
x = x.to(memory_format=torch.contiguous_format).float()
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_conditioning(self, batch, k=None):
|
||||||
|
if k is None:
|
||||||
|
k = self.label_key
|
||||||
|
assert k is not None, 'Needs to provide label key'
|
||||||
|
|
||||||
|
targets = batch[k].to(self.device)
|
||||||
|
|
||||||
|
if self.label_key == 'segmentation':
|
||||||
|
targets = rearrange(targets, 'b h w c -> b c h w')
|
||||||
|
for down in range(self.numd):
|
||||||
|
h, w = targets.shape[-2:]
|
||||||
|
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
|
||||||
|
|
||||||
|
# targets = rearrange(targets,'b c h w -> b h w c')
|
||||||
|
|
||||||
|
return targets
|
||||||
|
|
||||||
|
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
||||||
|
_, top_ks = torch.topk(logits, k, dim=1)
|
||||||
|
if reduction == "mean":
|
||||||
|
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||||
|
elif reduction == "none":
|
||||||
|
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
||||||
|
|
||||||
|
def on_train_epoch_start(self):
|
||||||
|
# save some memory
|
||||||
|
self.diffusion_model.model.to('cpu')
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def write_logs(self, loss, logits, targets):
|
||||||
|
log_prefix = 'train' if self.training else 'val'
|
||||||
|
log = {}
|
||||||
|
log[f"{log_prefix}/loss"] = loss.mean()
|
||||||
|
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
||||||
|
logits, targets, k=1, reduction="mean"
|
||||||
|
)
|
||||||
|
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
||||||
|
logits, targets, k=5, reduction="mean"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
|
||||||
|
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
||||||
|
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
|
||||||
|
lr = self.optimizers().param_groups[0]['lr']
|
||||||
|
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
|
||||||
|
|
||||||
|
def shared_step(self, batch, t=None):
|
||||||
|
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
|
||||||
|
targets = self.get_conditioning(batch)
|
||||||
|
if targets.dim() == 4:
|
||||||
|
targets = targets.argmax(dim=1)
|
||||||
|
if t is None:
|
||||||
|
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
|
||||||
|
else:
|
||||||
|
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
||||||
|
x_noisy = self.get_x_noisy(x, t)
|
||||||
|
logits = self(x_noisy, t)
|
||||||
|
|
||||||
|
loss = F.cross_entropy(logits, targets, reduction='none')
|
||||||
|
|
||||||
|
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
||||||
|
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss, logits, x_noisy, targets
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
loss, *_ = self.shared_step(batch)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def reset_noise_accs(self):
|
||||||
|
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
|
||||||
|
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
|
||||||
|
|
||||||
|
def on_validation_start(self):
|
||||||
|
self.reset_noise_accs()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
loss, *_ = self.shared_step(batch)
|
||||||
|
|
||||||
|
for t in self.noisy_acc:
|
||||||
|
_, logits, _, targets = self.shared_step(batch, t)
|
||||||
|
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
|
||||||
|
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
||||||
|
|
||||||
|
if self.use_scheduler:
|
||||||
|
scheduler = instantiate_from_config(self.scheduler_config)
|
||||||
|
|
||||||
|
print("Setting up LambdaLR scheduler...")
|
||||||
|
scheduler = [
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1
|
||||||
|
}]
|
||||||
|
return [optimizer], scheduler
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
||||||
|
log['inputs'] = x
|
||||||
|
|
||||||
|
y = self.get_conditioning(batch)
|
||||||
|
|
||||||
|
if self.label_key == 'class_label':
|
||||||
|
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
||||||
|
log['labels'] = y
|
||||||
|
|
||||||
|
if ismap(y):
|
||||||
|
log['labels'] = self.diffusion_model.to_rgb(y)
|
||||||
|
|
||||||
|
for step in range(self.log_steps):
|
||||||
|
current_time = step * self.log_time_interval
|
||||||
|
|
||||||
|
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
||||||
|
|
||||||
|
log[f'inputs@t{current_time}'] = x_noisy
|
||||||
|
|
||||||
|
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
||||||
|
pred = rearrange(pred, 'b h w c -> b c h w')
|
||||||
|
|
||||||
|
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
|
||||||
|
|
||||||
|
for key in log:
|
||||||
|
log[key] = log[key][:N]
|
||||||
|
|
||||||
|
return log
|
||||||
324
ldm/models/diffusion/ddim.py
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
"""SAMPLING ONLY."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from functools import partial
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
||||||
|
from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding
|
||||||
|
|
||||||
|
|
||||||
|
class DDIMSampler(object):
|
||||||
|
def __init__(self, model, schedule="linear", **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
|
self.schedule = schedule
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
"""Same as to in torch module
|
||||||
|
Don't really underestand why this isn't a module in the first place"""
|
||||||
|
for k, v in self.__dict__.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
new_v = getattr(self, k).to(device)
|
||||||
|
setattr(self, k, new_v)
|
||||||
|
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
if attr.device != torch.device("cuda"):
|
||||||
|
attr = attr.to(torch.device("cuda"))
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||||
|
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||||
|
alphas_cumprod = self.model.alphas_cumprod
|
||||||
|
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||||
|
|
||||||
|
self.register_buffer('betas', to_torch(self.model.betas))
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||||
|
|
||||||
|
# ddim sampling parameters
|
||||||
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||||
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
|
eta=ddim_eta,verbose=verbose)
|
||||||
|
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||||
|
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||||
|
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||||
|
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||||
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
|
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||||
|
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||||
|
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
dynamic_threshold=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||||
|
|
||||||
|
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold,
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim_sampling(self, cond, shape,
|
||||||
|
x_T=None, ddim_use_original_steps=False,
|
||||||
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||||
|
t_start=-1):
|
||||||
|
device = self.model.betas.device
|
||||||
|
b = shape[0]
|
||||||
|
if x_T is None:
|
||||||
|
img = torch.randn(shape, device=device)
|
||||||
|
else:
|
||||||
|
img = x_T
|
||||||
|
|
||||||
|
if timesteps is None:
|
||||||
|
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
|
elif timesteps is not None and not ddim_use_original_steps:
|
||||||
|
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||||
|
timesteps = self.ddim_timesteps[:subset_end]
|
||||||
|
|
||||||
|
timesteps = timesteps[:t_start]
|
||||||
|
|
||||||
|
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||||
|
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||||
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
|
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||||
|
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert x0 is not None
|
||||||
|
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
|
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||||
|
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||||
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold)
|
||||||
|
img, pred_x0 = outs
|
||||||
|
if callback:
|
||||||
|
img = callback(i, img, pred_x0)
|
||||||
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
|
intermediates['x_inter'].append(img)
|
||||||
|
intermediates['pred_x0'].append(pred_x0)
|
||||||
|
|
||||||
|
return img, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||||
|
dynamic_threshold=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
if isinstance(c, dict):
|
||||||
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [torch.cat([
|
||||||
|
unconditional_conditioning[k][i],
|
||||||
|
c[k][i]]) for i in range(len(c[k]))]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([
|
||||||
|
unconditional_conditioning[k],
|
||||||
|
c[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
|
||||||
|
if dynamic_threshold is not None:
|
||||||
|
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||||
|
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
||||||
|
unconditional_guidance_scale=1.0, unconditional_conditioning=None):
|
||||||
|
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
|
||||||
|
|
||||||
|
assert t_enc <= num_reference_steps
|
||||||
|
num_steps = t_enc
|
||||||
|
|
||||||
|
if use_original_steps:
|
||||||
|
alphas_next = self.alphas_cumprod[:num_steps]
|
||||||
|
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||||
|
else:
|
||||||
|
alphas_next = self.ddim_alphas[:num_steps]
|
||||||
|
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||||
|
|
||||||
|
x_next = x0
|
||||||
|
intermediates = []
|
||||||
|
inter_steps = []
|
||||||
|
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
||||||
|
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
||||||
|
if unconditional_guidance_scale == 1.:
|
||||||
|
noise_pred = self.model.apply_model(x_next, t, c)
|
||||||
|
else:
|
||||||
|
assert unconditional_conditioning is not None
|
||||||
|
e_t_uncond, noise_pred = torch.chunk(
|
||||||
|
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
|
||||||
|
torch.cat((unconditional_conditioning, c))), 2)
|
||||||
|
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
|
||||||
|
|
||||||
|
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||||
|
weighted_noise_pred = alphas_next[i].sqrt() * (
|
||||||
|
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
|
||||||
|
x_next = xt_weighted + weighted_noise_pred
|
||||||
|
if return_intermediates and i % (
|
||||||
|
num_steps // return_intermediates) == 0 and i < num_steps - 1:
|
||||||
|
intermediates.append(x_next)
|
||||||
|
inter_steps.append(i)
|
||||||
|
elif return_intermediates and i >= num_steps - 2:
|
||||||
|
intermediates.append(x_next)
|
||||||
|
inter_steps.append(i)
|
||||||
|
|
||||||
|
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
||||||
|
if return_intermediates:
|
||||||
|
out.update({'intermediates': intermediates})
|
||||||
|
return x_next, out
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||||
|
# fast, but does not allow for exact reconstruction
|
||||||
|
# t serves as an index to gather the correct alphas
|
||||||
|
if use_original_steps:
|
||||||
|
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||||
|
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||||
|
else:
|
||||||
|
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||||
|
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||||
|
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn_like(x0)
|
||||||
|
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||||
|
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||||
|
use_original_steps=False):
|
||||||
|
|
||||||
|
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||||
|
timesteps = timesteps[:t_start]
|
||||||
|
|
||||||
|
time_range = np.flip(timesteps)
|
||||||
|
total_steps = timesteps.shape[0]
|
||||||
|
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||||
|
x_dec = x_latent
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||||
|
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning)
|
||||||
|
return x_dec
|
||||||
1997
ldm/models/diffusion/ddpm.py
Normal file
259
ldm/models/diffusion/plms.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
"""SAMPLING ONLY."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||||
|
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||||
|
|
||||||
|
|
||||||
|
class PLMSSampler(object):
|
||||||
|
def __init__(self, model, schedule="linear", **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
|
self.schedule = schedule
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
if attr.device != torch.device("cuda"):
|
||||||
|
attr = attr.to(torch.device("cuda"))
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
if ddim_eta != 0:
|
||||||
|
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||||
|
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||||
|
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||||
|
alphas_cumprod = self.model.alphas_cumprod
|
||||||
|
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||||
|
|
||||||
|
self.register_buffer('betas', to_torch(self.model.betas))
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||||
|
|
||||||
|
# ddim sampling parameters
|
||||||
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||||
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
|
eta=ddim_eta,verbose=verbose)
|
||||||
|
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||||
|
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||||
|
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||||
|
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||||
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
|
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||||
|
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||||
|
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
dynamic_threshold=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for PLMS sampling is {size}')
|
||||||
|
|
||||||
|
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold,
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms_sampling(self, cond, shape,
|
||||||
|
x_T=None, ddim_use_original_steps=False,
|
||||||
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||||
|
dynamic_threshold=None):
|
||||||
|
device = self.model.betas.device
|
||||||
|
b = shape[0]
|
||||||
|
if x_T is None:
|
||||||
|
img = torch.randn(shape, device=device)
|
||||||
|
else:
|
||||||
|
img = x_T
|
||||||
|
|
||||||
|
if timesteps is None:
|
||||||
|
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
|
elif timesteps is not None and not ddim_use_original_steps:
|
||||||
|
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||||
|
timesteps = self.ddim_timesteps[:subset_end]
|
||||||
|
|
||||||
|
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||||
|
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||||
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
|
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||||
|
old_eps = []
|
||||||
|
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||||
|
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert x0 is not None
|
||||||
|
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
|
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||||
|
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||||
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
old_eps=old_eps, t_next=ts_next,
|
||||||
|
dynamic_threshold=dynamic_threshold)
|
||||||
|
img, pred_x0, e_t = outs
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
if callback: callback(i)
|
||||||
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
|
intermediates['x_inter'].append(img)
|
||||||
|
intermediates['pred_x0'].append(pred_x0)
|
||||||
|
|
||||||
|
return img, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
||||||
|
dynamic_threshold=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
def get_model_output(x, t):
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
if isinstance(c, dict):
|
||||||
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [torch.cat([
|
||||||
|
unconditional_conditioning[k][i],
|
||||||
|
c[k][i]]) for i in range(len(c[k]))]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([
|
||||||
|
unconditional_conditioning[k],
|
||||||
|
c[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
if dynamic_threshold is not None:
|
||||||
|
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
e_t = get_model_output(x, t)
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = get_model_output(x_prev, t_next)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
elif len(old_eps) >= 3:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
return x_prev, pred_x0, e_t
|
||||||