Skip to content

Commit

Permalink
Merge pull request #271 from LeoXing1996/dist_seed
Browse files Browse the repository at this point in the history
[Enchance] Support random seed for distributed sampler
  • Loading branch information
plyfager authored Mar 30, 2022
2 parents ac081e3 + 57cd501 commit b951741
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 5 deletions.
3 changes: 2 additions & 1 deletion mmgen/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def build_dataloader(dataset,
world_size,
rank,
shuffle=shuffle,
samples_per_gpu=samples_per_gpu)
samples_per_gpu=samples_per_gpu,
seed=seed)
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
Expand Down
18 changes: 16 additions & 2 deletions mmgen/datasets/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from torch.utils.data import DistributedSampler as _DistributedSampler

from mmgen.utils import sync_random_seed


class DistributedSampler(_DistributedSampler):
"""DistributedSampler inheriting from
Expand All @@ -19,7 +21,8 @@ def __init__(self,
num_replicas=None,
rank=None,
shuffle=True,
samples_per_gpu=1):
samples_per_gpu=1,
seed=None):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)

self.shuffle = shuffle
Expand All @@ -39,6 +42,13 @@ def __init__(self,
'You may use too small dataset and our distributed '
'sampler cannot pad your dataset correctly. We highly '
'recommend you to use fewer GPUs to finish your work')
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)

def update_sampler(self, dataset, samples_per_gpu=None):
self.dataset = dataset
Expand All @@ -64,7 +74,11 @@ def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
Expand Down
4 changes: 2 additions & 2 deletions mmgen/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .dist_util import check_dist_init
from .dist_util import check_dist_init, sync_random_seed
from .io_utils import MMGEN_CACHE_DIR, download_from_url
from .logger import get_root_logger

__all__ = [
'collect_env', 'get_root_logger', 'download_from_url', 'check_dist_init',
'MMGEN_CACHE_DIR'
'MMGEN_CACHE_DIR', 'sync_random_seed'
]
39 changes: 39 additions & 0 deletions mmgen/utils/dist_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info


def check_dist_init():
return dist.is_available() and dist.is_initialized()


def sync_random_seed(seed=None, device='cuda'):
"""Make sure different ranks share the same seed.
All workers must call
this function, otherwise it will deadlock. This method is generally used in
`DistributedSampler`, because the seed should be identical across all
processes in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)

rank, world_size = get_dist_info()

if world_size == 1:
return seed

if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()

0 comments on commit b951741

Please # to comment.