Skip to content

Commit

Permalink
Merge pull request #202 from LeoXing1996/ddpm_demo
Browse files Browse the repository at this point in the history
[Feature] Support denoising demo
  • Loading branch information
LeoXing1996 authored Jan 11, 2022
2 parents 5cea51f + 83a2e60 commit d4fcb21
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 3 deletions.
155 changes: 155 additions & 0 deletions demo/ddpm_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import argparse
import os
import os.path as osp
import sys

import mmcv
import numpy as np
import torch
from mmcv import DictAction
from torchvision import utils

# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa

from mmgen.apis import init_model, sample_ddpm_model # isort:skip # noqa
# yapf: enable


def parse_args():
parser = argparse.ArgumentParser(description='DDPM demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--save-path',
type=str,
default='./work_dirs/demos/ddpm_samples.png',
help='path to save uncoditional samples')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CUDA device id')

# args for inference/sampling
parser.add_argument(
'--num-batches', type=int, default=4, help='Batch size in inference')
parser.add_argument(
'--num-samples',
type=int,
default=12,
help='The total number of samples')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='Which model to use for sampling')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
parser.add_argument(
'--same-noise',
action='store_true',
help='whether use same noise as input (x_T)')
parser.add_argument(
'--n-skip',
type=int,
default=25,
help=('Skip how many steps before selecting one to visualize. This is '
'helpful with denoising timestep is too much. Only work with '
'`save-path` is end with \'.gif\'.'))

# args for image grid
parser.add_argument(
'--padding', type=int, default=0, help='Padding in the image grid.')
parser.add_argument(
'--nrow',
type=int,
default=2,
help=('Number of images displayed in each row of the grid. '
'This argument would work only when label is not given.'))

# args for image channel order
parser.add_argument(
'--is-rgb',
action='store_true',
help=('If true, color channels will not be permuted, This option is '
'useful when inference model trained with rgb images.'))

args = parser.parse_args()
return args


def create_gif(results, gif_name, fps=60, n_skip=1):
"""Create gif through imageio.
Args:
frames (torch.Tensor): Image frames, shape like [bz, 3, H, W].
gif_name (str): Saved gif name.
fps (int, optional): Frames per second of the generated gif.
Defaults to 60.
n_skip (int, optional): Skip how many steps before selecting one to
visualize. Defaults to 1.
"""
try:
import imageio
except ImportError:
raise RuntimeError('imageio is not installed,'
'Please use “pip install imageio” to install')
frames_list = []
for frame in results[::n_skip]:
frames_list.append(
(frame.permute(1, 2, 0).cpu().numpy() * 255.).astype(np.uint8))

# ensure the final denoising results in frames_list
if not (len(results) % n_skip == 0):
frames_list.append((results[-1].permute(1, 2, 0).cpu().numpy() *
255.).astype(np.uint8))

imageio.mimsave(gif_name, frames_list, 'GIF', fps=fps)


def main():
args = parse_args()
model = init_model(
args.config, checkpoint=args.checkpoint, device=args.device)

if args.sample_cfg is None:
args.sample_cfg = dict()

suffix = osp.splitext(args.save_path)[-1]
if suffix == '.gif':
args.sample_cfg['save_intermedia'] = True

results = sample_ddpm_model(model, args.num_samples, args.num_batches,
args.sample_model, args.same_noise,
**args.sample_cfg)

# save images
mmcv.mkdir_or_exist(os.path.dirname(args.save_path))
if suffix == '.gif':
# concentrate all output of each timestep
results_timestep_list = []
for t in results.keys():
# make grid
results_timestep = utils.make_grid(
results[t], nrow=args.nrow, padding=args.padding)
# unsqueeze at 0, because make grid output is size like [H', W', 3]
results_timestep_list.append(results_timestep[None, ...])

# Concatenates to [n_timesteps, H', W', 3]
results_timestep = torch.cat(results_timestep_list, dim=0)
if not args.is_rgb:
results_timestep = results_timestep[:, [2, 1, 0]]
results_timestep = (results_timestep + 1.) / 2.
create_gif(results_timestep, args.save_path, n_skip=args.n_skip)
else:
if not args.is_rgb:
results = results[:, [2, 1, 0]]

results = (results + 1.) / 2.
utils.save_image(
results, args.save_path, nrow=args.nrow, padding=args.padding)


if __name__ == '__main__':
main()
6 changes: 4 additions & 2 deletions mmgen/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import (init_model, sample_conditional_model,
sample_img2img_model, sample_uncoditional_model)
sample_ddpm_model, sample_img2img_model,
sample_uncoditional_model)
from .train import set_random_seed, train_model

__all__ = [
'set_random_seed', 'train_model', 'init_model', 'sample_img2img_model',
'sample_uncoditional_model', 'sample_conditional_model'
'sample_uncoditional_model', 'sample_conditional_model',
'sample_ddpm_model'
]
68 changes: 68 additions & 0 deletions mmgen/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,71 @@ def sample_img2img_model(model, image_path, target_domain=None, **kwargs):
**kwargs)
output = results['target']
return output


@torch.no_grad()
def sample_ddpm_model(model,
num_samples=16,
num_batches=4,
sample_model='ema',
same_noise=False,
**kwargs):
"""Sampling from ddpm models.
Args:
model (nn.Module): DDPM models in MMGeneration.
num_samples (int, optional): The total number of samples.
Defaults to 16.
num_batches (int, optional): The number of batch size for inference.
Defaults to 4.
sample_model (str, optional): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
noise_batch (torch.Tensor): Noise batch used as denoising starting up.
Defaults to None.
Returns:
list[Tensor | dict]: Generated image tensor.
"""
model.eval()

n_repeat = num_samples // num_batches
batches_list = [num_batches] * n_repeat

if num_samples % num_batches > 0:
batches_list.append(num_samples % num_batches)

noise_batch = torch.randn(model.image_shape) if same_noise else None

res_list = []
# inference
for idx, batches in enumerate(batches_list):
mmcv.print_log(
f'Start to sample batch [{idx+1} / '
f'{len(batches_list)}]', 'mmgen')
noise_batch_ = noise_batch[None, ...].expand(batches, -1, -1, -1) \
if same_noise else None

res = model.sample_from_noise(
noise_batch_,
num_batches=batches,
sample_model=sample_model,
show_pbar=True,
**kwargs)
if isinstance(res, dict):
res = {k: v.cpu() for k, v in res.items()}
elif isinstance(res, torch.Tensor):
res = res.cpu()
else:
raise ValueError('Sample results should be \'dict\' or '
f'\'torch.Tensor\', but receive \'{type(res)}\'')
res_list.append(res)

# gather the res_list
if isinstance(res_list[0], dict):
res_dict = dict()
for t in res_list[0].keys():
# num_samples x 3 x H x W
res_dict[t] = torch.cat([res[t] for res in res_list], dim=0)
return res_dict
else:
return torch.cat(res_list, dim=0)
51 changes: 50 additions & 1 deletion tests/test_apis/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch

from mmgen.apis import (init_model, sample_img2img_model,
from mmgen.apis import (init_model, sample_ddpm_model, sample_img2img_model,
sample_uncoditional_model)


Expand Down Expand Up @@ -78,3 +78,52 @@ def test_translation_model_cuda(self):
res = sample_img2img_model(
self.cyclegan.cuda(), self.img_path, target_domain='photo')
assert res.shape == (1, 3, 256, 256)


class TestDiffusionModel:

@classmethod
def setup_class(cls):
project_dir = os.path.abspath(os.path.join(__file__, '../../..'))
ddpm_config = mmcv.Config.fromfile(
os.path.join(
project_dir, 'configs/improved_ddpm/'
'ddpm_cosine_hybird_timestep-4k_drop0.3_'
'cifar10_32x32_b8x16_500k.py'))
# change timesteps to speed up test process
ddpm_config.model['num_timesteps'] = 10
cls.model = init_model(ddpm_config, checkpoint=None, device='cpu')

def test_diffusion_model_cpu(self):
# save_intermedia is False
res = sample_ddpm_model(
self.model, num_samples=3, num_batches=2, same_noise=True)
assert res.shape == (3, 3, 32, 32)

# save_intermedia is True
res = sample_ddpm_model(
self.model,
num_samples=2,
num_batches=2,
same_noise=True,
save_intermedia=True)
assert isinstance(res, dict)
assert all([i in res for i in range(10)])

@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_diffusion_model_cuda(self):
model = self.model.cuda()
# save_intermedia is False
res = sample_ddpm_model(
model, num_samples=3, num_batches=2, same_noise=True)
assert res.shape == (3, 3, 32, 32)

# save_intermedia is True
res = sample_ddpm_model(
model,
num_samples=2,
num_batches=2,
same_noise=True,
save_intermedia=True)
assert isinstance(res, dict)
assert all([i in res for i in range(10)])

0 comments on commit d4fcb21

Please # to comment.