Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zytx121 committed Mar 10, 2022
1 parent 077ae5d commit 2df7a47
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import mmcv
import torch
import torch.distributed as dist
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash
Expand Down Expand Up @@ -47,6 +48,10 @@ def parse_args():
help='ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--diff_seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
Expand Down Expand Up @@ -142,6 +147,7 @@ def main():

# set random seeds
seed = init_random_seed(args.seed)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
Expand Down

0 comments on commit 2df7a47

Please # to comment.