From 57cd50114a897da12778f8de4d20c5d37ab27cd6 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Wed, 23 Mar 2022 19:25:08 +0800 Subject: [PATCH] support random seed for distributed sampler --- mmgen/datasets/builder.py | 3 +- .../datasets/samplers/distributed_sampler.py | 18 ++++++++- mmgen/utils/__init__.py | 4 +- mmgen/utils/dist_util.py | 39 +++++++++++++++++++ 4 files changed, 59 insertions(+), 5 deletions(-) diff --git a/mmgen/datasets/builder.py b/mmgen/datasets/builder.py index 701909950..5de459ff8 100644 --- a/mmgen/datasets/builder.py +++ b/mmgen/datasets/builder.py @@ -103,7 +103,8 @@ def build_dataloader(dataset, world_size, rank, shuffle=shuffle, - samples_per_gpu=samples_per_gpu) + samples_per_gpu=samples_per_gpu, + seed=seed) shuffle = False batch_size = samples_per_gpu num_workers = workers_per_gpu diff --git a/mmgen/datasets/samplers/distributed_sampler.py b/mmgen/datasets/samplers/distributed_sampler.py index 0641b9ae9..1aa3d77d0 100644 --- a/mmgen/datasets/samplers/distributed_sampler.py +++ b/mmgen/datasets/samplers/distributed_sampler.py @@ -5,6 +5,8 @@ import torch from torch.utils.data import DistributedSampler as _DistributedSampler +from mmgen.utils import sync_random_seed + class DistributedSampler(_DistributedSampler): """DistributedSampler inheriting from @@ -19,7 +21,8 @@ def __init__(self, num_replicas=None, rank=None, shuffle=True, - samples_per_gpu=1): + samples_per_gpu=1, + seed=None): super().__init__(dataset, num_replicas=num_replicas, rank=rank) self.shuffle = shuffle @@ -39,6 +42,13 @@ def __init__(self, 'You may use too small dataset and our distributed ' 'sampler cannot pad your dataset correctly. We highly ' 'recommend you to use fewer GPUs to finish your work') + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. + self.seed = sync_random_seed(seed) def update_sampler(self, dataset, samples_per_gpu=None): self.dataset = dataset @@ -64,7 +74,11 @@ def __iter__(self): # deterministically shuffle based on epoch if self.shuffle: g = torch.Generator() - g.manual_seed(self.epoch) + # When :attr:`shuffle=True`, this ensures all replicas + # use a different random ordering for each epoch. + # Otherwise, the next iteration of this sampler will + # yield the same ordering. + g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = torch.arange(len(self.dataset)).tolist() diff --git a/mmgen/utils/__init__.py b/mmgen/utils/__init__.py index 7dd77a1a7..a277770fd 100644 --- a/mmgen/utils/__init__.py +++ b/mmgen/utils/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .collect_env import collect_env -from .dist_util import check_dist_init +from .dist_util import check_dist_init, sync_random_seed from .io_utils import MMGEN_CACHE_DIR, download_from_url from .logger import get_root_logger __all__ = [ 'collect_env', 'get_root_logger', 'download_from_url', 'check_dist_init', - 'MMGEN_CACHE_DIR' + 'MMGEN_CACHE_DIR', 'sync_random_seed' ] diff --git a/mmgen/utils/dist_util.py b/mmgen/utils/dist_util.py index d3a47edb9..569132328 100644 --- a/mmgen/utils/dist_util.py +++ b/mmgen/utils/dist_util.py @@ -1,6 +1,45 @@ # Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch import torch.distributed as dist +from mmcv.runner import get_dist_info def check_dist_init(): return dist.is_available() and dist.is_initialized() + + +def sync_random_seed(seed=None, device='cuda'): + """Make sure different ranks share the same seed. + + All workers must call + this function, otherwise it will deadlock. This method is generally used in + `DistributedSampler`, because the seed should be identical across all + processes in the distributed group. + In distributed sampling, different ranks should sample non-overlapped + data in the dataset. Therefore, this function is used to make sure that + each rank shuffles the data indices in the same order based + on the same seed. Then different ranks could use different indices + to select non-overlapped data from the same data list. + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + Returns: + int: Seed to be used. + """ + if seed is None: + seed = np.random.randint(2**31) + assert isinstance(seed, int) + + rank, world_size = get_dist_info() + + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item()