From 8559defdfadced7da661ed7442be2e91529ab962 Mon Sep 17 00:00:00 2001 From: Galina Zalesskaya Date: Wed, 8 Nov 2023 11:07:49 +0200 Subject: [PATCH] Fix XAI algorithm for Detection (#2609) * Impove saliency maps algorithm for Detection * Remove extra changes * Update unit tests * Changes for 1 class * Fix pre-commit * Update CHANGELOG --- CHANGELOG.md | 1 + .../hooks/det_class_probability_map_hook.py | 14 ++++++++------ .../detection/test_xai_detection_validity.py | 16 ++++++++-------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0eabab2e816..f009610b0e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ All notable changes to this project will be documented in this file. - Fix mmcls bug not wrapping model in DataParallel on CPUs () - Fix h-label loss normalization issue w/ exclusive label group of singe label () - Fix division by zero in class incremental learning for classification () +- Fix saliency maps calculation issue for detection models () ## \[v1.4.3\] diff --git a/src/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py b/src/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py index 7931e234091..2847f1c573a 100644 --- a/src/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py +++ b/src/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py @@ -60,12 +60,9 @@ def func( else: cls_scores = self._get_cls_scores_from_feature_map(feature_map) - # Don't use softmax for tiles in tiling detection, if the tile doesn't contain objects, - # it would highlight one of the class maps as a background class - if self.use_cls_softmax and self._num_cls_out_channels > 1: - cls_scores = [torch.softmax(t, dim=1) for t in cls_scores] - - batch_size, _, height, width = cls_scores[-1].size() + middle_idx = len(cls_scores) // 2 + # resize to the middle feature map + batch_size, _, height, width = cls_scores[middle_idx].size() saliency_maps = torch.empty(batch_size, self._num_cls_out_channels, height, width) for batch_idx in range(batch_size): cls_scores_anchorless = [] @@ -82,6 +79,11 @@ def func( ) saliency_maps[batch_idx] = torch.cat(cls_scores_anchorless_resized, dim=0).mean(dim=0) + # Don't use softmax for tiles in tiling detection, if the tile doesn't contain objects, + # it would highlight one of the class maps as a background class + if self.use_cls_softmax: + saliency_maps[0] = torch.stack([torch.softmax(t, dim=1) for t in saliency_maps[0]]) + if self._norm_saliency_maps: saliency_maps = saliency_maps.reshape((batch_size, self._num_cls_out_channels, -1)) saliency_maps = self._normalize_map(saliency_maps) diff --git a/tests/unit/algorithms/detection/test_xai_detection_validity.py b/tests/unit/algorithms/detection/test_xai_detection_validity.py index 6f684376064..b24b690e3ba 100644 --- a/tests/unit/algorithms/detection/test_xai_detection_validity.py +++ b/tests/unit/algorithms/detection/test_xai_detection_validity.py @@ -24,19 +24,19 @@ class TestExplainMethods: ref_saliency_shapes = { - "MobileNetV2-ATSS": (2, 4, 4), + "MobileNetV2-ATSS": (2, 13, 13), "SSD": (81, 13, 13), - "YOLOX": (80, 13, 13), + "YOLOX": (80, 26, 26), } ref_saliency_vals_det = { - "MobileNetV2-ATSS": np.array([67, 216, 255, 57], dtype=np.uint8), - "YOLOX": np.array([80, 28, 42, 53, 49, 68, 72, 75, 69, 57, 65, 6, 157], dtype=np.uint8), - "SSD": np.array([119, 72, 118, 35, 39, 30, 31, 31, 36, 28, 44, 23, 61], dtype=np.uint8), + "MobileNetV2-ATSS": np.array([34, 67, 148, 132, 172, 147, 146, 155, 167, 159], dtype=np.uint8), + "YOLOX": np.array([177, 94, 147, 147, 161, 162, 164, 164, 163, 166], dtype=np.uint8), + "SSD": np.array([255, 178, 212, 90, 93, 79, 79, 80, 87, 83], dtype=np.uint8), } ref_saliency_vals_det_wo_postprocess = { - "MobileNetV2-ATSS": -0.10465062, + "MobileNetV2-ATSS": -0.014513552, "YOLOX": 0.04948914, "SSD": 0.6629989, } @@ -80,8 +80,8 @@ def test_saliency_map_det(self, template): assert len(saliency_maps) == 2 assert saliency_maps[0].ndim == 3 assert saliency_maps[0].shape == self.ref_saliency_shapes[template.name] - actual_sal_vals = saliency_maps[0][0][0].astype(np.int8) - ref_sal_vals = self.ref_saliency_vals_det[template.name].astype(np.int8) + actual_sal_vals = saliency_maps[0][0][0][:10].astype(np.int16) + ref_sal_vals = self.ref_saliency_vals_det[template.name].astype(np.uint8) assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) @e2e_pytest_unit