diff --git a/demo/ddpm_demo.py b/demo/ddpm_demo.py new file mode 100644 index 000000000..015999a41 --- /dev/null +++ b/demo/ddpm_demo.py @@ -0,0 +1,155 @@ +import argparse +import os +import os.path as osp +import sys + +import mmcv +import numpy as np +import torch +from mmcv import DictAction +from torchvision import utils + +# yapf: disable +sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa + +from mmgen.apis import init_model, sample_ddpm_model # isort:skip # noqa +# yapf: enable + + +def parse_args(): + parser = argparse.ArgumentParser(description='DDPM demo') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--save-path', + type=str, + default='./work_dirs/demos/ddpm_samples.png', + help='path to save uncoditional samples') + parser.add_argument( + '--device', type=str, default='cuda:0', help='CUDA device id') + + # args for inference/sampling + parser.add_argument( + '--num-batches', type=int, default=4, help='Batch size in inference') + parser.add_argument( + '--num-samples', + type=int, + default=12, + help='The total number of samples') + parser.add_argument( + '--sample-model', + type=str, + default='ema', + help='Which model to use for sampling') + parser.add_argument( + '--sample-cfg', + nargs='+', + action=DictAction, + help='Other customized kwargs for sampling function') + parser.add_argument( + '--same-noise', + action='store_true', + help='whether use same noise as input (x_T)') + parser.add_argument( + '--n-skip', + type=int, + default=25, + help=('Skip how many steps before selecting one to visualize. This is ' + 'helpful with denoising timestep is too much. Only work with ' + '`save-path` is end with \'.gif\'.')) + + # args for image grid + parser.add_argument( + '--padding', type=int, default=0, help='Padding in the image grid.') + parser.add_argument( + '--nrow', + type=int, + default=2, + help=('Number of images displayed in each row of the grid. ' + 'This argument would work only when label is not given.')) + + # args for image channel order + parser.add_argument( + '--is-rgb', + action='store_true', + help=('If true, color channels will not be permuted, This option is ' + 'useful when inference model trained with rgb images.')) + + args = parser.parse_args() + return args + + +def create_gif(results, gif_name, fps=60, n_skip=1): + """Create gif through imageio. + + Args: + frames (torch.Tensor): Image frames, shape like [bz, 3, H, W]. + gif_name (str): Saved gif name. + fps (int, optional): Frames per second of the generated gif. + Defaults to 60. + n_skip (int, optional): Skip how many steps before selecting one to + visualize. Defaults to 1. + """ + try: + import imageio + except ImportError: + raise RuntimeError('imageio is not installed,' + 'Please use “pip install imageio” to install') + frames_list = [] + for frame in results[::n_skip]: + frames_list.append( + (frame.permute(1, 2, 0).cpu().numpy() * 255.).astype(np.uint8)) + + # ensure the final denoising results in frames_list + if not (len(results) % n_skip == 0): + frames_list.append((results[-1].permute(1, 2, 0).cpu().numpy() * + 255.).astype(np.uint8)) + + imageio.mimsave(gif_name, frames_list, 'GIF', fps=fps) + + +def main(): + args = parse_args() + model = init_model( + args.config, checkpoint=args.checkpoint, device=args.device) + + if args.sample_cfg is None: + args.sample_cfg = dict() + + suffix = osp.splitext(args.save_path)[-1] + if suffix == '.gif': + args.sample_cfg['save_intermedia'] = True + + results = sample_ddpm_model(model, args.num_samples, args.num_batches, + args.sample_model, args.same_noise, + **args.sample_cfg) + + # save images + mmcv.mkdir_or_exist(os.path.dirname(args.save_path)) + if suffix == '.gif': + # concentrate all output of each timestep + results_timestep_list = [] + for t in results.keys(): + # make grid + results_timestep = utils.make_grid( + results[t], nrow=args.nrow, padding=args.padding) + # unsqueeze at 0, because make grid output is size like [H', W', 3] + results_timestep_list.append(results_timestep[None, ...]) + + # Concatenates to [n_timesteps, H', W', 3] + results_timestep = torch.cat(results_timestep_list, dim=0) + if not args.is_rgb: + results_timestep = results_timestep[:, [2, 1, 0]] + results_timestep = (results_timestep + 1.) / 2. + create_gif(results_timestep, args.save_path, n_skip=args.n_skip) + else: + if not args.is_rgb: + results = results[:, [2, 1, 0]] + + results = (results + 1.) / 2. + utils.save_image( + results, args.save_path, nrow=args.nrow, padding=args.padding) + + +if __name__ == '__main__': + main() diff --git a/mmgen/apis/__init__.py b/mmgen/apis/__init__.py index bf2d66074..57cbebba7 100644 --- a/mmgen/apis/__init__.py +++ b/mmgen/apis/__init__.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .inference import (init_model, sample_conditional_model, - sample_img2img_model, sample_uncoditional_model) + sample_ddpm_model, sample_img2img_model, + sample_uncoditional_model) from .train import set_random_seed, train_model __all__ = [ 'set_random_seed', 'train_model', 'init_model', 'sample_img2img_model', - 'sample_uncoditional_model', 'sample_conditional_model' + 'sample_uncoditional_model', 'sample_conditional_model', + 'sample_ddpm_model' ] diff --git a/mmgen/apis/inference.py b/mmgen/apis/inference.py index 4d8754df2..68f150f5e 100644 --- a/mmgen/apis/inference.py +++ b/mmgen/apis/inference.py @@ -227,3 +227,71 @@ def sample_img2img_model(model, image_path, target_domain=None, **kwargs): **kwargs) output = results['target'] return output + + +@torch.no_grad() +def sample_ddpm_model(model, + num_samples=16, + num_batches=4, + sample_model='ema', + same_noise=False, + **kwargs): + """Sampling from ddpm models. + + Args: + model (nn.Module): DDPM models in MMGeneration. + num_samples (int, optional): The total number of samples. + Defaults to 16. + num_batches (int, optional): The number of batch size for inference. + Defaults to 4. + sample_model (str, optional): Which model you want to use. ['ema', + 'orig']. Defaults to 'ema'. + noise_batch (torch.Tensor): Noise batch used as denoising starting up. + Defaults to None. + + Returns: + list[Tensor | dict]: Generated image tensor. + """ + model.eval() + + n_repeat = num_samples // num_batches + batches_list = [num_batches] * n_repeat + + if num_samples % num_batches > 0: + batches_list.append(num_samples % num_batches) + + noise_batch = torch.randn(model.image_shape) if same_noise else None + + res_list = [] + # inference + for idx, batches in enumerate(batches_list): + mmcv.print_log( + f'Start to sample batch [{idx+1} / ' + f'{len(batches_list)}]', 'mmgen') + noise_batch_ = noise_batch[None, ...].expand(batches, -1, -1, -1) \ + if same_noise else None + + res = model.sample_from_noise( + noise_batch_, + num_batches=batches, + sample_model=sample_model, + show_pbar=True, + **kwargs) + if isinstance(res, dict): + res = {k: v.cpu() for k, v in res.items()} + elif isinstance(res, torch.Tensor): + res = res.cpu() + else: + raise ValueError('Sample results should be \'dict\' or ' + f'\'torch.Tensor\', but receive \'{type(res)}\'') + res_list.append(res) + + # gather the res_list + if isinstance(res_list[0], dict): + res_dict = dict() + for t in res_list[0].keys(): + # num_samples x 3 x H x W + res_dict[t] = torch.cat([res[t] for res in res_list], dim=0) + return res_dict + else: + return torch.cat(res_list, dim=0) diff --git a/tests/test_apis/test_inference.py b/tests/test_apis/test_inference.py index 424f05b87..3e6d5066e 100644 --- a/tests/test_apis/test_inference.py +++ b/tests/test_apis/test_inference.py @@ -4,7 +4,7 @@ import pytest import torch -from mmgen.apis import (init_model, sample_img2img_model, +from mmgen.apis import (init_model, sample_ddpm_model, sample_img2img_model, sample_uncoditional_model) @@ -78,3 +78,52 @@ def test_translation_model_cuda(self): res = sample_img2img_model( self.cyclegan.cuda(), self.img_path, target_domain='photo') assert res.shape == (1, 3, 256, 256) + + +class TestDiffusionModel: + + @classmethod + def setup_class(cls): + project_dir = os.path.abspath(os.path.join(__file__, '../../..')) + ddpm_config = mmcv.Config.fromfile( + os.path.join( + project_dir, 'configs/improved_ddpm/' + 'ddpm_cosine_hybird_timestep-4k_drop0.3_' + 'cifar10_32x32_b8x16_500k.py')) + # change timesteps to speed up test process + ddpm_config.model['num_timesteps'] = 10 + cls.model = init_model(ddpm_config, checkpoint=None, device='cpu') + + def test_diffusion_model_cpu(self): + # save_intermedia is False + res = sample_ddpm_model( + self.model, num_samples=3, num_batches=2, same_noise=True) + assert res.shape == (3, 3, 32, 32) + + # save_intermedia is True + res = sample_ddpm_model( + self.model, + num_samples=2, + num_batches=2, + same_noise=True, + save_intermedia=True) + assert isinstance(res, dict) + assert all([i in res for i in range(10)]) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') + def test_diffusion_model_cuda(self): + model = self.model.cuda() + # save_intermedia is False + res = sample_ddpm_model( + model, num_samples=3, num_batches=2, same_noise=True) + assert res.shape == (3, 3, 32, 32) + + # save_intermedia is True + res = sample_ddpm_model( + model, + num_samples=2, + num_batches=2, + same_noise=True, + save_intermedia=True) + assert isinstance(res, dict) + assert all([i in res for i in range(10)])