Skip to content

Commit 3c78b1a

Browse files
committedMay 31, 2022
Add ddpm code.
1 parent c8047c8 commit 3c78b1a

File tree

12 files changed

+123
-270
lines changed

12 files changed

+123
-270
lines changed
 

‎assets/ddpm/celeba.jpg

69.3 KB
Loading

‎assets/ddpm/cifar10.jpg

22.7 KB
Loading

‎assets/ddpm/mnist.jpg

18.4 KB
Loading

‎configs/experiment/ddpm/celeba.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# @package _global_
2+
defaults:
3+
- override /model: ddpm
4+
- override /datamodule: celeba
5+
6+
exp_name: ddpm/celeba
7+
8+
trainer:
9+
max_epochs: 100
10+
check_val_every_n_epoch: 10
11+
12+
model:
13+
dim_mults: [1, 2, 4, 8]
14+
timesteps: 1000

‎configs/experiment/ddpm/cifar10.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# @package _global_
2+
defaults:
3+
- override /model: ddpm
4+
- override /datamodule: cifar10
5+
6+
exp_name: ddpm/cifar10
7+
8+
trainer:
9+
max_epochs: 100
10+
check_val_every_n_epoch: 10
11+
12+
model:
13+
dim_mults: [1, 2, 4]
14+
timesteps: 1000

‎configs/experiment/ddpm/mnist.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# @package _global_
2+
defaults:
3+
- override /model: ddpm
4+
- override /datamodule: mnist
5+
6+
exp_name: ddpm/mnist
7+
8+
trainer:
9+
max_epochs: 100
10+
check_val_every_n_epoch: 10
11+
12+
model:
13+
dim_mults: [2, 4]
14+
timesteps: 1000

‎configs/model/ddpm.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
defaults:
2+
- override /callbacks@_global_: ar_models
13
_target_: src.models.ddpm.DDPM
24
hidden_dim: 64
35
lr: 0.0001

‎readme.adoc

+17-1
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ NeruIPS 2017. [https://arxiv.org/abs/1711.00937[PDF]]
379379

380380
=== PixelCNN
381381
*_Conditional Image Generation with PixelCNN Decoders_* +
382-
_Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex Graves, Koray Kavukcuoglu_
382+
_Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex Graves, Koray Kavukcuoglu_ +
383383
NeruIPS 2016. [https://arxiv.org/abs/1606.05328[PDF]]
384384

385385
[cols="3*", options="header"]
@@ -395,3 +395,19 @@ NeruIPS 2016. [https://arxiv.org/abs/1606.05328[PDF]]
395395

396396
== Diffusion Models
397397
=== DDPM
398+
*_Denoising Diffusion Probabilistic Models_* +
399+
_Jonathan Ho, Ajay Jain, Pieter Abbeel_ +
400+
NeurIPS 2020. [https://arxiv.org/abs/2006.11239[PDF]]
401+
402+
[cols="4*", options="header"]
403+
|===
404+
^| Dataset
405+
^| MNIST
406+
^| CelebA
407+
^| CIFAR10
408+
409+
^.^| Results
410+
| image:assets/ddpm/mnist.jpg[mnist_mlp, {img-size}, {img-size}]
411+
| image:assets/ddpm/celeba.jpg[cleba_conv, {img-size}, {img-size}]
412+
| image:assets/ddpm/cifar10.jpg[cifar10_conv, {img-size}, {img-size}]
413+
|===

‎requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ seaborn # used in some callbacks
3838
jupyterlab # better jupyter notebooks
3939
pudb # debugger
4040
GPUtil # get info about GPUs
41+
einops # Framework to use einstein-like notation

‎src/callbacks/visualization.py

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ def on_validation_batch_end(self, trainer, pl_module, outputs: ValidationResult,
3232
trainer.logger.experiment.add_image("images/sample", fake_grid, global_step=trainer.current_epoch)
3333
torchvision.utils.save_image(fake_grid, result_path / f"{trainer.current_epoch}.jpg")
3434

35+
for key in outputs.others:
36+
grid = get_grid_images(outputs.others[key], pl_module)
37+
trainer.logger.experiment.add_image(f"images/{key}", grid, global_step=trainer.current_epoch)
38+
39+
3540
class TraverseLatentCallback(pl.Callback):
3641
def __init__(self, col=10, row=10) -> None:
3742
super().__init__()

‎src/models/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
22
from pytorch_lightning import LightningModule
33
from src.utils.utils import get_logger
44
import torch
55
import torch.nn.functional as F
66

77
@dataclass
88
class ValidationResult():
9+
others: field(default_factory=dict)
910
real_image: torch.Tensor = None
1011
fake_image: torch.Tensor = None
1112
recon_image: torch.Tensor = None

0 commit comments

Comments
 (0)