From c7bb127ef6c05fe79b358a0040e7f56a38748d5a Mon Sep 17 00:00:00 2001 From: LeoXing Date: Mon, 3 Jan 2022 13:12:23 +0800 Subject: [PATCH 1/7] add denoising demo --- demo/ddpm_demo.py | 151 ++++++++++++++++++++++++++++++++++++++++ mmgen/apis/inference.py | 56 +++++++++++++++ 2 files changed, 207 insertions(+) create mode 100644 demo/ddpm_demo.py diff --git a/demo/ddpm_demo.py b/demo/ddpm_demo.py new file mode 100644 index 000000000..c58cae601 --- /dev/null +++ b/demo/ddpm_demo.py @@ -0,0 +1,151 @@ +import argparse +import os +import os.path as osp +import sys + +import mmcv +import numpy as np +import torch +from mmcv import DictAction +from torchvision import utils + +# yapf: disable +sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa + +from mmgen.apis import init_model, sample_ddpm_model # isort:skip # noqa +# yapf: enable + + +def parse_args(): + parser = argparse.ArgumentParser(description='DDPM demo') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--save-path', + type=str, + default='./work_dirs/demos/ddpm_samples.png', + help='path to save uncoditional samples') + parser.add_argument( + '--device', type=str, default='cuda:0', help='CUDA device id') + + # args for inference/sampling + parser.add_argument( + '--num-batches', type=int, default=4, help='Batch size in inference') + parser.add_argument( + '--num-samples', + type=int, + default=12, + help='The total number of samples') + parser.add_argument( + '--sample-model', + type=str, + default='ema', + help='Which model to use for sampling') + parser.add_argument( + '--sample-cfg', + nargs='+', + action=DictAction, + help='Other customized kwargs for sampling function') + parser.add_argument( + '--same-noise', + action='store_true', + help='whether use same noise as input (x_T)') + parser.add_argument( + '--n-skip', + type=int, + default=25, + help=('Skip how many steps before selecting one to visualize. This is ' + 'helpful with denoising timestep is too much. Only work with ' + '`save-path` is end with \'.gif\'.')) + + # args for image grid + parser.add_argument( + '--padding', type=int, default=0, help='Padding in the image grid.') + parser.add_argument( + '--nrow', + type=int, + default=2, + help=('Number of images displayed in each row of the grid. ' + 'This argument would work only when label is not given.')) + + # args for image channel order + parser.add_argument( + '--is-rgb', + action='store_true', + help=('If true, color channels will not be permuted, This option is ' + 'useful when inference model trained with rgb images.')) + + args = parser.parse_args() + return args + + +def create_gif(results, gif_name, fps=60, n_skip=1): + """Create gif through imageio. + + Args: + frames (torch.Tensor): Image frames, shape like [bz, 3, H, W]. + gif_name (str): Saved gif name. + fps (int, optional): Frames per second of the generated gif. + Defaults to 60. + n_skip (int, optional): Skip how many steps before selecting one to + visualize. Defaults to 1. + """ + import imageio + frames_list = [] + for frame in results[::n_skip]: + frames_list.append( + (frame.permute(1, 2, 0).cpu().numpy() * 255.).astype(np.uint8)) + if imageio is None: + raise RuntimeError('imageio is not installed,' + 'Please use “pip install imageio” to install') + imageio.mimsave(gif_name, frames_list, 'GIF', fps=fps) + + +def main(): + args = parse_args() + model = init_model( + args.config, checkpoint=args.checkpoint, device=args.device) + + if args.sample_cfg is None: + args.sample_cfg = dict() + + # noise_batch = None # set default noise batch as None + suffix = osp.splitext(args.save_path)[-1] + if suffix == '.gif': + args.sample_cfg['save_intermedia'] = True + + results = sample_ddpm_model(model, args.num_samples, args.num_batches, + args.sample_model, args.same_noise, + **args.sample_cfg) + + # save images + mmcv.mkdir_or_exist(os.path.dirname(args.save_path)) + if suffix == '.gif': + # concentrate all output of each timestep + results_timestep_list = [] + for t in results[0].keys(): + # num_samples x 3 x H x W + results_timestep_ = torch.cat([res[t] for res in results], dim=0) + # make grid + results_timestep_ = utils.make_grid( + results_timestep_, nrow=args.nrow, padding=args.padding) + # unsqueeze at 0, because make grid output is size like [H', W', 3] + results_timestep_list.append(results_timestep_[None, ...]) + + # Concatenates to [n_timesteps, H', W', 3] + results_timestep = torch.cat(results_timestep_list, dim=0) + if not args.is_rgb: + results_timestep = results_timestep[:, [2, 1, 0]] + results_timestep = (results_timestep + 1.) / 2. + create_gif(results_timestep, args.save_path, n_skip=args.n_skip) + else: + if not args.is_rgb: + results = results[:, [2, 1, 0]] + + results = (results + 1.) / 2. + utils.save_image( + results, args.save_path, nrow=args.nrow, padding=args.padding) + + +if __name__ == '__main__': + main() diff --git a/mmgen/apis/inference.py b/mmgen/apis/inference.py index 4d8754df2..4e895c8c2 100644 --- a/mmgen/apis/inference.py +++ b/mmgen/apis/inference.py @@ -227,3 +227,59 @@ def sample_img2img_model(model, image_path, target_domain=None, **kwargs): **kwargs) output = results['target'] return output + + +@torch.no_grad() +def sample_ddpm_model(model, + num_samples=16, + num_batches=4, + sample_model='ema', + same_noise=False, + **kwargs): + """Sampling from ddpm models. + + Args: + model (nn.Module): DDPM models in MMGeneration. + num_samples (int, optional): The total number of samples. + Defaults to 16. + num_batches (int, optional): The number of batch size for inference. + Defaults to 4. + sample_model (str, optional): Which model you want to use. ['ema', + 'orig']. Defaults to 'ema'. + noise_batch (torch.Tensor): Noise batch used as denoising starting up. + Defaults to None. + + Returns: + Tensor: Generated image tensor. + """ + model.eval() + + n_repeat = num_samples // num_batches + batches_list = [num_batches] * n_repeat + + if num_samples % num_batches > 0: + batches_list.append(num_samples % num_batches) + + noise_batch = torch.randn(model.image_shape) if same_noise else None + + res_list = [] + # inference + for idx, batches in enumerate(batches_list): + mmcv.print_log( + f'Start to sample batch [{idx+1} / ' + f'{len(batches_list)}]', 'mmgen') + noise_batch_ = noise_batch[None, ...].expand(batches, -1, -1, -1) \ + if same_noise else None + + res = model.sample_from_noise( + noise_batch_, + num_batches=batches, + sample_model=sample_model, + show_pbar=True, + **kwargs) + if isinstance(res, dict): + res = {k: v.cpu() for k, v in res.items()} + else: + res = res.cpu() + res_list.append(res) + return res_list From 3f65853e25b829c739ae7c677518f69b05cbc3a0 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Mon, 3 Jan 2022 23:36:47 +0800 Subject: [PATCH 2/7] solve import bug --- mmgen/apis/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mmgen/apis/__init__.py b/mmgen/apis/__init__.py index bf2d66074..57cbebba7 100644 --- a/mmgen/apis/__init__.py +++ b/mmgen/apis/__init__.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .inference import (init_model, sample_conditional_model, - sample_img2img_model, sample_uncoditional_model) + sample_ddpm_model, sample_img2img_model, + sample_uncoditional_model) from .train import set_random_seed, train_model __all__ = [ 'set_random_seed', 'train_model', 'init_model', 'sample_img2img_model', - 'sample_uncoditional_model', 'sample_conditional_model' + 'sample_uncoditional_model', 'sample_conditional_model', + 'sample_ddpm_model' ] From 0b83cd1d1f773ad646a10f7edc68c403da98e99f Mon Sep 17 00:00:00 2001 From: LeoXing Date: Tue, 4 Jan 2022 10:31:41 +0800 Subject: [PATCH 3/7] add the final denoising results to gif when num_timesteps cannot be devided by n_skip --- demo/ddpm_demo.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/demo/ddpm_demo.py b/demo/ddpm_demo.py index c58cae601..6df884ec6 100644 --- a/demo/ddpm_demo.py +++ b/demo/ddpm_demo.py @@ -95,6 +95,12 @@ def create_gif(results, gif_name, fps=60, n_skip=1): for frame in results[::n_skip]: frames_list.append( (frame.permute(1, 2, 0).cpu().numpy() * 255.).astype(np.uint8)) + + # ensure the final denoising results in frames_list + if not (len(results) % n_skip == 0): + frames_list.append((results[-1].permute(1, 2, 0).cpu().numpy() * + 255.).astype(np.uint8)) + if imageio is None: raise RuntimeError('imageio is not installed,' 'Please use “pip install imageio” to install') From c4ae18110c79ebcd7c1194f12ee7e549ba2d7dcd Mon Sep 17 00:00:00 2001 From: LeoXing Date: Tue, 4 Jan 2022 10:35:59 +0800 Subject: [PATCH 4/7] remove comment --- demo/ddpm_demo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/demo/ddpm_demo.py b/demo/ddpm_demo.py index 6df884ec6..33da1c4ee 100644 --- a/demo/ddpm_demo.py +++ b/demo/ddpm_demo.py @@ -115,7 +115,6 @@ def main(): if args.sample_cfg is None: args.sample_cfg = dict() - # noise_batch = None # set default noise batch as None suffix = osp.splitext(args.save_path)[-1] if suffix == '.gif': args.sample_cfg['save_intermedia'] = True From d5935f80386a4ea1bd5ab6ed03f3b5b66935043e Mon Sep 17 00:00:00 2001 From: LeoXing Date: Mon, 10 Jan 2022 19:38:13 +0800 Subject: [PATCH 5/7] revise known issue --- demo/ddpm_demo.py | 9 +++++---- mmgen/apis/inference.py | 5 ++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/demo/ddpm_demo.py b/demo/ddpm_demo.py index 33da1c4ee..141fbf09c 100644 --- a/demo/ddpm_demo.py +++ b/demo/ddpm_demo.py @@ -90,7 +90,11 @@ def create_gif(results, gif_name, fps=60, n_skip=1): n_skip (int, optional): Skip how many steps before selecting one to visualize. Defaults to 1. """ - import imageio + try: + import imageio + except ImportError: + raise RuntimeError('imageio is not installed,' + 'Please use “pip install imageio” to install') frames_list = [] for frame in results[::n_skip]: frames_list.append( @@ -101,9 +105,6 @@ def create_gif(results, gif_name, fps=60, n_skip=1): frames_list.append((results[-1].permute(1, 2, 0).cpu().numpy() * 255.).astype(np.uint8)) - if imageio is None: - raise RuntimeError('imageio is not installed,' - 'Please use “pip install imageio” to install') imageio.mimsave(gif_name, frames_list, 'GIF', fps=fps) diff --git a/mmgen/apis/inference.py b/mmgen/apis/inference.py index 4e895c8c2..293730a65 100644 --- a/mmgen/apis/inference.py +++ b/mmgen/apis/inference.py @@ -279,7 +279,10 @@ def sample_ddpm_model(model, **kwargs) if isinstance(res, dict): res = {k: v.cpu() for k, v in res.items()} - else: + elif isinstance(res, torch.Tensor): res = res.cpu() + else: + raise ValueError('Sample results should be \'dict\' or ' + f'\'torch.Tensor\', but receive \'{type(res)}\'') res_list.append(res) return res_list From acb0665d27395a36ca2e3523faf590acbe9fb121 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Tue, 11 Jan 2022 13:13:50 +0800 Subject: [PATCH 6/7] add unit test --- demo/ddpm_demo.py | 10 +++---- mmgen/apis/inference.py | 13 ++++++-- tests/test_apis/test_inference.py | 50 ++++++++++++++++++++++++++++++- 3 files changed, 64 insertions(+), 9 deletions(-) diff --git a/demo/ddpm_demo.py b/demo/ddpm_demo.py index 141fbf09c..015999a41 100644 --- a/demo/ddpm_demo.py +++ b/demo/ddpm_demo.py @@ -129,14 +129,12 @@ def main(): if suffix == '.gif': # concentrate all output of each timestep results_timestep_list = [] - for t in results[0].keys(): - # num_samples x 3 x H x W - results_timestep_ = torch.cat([res[t] for res in results], dim=0) + for t in results.keys(): # make grid - results_timestep_ = utils.make_grid( - results_timestep_, nrow=args.nrow, padding=args.padding) + results_timestep = utils.make_grid( + results[t], nrow=args.nrow, padding=args.padding) # unsqueeze at 0, because make grid output is size like [H', W', 3] - results_timestep_list.append(results_timestep_[None, ...]) + results_timestep_list.append(results_timestep[None, ...]) # Concatenates to [n_timesteps, H', W', 3] results_timestep = torch.cat(results_timestep_list, dim=0) diff --git a/mmgen/apis/inference.py b/mmgen/apis/inference.py index 293730a65..68f150f5e 100644 --- a/mmgen/apis/inference.py +++ b/mmgen/apis/inference.py @@ -250,7 +250,7 @@ def sample_ddpm_model(model, Defaults to None. Returns: - Tensor: Generated image tensor. + list[Tensor | dict]: Generated image tensor. """ model.eval() @@ -285,4 +285,13 @@ def sample_ddpm_model(model, raise ValueError('Sample results should be \'dict\' or ' f'\'torch.Tensor\', but receive \'{type(res)}\'') res_list.append(res) - return res_list + + # gather the res_list + if isinstance(res_list[0], dict): + res_dict = dict() + for t in res_list[0].keys(): + # num_samples x 3 x H x W + res_dict[t] = torch.cat([res[t] for res in res_list], dim=0) + return res_dict + else: + return torch.cat(res_list, dim=0) diff --git a/tests/test_apis/test_inference.py b/tests/test_apis/test_inference.py index 424f05b87..d99cd4d31 100644 --- a/tests/test_apis/test_inference.py +++ b/tests/test_apis/test_inference.py @@ -4,7 +4,7 @@ import pytest import torch -from mmgen.apis import (init_model, sample_img2img_model, +from mmgen.apis import (init_model, sample_ddpm_model, sample_img2img_model, sample_uncoditional_model) @@ -78,3 +78,51 @@ def test_translation_model_cuda(self): res = sample_img2img_model( self.cyclegan.cuda(), self.img_path, target_domain='photo') assert res.shape == (1, 3, 256, 256) + + +class TestDiffusionModel: + + @classmethod + def setup_class(cls): + project_dir = os.path.abspath(os.path.join(__file__, '../../..')) + ddpm_config = mmcv.Config.fromfile( + os.path.join( + project_dir, 'configs/improved_ddpm/' + 'ddpm_cosine_hybird_timestep-4k_drop0.3_' + 'cifar10_32x32_b8x16_500k.py')) + # change timesteps to speed up test process + ddpm_config.model['num_timesteps'] = 10 + cls.model = init_model(ddpm_config, checkpoint=None, device='cpu') + + def test_diffusion_model_cpu(self): + # save_intermedia is False + res = sample_ddpm_model( + self.model, num_samples=3, num_batches=2, same_noise=True) + assert res.shape == (3, 3, 32, 32) + + # save_intermedia is True + res = sample_ddpm_model( + self.model, + num_samples=2, + num_batches=2, + same_noise=True, + save_intermedia=True) + assert isinstance(res, dict) + assert all([i in res for i in range(10)]) + + def test_diffusion_model_cuda(self): + model = self.model.cuda() + # save_intermedia is False + res = sample_ddpm_model( + model, num_samples=3, num_batches=2, same_noise=True) + assert res.shape == (3, 3, 32, 32) + + # save_intermedia is True + res = sample_ddpm_model( + model, + num_samples=2, + num_batches=2, + same_noise=True, + save_intermedia=True) + assert isinstance(res, dict) + assert all([i in res for i in range(10)]) From 83a2e605bc0742d2e8d4fcbf473ae99ec643e4d5 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Tue, 11 Jan 2022 13:36:40 +0800 Subject: [PATCH 7/7] add skip cuda in unit test --- tests/test_apis/test_inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_apis/test_inference.py b/tests/test_apis/test_inference.py index d99cd4d31..3e6d5066e 100644 --- a/tests/test_apis/test_inference.py +++ b/tests/test_apis/test_inference.py @@ -110,6 +110,7 @@ def test_diffusion_model_cpu(self): assert isinstance(res, dict) assert all([i in res for i in range(10)]) + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') def test_diffusion_model_cuda(self): model = self.model.cuda() # save_intermedia is False