diff --git a/ml3d/torch/models/point_pillars.py b/ml3d/torch/models/point_pillars.py index a519acc7..c6453454 100644 --- a/ml3d/torch/models/point_pillars.py +++ b/ml3d/torch/models/point_pillars.py @@ -907,7 +907,7 @@ def flatten_idx(idx, j): # for each anchor the gt with max IoU max_overlaps, argmax_overlaps = overlaps.max(dim=0) # for each gt the anchor with max IoU - gt_max_overlaps, _ = overlaps.max(dim=1) + gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1) pos_idx = max_overlaps >= pos_th neg_idx = (max_overlaps >= 0) & (max_overlaps < neg_th) @@ -916,6 +916,7 @@ def flatten_idx(idx, j): for k in range(len(target_bboxes[i])): if gt_max_overlaps[k] >= neg_th: pos_idx[overlaps[k, :] == gt_max_overlaps[k]] = True + argmax_overlaps[gt_argmax_overlaps[k]] = k # encode bbox for positive matches assigned_bboxes.append(