-
Notifications
You must be signed in to change notification settings - Fork 232
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #271 from LeoXing1996/dist_seed
[Enchance] Support random seed for distributed sampler
- Loading branch information
Showing
4 changed files
with
59 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .collect_env import collect_env | ||
from .dist_util import check_dist_init | ||
from .dist_util import check_dist_init, sync_random_seed | ||
from .io_utils import MMGEN_CACHE_DIR, download_from_url | ||
from .logger import get_root_logger | ||
|
||
__all__ = [ | ||
'collect_env', 'get_root_logger', 'download_from_url', 'check_dist_init', | ||
'MMGEN_CACHE_DIR' | ||
'MMGEN_CACHE_DIR', 'sync_random_seed' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,45 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import numpy as np | ||
import torch | ||
import torch.distributed as dist | ||
from mmcv.runner import get_dist_info | ||
|
||
|
||
def check_dist_init(): | ||
return dist.is_available() and dist.is_initialized() | ||
|
||
|
||
def sync_random_seed(seed=None, device='cuda'): | ||
"""Make sure different ranks share the same seed. | ||
All workers must call | ||
this function, otherwise it will deadlock. This method is generally used in | ||
`DistributedSampler`, because the seed should be identical across all | ||
processes in the distributed group. | ||
In distributed sampling, different ranks should sample non-overlapped | ||
data in the dataset. Therefore, this function is used to make sure that | ||
each rank shuffles the data indices in the same order based | ||
on the same seed. Then different ranks could use different indices | ||
to select non-overlapped data from the same data list. | ||
Args: | ||
seed (int, Optional): The seed. Default to None. | ||
device (str): The device where the seed will be put on. | ||
Default to 'cuda'. | ||
Returns: | ||
int: Seed to be used. | ||
""" | ||
if seed is None: | ||
seed = np.random.randint(2**31) | ||
assert isinstance(seed, int) | ||
|
||
rank, world_size = get_dist_info() | ||
|
||
if world_size == 1: | ||
return seed | ||
|
||
if rank == 0: | ||
random_num = torch.tensor(seed, dtype=torch.int32, device=device) | ||
else: | ||
random_num = torch.tensor(0, dtype=torch.int32, device=device) | ||
dist.broadcast(random_num, src=0) | ||
return random_num.item() |