Skip to content

Commit

Permalink
[Enhancement] set the same seed for all rank when distributed training (
Browse files Browse the repository at this point in the history
  • Loading branch information
MeowZheng authored Nov 20, 2021
1 parent 985cbee commit 500ee41
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 10 deletions.
4 changes: 2 additions & 2 deletions mmflow/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_model, init_model
from .test import multi_gpu_test, single_gpu_test
from .train import set_random_seed, train_model
from .train import init_random_seed, set_random_seed, train_model

__all__ = [
'set_random_seed', 'train_model', 'init_model', 'inference_model',
'multi_gpu_test', 'single_gpu_test'
'multi_gpu_test', 'single_gpu_test', 'init_random_seed'
]
33 changes: 32 additions & 1 deletion mmflow/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, Fp16OptimizerHook, OptimizerHook,
build_optimizer, build_runner)
build_optimizer, build_runner, get_dist_info)
from mmcv.utils import Config, build_from_cfg

from mmflow.core import DistEvalHook, EvalHook
Expand All @@ -18,6 +19,36 @@
Dataset = torch.utils.data.Dataset


def init_random_seed(seed: Optional[int] = None, device: str = 'cuda') -> int:
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to avoid some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed

# Make sure all ranks share the same random seed to avoid
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()


def set_random_seed(seed: int, deterministic: bool = False) -> None:
"""Set random seed.
Expand Down
14 changes: 7 additions & 7 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mmcv.utils import get_git_hash

from mmflow import __version__
from mmflow.apis import set_random_seed, train_model
from mmflow.apis import init_random_seed, set_random_seed, train_model
from mmflow.datasets import build_dataset
from mmflow.models import build_flow_estimator
from mmflow.utils import collect_env, get_root_logger
Expand Down Expand Up @@ -130,12 +130,12 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')

# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, deterministic: '
f'{args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
seed = init_random_seed(args.seed)
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)

model = build_flow_estimator(cfg.model)
Expand Down

0 comments on commit 500ee41

Please # to comment.