-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
157 lines (146 loc) · 5.94 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
arg_lists = []
parser = argparse.ArgumentParser(description='PoseDepth')
def str2bool(v):
return v.lower() in ('true', '1')
def add_argument_group(name):
arg = parser.add_argument_group(name)
arg_lists.append(arg)
return arg
# train mode params
mode_arg = add_argument_group('Train mode Params')
mode_arg.add_argument('--dataset',
type=str,
default="kitti_raw",
choices=["kitti_raw", "kitti_odom"],
help='')
# KITTI data params
kitti_arg = add_argument_group('KITTI data Params')
kitti_arg.add_argument('--kitti_raw_txt',
type=str,
default='./splits/eigen_zhou/train_files.txt',
help='Train set.')
kitti_arg.add_argument('--kitti_raw_root',
type=str,
default='',
help='/path/to/your/kitti/raw_data/root.')
kitti_arg.add_argument('--kitti_odom_txt',
type=str,
default='./splits/odom/train_files.txt',
help='Train set.')
kitti_arg.add_argument('--kitti_odom_root',
type=str,
default='',
help='/path/to/your/kitti/odometry/root.')
kitti_arg.add_argument('--kitti_hw',
type=tuple,
default=(192, 640),
choices=[(192, 640), (256, 832), (320, 1024)],
help='')
kitti_arg.add_argument('--img_ext',
type=str,
default='.jpg',
choices=['.png', '.jpg'],
help='')
kitti_arg.add_argument('--frame_ids',
nargs="+",
type=int,
default=[0, -1, 1],
help='')
kitti_arg.add_argument('--num_scales',
type=int,
default=4,
help='')
# training params
train_arg = add_argument_group('Training Params')
train_arg.add_argument('--batch_size',
type=int,
default=12,
help='# of images in each batch of data')
train_arg.add_argument('--num_workers',
type=int,
default=8,
help='# of subprocesses to use for data loading')
train_arg.add_argument('--pin_memory',
type=str2bool,
default=True,
help='# of subprocesses to use for data loading')
train_arg.add_argument('--shuffle',
type=str2bool,
default=True,
help='Whether to shuffle the train and valid indices')
train_arg.add_argument('--optim_policy',
type=str,
default='adam',
help='The optimation policy(adam or sgd).')
train_arg.add_argument('--start_epoch',
type=int,
default=0,
help='Number of epochs to train for.')
train_arg.add_argument('--max_epoch',
type=int,
default=20,
help='Number of epochs to train for.')
train_arg.add_argument('--init_lr',
type=float,
default=1e-4,
help='Initial learning rate value.')
train_arg.add_argument('--lr_factor',
type=float,
default=0.1,
help='Reduce learning rate value.')
train_arg.add_argument('--milestones',
type=list,
default=[15],
help='Reduce learning rate value.')
train_arg.add_argument('--display',
type=int,
default=50,
help='')
# data storage
storage_arg = add_argument_group('Storage')
storage_arg.add_argument('--train_log',
type=str,
default='train',
help='Training record.')
storage_arg.add_argument('--ckpt_dir',
type=str,
default='Res18_monobooster_640x192',
help='Training record.')
# depth net params
depth_net_arg = add_argument_group('DepthNet Params')
depth_net_arg.add_argument("--layer",
type=int,
default=18,
help='')
# depth loss params
depthloss_arg = add_argument_group('Depth net loss functions Params')
depthloss_arg.add_argument('--reproj_weight',
type=float,
default=1.0,
help='')
depthloss_arg.add_argument('--smooth_weight',
type=float,
default=0.001,
help='')
# other params
misc_arg = add_argument_group('Misc.')
misc_arg.add_argument('--gpu',
type=int,
default=0,
help="Which GPU to run on.")
misc_arg.add_argument('--seed',
type=int,
default=1001,
help='Seed to ensure reproducibility.')
misc_arg.add_argument('--ckpt_root',
type=str,
default='./checkpoints',
help='Directory in which to save model checkpoints.')
misc_arg.add_argument('--logs_dir',
type=str,
default='./logs',
help='Directory in which logs wil be stored.')
def get_config():
config, unparsed = parser.parse_known_args()
return config, unparsed