-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patharguments_ddp.py
106 lines (85 loc) · 3.47 KB
/
arguments_ddp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
import os
import torch
import numpy as np
import torch
import torch.distributed as dist
import random
import re
import yaml
import shutil
import warnings
from datetime import datetime
class Namespace(object):
def __init__(self, somedict):
for key, value in somedict.items():
assert isinstance(key, str) and re.match("[A-Za-z_-]", key)
if isinstance(value, dict):
self.__dict__[key] = Namespace(value)
else:
self.__dict__[key] = value
def __getattr__(self, attribute):
raise AttributeError(f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!")
def set_deterministic(seed):
# seed by default is None
if seed is not None:
print(f"Deterministic with seed = {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_args(create_log=True):
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config-file', required=True, type=str, help="xxx.yaml")
parser.add_argument('--debug', action='store_true')
parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web")
parser.add_argument('--data_dir', type=str, default=os.getenv('DATA'))
parser.add_argument('--log_dir', type=str, default=os.getenv('LOG'))
parser.add_argument('--ckpt_dir', type=str, default=os.getenv('CHECKPOINT'))
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--eval_from', type=str, default=None)
parser.add_argument('--hide_progress', action='store_true')
# DDP
parser.add_argument("--local_rank", type=int, default=-1)
args = parser.parse_args()
local_rank = args.local_rank
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
device = torch.device("cuda", local_rank)
with open(args.config_file, 'r') as f:
for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items():
vars(args)[key] = value
if args.debug:
if args.train:
args.train.batch_size = 2
args.train.num_epochs = 1
args.train.stop_at_epoch = 1
if args.eval:
args.eval.batch_size = 2
args.eval.num_epochs = 1 # train only one epoch
args.dataset.num_workers = 0
assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name]
args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_')+args.name)
if dist.get_rank() == 0 and create_log:
os.makedirs(args.log_dir, exist_ok=False)
print(f'creating file {args.log_dir}')
os.makedirs(args.ckpt_dir, exist_ok=True)
shutil.copy2(args.config_file, args.log_dir)
set_deterministic(args.seed)
vars(args)['aug_kwargs'] = {
'name':args.model.name,
'image_size': args.dataset.image_size
}
vars(args)['dataset_kwargs'] = {
'dataset':args.dataset.name,
'data_dir': args.data_dir,
'download':args.download,
}
vars(args)['dataloader_kwargs'] = {
'drop_last': True,
'pin_memory': True,
'num_workers': args.dataset.num_workers,
}
return args, device, local_rank