diff --git a/tools/train.py b/tools/train.py index e761e2a57..1cf9983ab 100644 --- a/tools/train.py +++ b/tools/train.py @@ -8,6 +8,7 @@ import mmcv import torch +import torch.distributed as dist from mmcv import Config, DictAction from mmcv.runner import get_dist_info, init_dist from mmcv.utils import get_git_hash @@ -47,6 +48,10 @@ def parse_args(): help='ids of gpus to use ' '(only applicable to non-distributed training)') parser.add_argument('--seed', type=int, default=None, help='random seed') + parser.add_argument( + '--diff_seed', + action='store_true', + help='Whether or not set different seeds for different ranks') parser.add_argument( '--deterministic', action='store_true', @@ -142,6 +147,7 @@ def main(): # set random seeds seed = init_random_seed(args.seed) + seed = seed + dist.get_rank() if args.diff_seed else seed logger.info(f'Set random seed to {seed}, ' f'deterministic: {args.deterministic}') set_random_seed(seed, deterministic=args.deterministic)