From c7b365435a932c6737f84961b8dcb0715143b878 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Sat, 20 Nov 2021 21:53:49 +0800 Subject: [PATCH] [Enhancement] set the same seed for all rank when distributed training --- mmflow/apis/__init__.py | 4 ++-- mmflow/apis/train.py | 33 ++++++++++++++++++++++++++++++++- tools/train.py | 14 +++++++------- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/mmflow/apis/__init__.py b/mmflow/apis/__init__.py index 396ffbea..3c10c92e 100644 --- a/mmflow/apis/__init__.py +++ b/mmflow/apis/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .inference import inference_model, init_model from .test import multi_gpu_test, single_gpu_test -from .train import set_random_seed, train_model +from .train import init_random_seed, set_random_seed, train_model __all__ = [ 'set_random_seed', 'train_model', 'init_model', 'inference_model', - 'multi_gpu_test', 'single_gpu_test' + 'multi_gpu_test', 'single_gpu_test', 'init_random_seed' ] diff --git a/mmflow/apis/train.py b/mmflow/apis/train.py index 1a193de7..44d40908 100644 --- a/mmflow/apis/train.py +++ b/mmflow/apis/train.py @@ -5,9 +5,10 @@ import numpy as np import torch +import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (HOOKS, Fp16OptimizerHook, OptimizerHook, - build_optimizer, build_runner) + build_optimizer, build_runner, get_dist_info) from mmcv.utils import Config, build_from_cfg from mmflow.core import DistEvalHook, EvalHook @@ -18,6 +19,36 @@ Dataset = torch.utils.data.Dataset +def init_random_seed(seed: Optional[int] = None, device: str = 'cuda') -> int: + """Initialize random seed. + + If the seed is not set, the seed will be automatically randomized, + and then broadcast to all processes to avoid some potential bugs. + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + Returns: + int: Seed to be used. + """ + if seed is not None: + return seed + + # Make sure all ranks share the same random seed to avoid + # some potential bugs. Please refer to + # https://github.com/open-mmlab/mmdetection/issues/6339 + rank, world_size = get_dist_info() + seed = np.random.randint(2**31) + if world_size == 1: + return seed + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() + + def set_random_seed(seed: int, deterministic: bool = False) -> None: """Set random seed. diff --git a/tools/train.py b/tools/train.py index ebc65480..0461d8dd 100644 --- a/tools/train.py +++ b/tools/train.py @@ -12,7 +12,7 @@ from mmcv.utils import get_git_hash from mmflow import __version__ -from mmflow.apis import set_random_seed, train_model +from mmflow.apis import init_random_seed, set_random_seed, train_model from mmflow.datasets import build_dataset from mmflow.models import build_flow_estimator from mmflow.utils import collect_env, get_root_logger @@ -130,12 +130,12 @@ def main(): logger.info(f'Config:\n{cfg.pretty_text}') # set random seeds - if args.seed is not None: - logger.info(f'Set random seed to {args.seed}, deterministic: ' - f'{args.deterministic}') - set_random_seed(args.seed, deterministic=args.deterministic) - cfg.seed = args.seed - meta['seed'] = args.seed + seed = init_random_seed(args.seed) + logger.info(f'Set random seed to {seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(seed, deterministic=args.deterministic) + cfg.seed = seed + meta['seed'] = seed meta['exp_name'] = osp.basename(args.config) model = build_flow_estimator(cfg.model)