-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain.py
158 lines (140 loc) · 4.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import argparse
import difflib
import os
from bcos.experiments.utils import get_configs_and_model_factory
from bcos.training import trainer
def get_parser(add_help=True):
parser = argparse.ArgumentParser(description="Start training.", add_help=add_help)
# specify save dir and experiment config
parser.add_argument(
"--base_directory",
default="./experiments",
help="The base directory to store to.",
)
parser.add_argument(
"--dataset", choices=["ImageNet", "CIFAR10"], help="The dataset."
)
parser.add_argument(
"--base_network", help="The model config or base network to use."
)
parser.add_argument("--experiment_name", help="The name of the experiment to run.")
# other training args
parser.add_argument(
"--track_grad_norm",
default=False,
action="store_true",
help="Track the L_2 norm of the gradient.",
)
parser.add_argument(
"--distributed",
default=False,
action="store_true",
help="Use distributed mode.",
)
parser.add_argument(
"--force-no-resume",
dest="resume",
default=True, # so by default always resume (notice dest!)
action="store_false", # if given do not resume!
help="Force restart/retrain experiment.",
)
parser.add_argument(
"--amp", default=False, action="store_true", help="Use mixed precision."
)
parser.add_argument(
"--jit",
default=False,
action="store_true",
help="Use torch.jit.script on the model.",
)
parser.add_argument(
"--cache_dataset",
default=None,
choices=["onthefly", "shm"],
help="Cache dataset.",
)
parser.add_argument(
"--refresh_rate",
type=int,
help="Refresh rate for progress bar.",
)
# loggers
parser.add_argument(
"--csv_logger", action="store_true", default=False, help="Use CSV logger."
)
parser.add_argument(
"--tensorboard_logger",
action="store_true",
default=False,
help="Use tensorboard logger.",
)
parser.add_argument(
"--wandb_logger", action="store_true", default=False, help="Use WB logger."
)
parser.add_argument(
"--wandb_project",
# here so that custom args validation doesn't complain
default=os.getenv("WANDB_PROJECT"),
help="Project name of run.",
)
parser.add_argument(
"--wandb_id", default=os.getenv("WANDB_ID"), help="WandB ID of the run."
)
parser.add_argument(
"--wandb_name",
default=None, # use args.experiment_name
help="Override wandb exp. name. Default use --experiment_name",
)
# explanations logging
parser.add_argument(
"--explanation_logging",
action="store_true",
dest="explanation_logging",
default=False,
help="Enable explanation logging.",
)
parser.add_argument(
"--explanation_logging_every_n_epochs",
type=int,
default=1,
help="Log explanations every n epochs.",
)
# debugging stuff
parser.add_argument(
"--fast_dev_run",
action="store_true",
default=False,
help="Use trainer's fast dev run mode.",
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Enable debugging mode.",
)
return parser
def _args_validation(args):
# check if config exists
configs, _ = get_configs_and_model_factory(args.dataset, args.base_network)
if args.experiment_name not in configs:
err_msg = f"Unknown config '{args.experiment_name}'!"
possible = difflib.get_close_matches(args.experiment_name, configs.keys())
if possible:
err_msg += f" Did you mean '{possible[0]}'?"
raise RuntimeError(err_msg)
# check for resume
assert hasattr(args, "resume"), "no resume arg in args!"
# stay organized
if args.wandb_logger:
assert args.wandb_project is not None, "Provide a project name for WB!"
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
_args_validation(args)
try:
trainer.run_training(args)
except Exception:
import pdb
if args.debug:
pdb.post_mortem()
raise