-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathparsac.py
126 lines (91 loc) · 5.58 KB
/
parsac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from utils import \
options, initialisation, sampling, backward, visualisation, evaluation, residual_functions, inlier_counting, metrics, postprocessing
import torch
import time
opt = options.get_options()
initialisation.seeds(opt)
ckpt_dir, log = initialisation.setup_logging_and_checkpointing(opt)
model, optimizer, scheduler, device = initialisation.get_model(opt)
datasets = initialisation.get_dataset(opt)
for epoch in range(opt.epochs):
print("epoch %d / %d" % (epoch + 1, opt.epochs))
dataloaders = initialisation.get_dataloader(opt, datasets, shuffle_all=False)
for mode in opt.modes:
assert not (dataloaders[mode] is None), "no dataloader for %s available" % mode
print("mode: %s" % mode)
if mode == "train":
model.train()
else:
model.eval()
eval_metrics = {"loss": [], "time": [], "total_time": []}
wandb_log_data = {}
total_start = time.time()
for batch_idx, (features, X, gt_labels, gt_models, image, image_size) in enumerate(dataloaders[mode]):
for run_idx in range(opt.runcount):
X = X.to(device)
features = features.to(device)
gt_labels = gt_labels.to(device)
gt_models = gt_models.to(device)
image_size = image_size.to(device)
optimizer.zero_grad()
start_time = time.time()
log_inlier_weights, log_sample_weights = model(features)
with torch.no_grad():
minimal_sets = sampling.sample_minimal_sets(opt, log_sample_weights)
hypotheses = sampling.generate_hypotheses(opt, X, minimal_sets)
residuals = residual_functions.compute_residuals(opt, X, hypotheses)
weighted_inlier_ratios, inlier_scores = \
inlier_counting.count_inliers(opt, residuals, log_inlier_weights)
log_p_M_S, sampled_inlier_scores, sampled_hypotheses, sampled_residuals = \
sampling.sample_hypotheses(opt, mode, hypotheses, weighted_inlier_ratios, inlier_scores, residuals)
if opt.refine:
if opt.problem == "vp":
sampled_hypotheses, sampled_residuals, sampled_inlier_scores = \
postprocessing.refinement_with_inliers(opt, X, sampled_inlier_scores)
ranked_choices, ranked_inlier_ratios, ranked_hypotheses, ranked_scores, labels, clusters = \
postprocessing.ranking_and_clustering(opt, sampled_inlier_scores, sampled_hypotheses,
sampled_residuals)
duration = (time.time() - start_time) * 1000
eval_metrics["time"] += [duration]
if not opt.self_supervised:
with torch.no_grad():
if opt.problem == "vp":
exp_losses, _ = metrics.vp_loss(gt_models, ranked_hypotheses, datasets["inverse_intrinsics"])
elif opt.problem == "fundamental" or opt.problem == "homography":
exp_losses = metrics.classification_loss(opt, gt_labels, clusters)
else:
assert False
else:
cumulative_inlier_losses = inlier_counting.compute_cumulative_inliers(opt, ranked_scores)
sample_inlier_counts = inlier_counting.combine_hypotheses_inliers(ranked_scores)
exp_losses = backward.expected_self_losses(opt, sample_inlier_counts, cumulative_inlier_losses)
if mode == "train":
log_p_M = backward.log_probabilities(log_sample_weights, minimal_sets, log_p_M_S)
_ = backward.backward_pass(opt, exp_losses, log_p_M, optimizer)
else:
eval_metrics = \
evaluation.compute_validation_metrics(opt, eval_metrics, ranked_hypotheses,
ranked_inlier_ratios, gt_models, gt_labels, X, image_size, clusters,
run_idx, datasets["inverse_intrinsics"], train=(mode == "train"))
mean_loss = exp_losses.mean()
eval_metrics["loss"] += [mean_loss.item()]
total_duration = (time.time() - total_start) * 1000
eval_metrics["total_time"] += [total_duration]
total_start = time.time()
if opt.visualise:
visualisation.save_visualisation_plots(opt, X, ranked_choices, log_inlier_weights,
log_sample_weights, ranked_scores, clusters,
labels, gt_models, gt_labels, image,
dataloaders[mode].dataset, metrics=eval_metrics)
visualisation.log_wandb(wandb_log_data, eval_metrics, mode, epoch)
if opt.eval:
for key, val in wandb_log_data.items():
print(key, ":", val)
if mode == "train":
scheduler.step()
if opt.ckpt_mode == "all":
torch.save(model.state_dict(), '%s/model_weights_%06d.net' % (ckpt_dir, epoch))
torch.save(optimizer.state_dict(), '%s/optimizer_%06d.net' % (ckpt_dir, epoch))
elif opt.ckpt_mode == "last":
torch.save(model.state_dict(), '%s/model_weights.net' % (ckpt_dir))
torch.save(optimizer.state_dict(), '%s/optimizer.net' % (ckpt_dir))