-
Notifications
You must be signed in to change notification settings - Fork 5
/
train_KITTI.py
115 lines (102 loc) · 4.31 KB
/
train_KITTI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os, logging, sys, shutil, json
from config import get_config
from easydict import EasyDict as edict
from libs.loss import TransformationLoss, ClassificationLoss, SpectralMatchingLoss
from datasets.KITTI import KITTIDataset
from datasets.dataloader import get_dataloader
from libs.trainer import Trainer
from models.VBPointDSC import VBPointDSC
from torch import optim
if __name__ == '__main__':
config = get_config()
dconfig = vars(config)
for k in dconfig:
print(f" {k}: {dconfig[k]}")
config = edict(dconfig)
config.is_cal_upper = False
config.is_plot_attention = False
os.makedirs(config.snapshot_dir, exist_ok=True)
os.makedirs(config.tboard_dir, exist_ok=True)
os.makedirs(config.save_dir, exist_ok=True)
shutil.copy2(os.path.join('.', 'train_KITTI.py'), os.path.join(config.snapshot_dir, 'train.py'))
shutil.copy2(os.path.join('.', 'libs/trainer.py'), os.path.join(config.snapshot_dir, 'trainer.py'))
shutil.copy2(os.path.join('.', 'models/VBPointDSC.py'), os.path.join(config.snapshot_dir, 'model.py')) # for the model setting.
shutil.copy2(os.path.join('.', 'libs/loss.py'), os.path.join(config.snapshot_dir, 'loss.py'))
shutil.copy2(os.path.join('.', 'datasets/KITTI.py'), os.path.join(config.snapshot_dir, 'dataset.py'))
json.dump(config, open(os.path.join(config.snapshot_dir, 'config.json'), 'w'), indent=4,)
# create model
config.model = VBPointDSC(config)
# create optimizer
if config.optimizer == 'SGD':
config.optimizer = optim.SGD(
config.model.parameters(),
lr=config.lr,
momentum=config.momentum,
weight_decay=config.weight_decay,
)
elif config.optimizer == 'ADAM':
config.optimizer = optim.Adam(
config.model.parameters(),
lr=config.lr,
betas=(0.9, 0.999),
# momentum=config.momentum,
weight_decay=config.weight_decay,
)
config.scheduler = optim.lr_scheduler.ExponentialLR(
config.optimizer,
gamma=config.scheduler_gamma,
)
# create dataset and dataloader
train_set = KITTIDataset(
root=config.root,
split='train',
descriptor=config.descriptor,
in_dim=config.in_dim,
inlier_threshold=config.inlier_threshold,
num_node=config.num_node,
use_mutual=config.use_mutual,
augment_axis=config.augment_axis,
augment_rotation=config.augment_rotation,
augment_translation=config.augment_translation,
)
val_set = KITTIDataset(
root=config.root,
split='val',
descriptor=config.descriptor,
in_dim=config.in_dim,
inlier_threshold=config.inlier_threshold,
num_node=config.num_node,
use_mutual=config.use_mutual,
augment_axis=config.augment_axis,
augment_rotation=config.augment_rotation,
augment_translation=config.augment_translation,
)
train_set[10]
val_set[10]
config.train_loader = get_dataloader(dataset=train_set,
batch_size=config.batch_size,
num_workers=config.num_workers,
)
config.val_loader = get_dataloader(dataset=val_set,
batch_size=config.batch_size_val,
num_workers=config.num_workers,
)
# create evaluation
config.evaluate_metric = {
"ClassificationLoss": ClassificationLoss(balanced=config.balanced),
"SpectralMatchingLoss": SpectralMatchingLoss(balanced=config.balanced),
"TransformationLoss": TransformationLoss(re_thre=config.re_thre, te_thre=config.te_thre),
}
config.metric_weight = {
"ClassificationLoss": config.weight_classification,
"SpectralMatchingLoss": config.weight_spectralmatching,
"TransformationLoss": config.weight_transformation,
}
logging.basicConfig(level=logging.INFO,
filename=os.path.join(config.snapshot_dir, 'log.txt'),
filemode='a',
format="")
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
config.logging = logging
trainer = Trainer(config)
trainer.train()