From 2f0b54b4c01a13cfaf442270883445f9a1c111a9 Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Thu, 23 Jan 2025 08:36:14 +0000 Subject: [PATCH] DETR XAI (#4184) * Implement explainability features in DFine and RTDETR models --- CHANGELOG.md | 2 + docs/source/guide/tutorials/base/explain.rst | 4 +- src/otx/algo/detection/d_fine.py | 56 ++++++++++++++++++ .../detectors/detection_transformer.py | 40 ++++++++++--- src/otx/algo/detection/heads/dfine_decoder.py | 31 ++++++++-- .../algo/detection/heads/rtdetr_decoder.py | 35 ++++++++--- src/otx/algo/detection/rtdetr.py | 58 ++++++++++++++++++- tests/integration/api/test_xai.py | 18 +----- tests/integration/cli/test_cli.py | 13 +---- 9 files changed, 205 insertions(+), 52 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a5dbbb26b6..e0578a1549b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ All notable changes to this project will be documented in this file. () - Add D-Fine Detection Algorithm () +- Add DETR XAI Explain Mode + () ### Enhancements diff --git a/docs/source/guide/tutorials/base/explain.rst b/docs/source/guide/tutorials/base/explain.rst index bf2af135783..cb195b9a914 100644 --- a/docs/source/guide/tutorials/base/explain.rst +++ b/docs/source/guide/tutorials/base/explain.rst @@ -32,6 +32,7 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho (otx) ...$ otx explain --work_dir otx-workspace \ --dump True # Wherether to save saliency map images or not + --explain_config.postprocess True # Resizes and applies colormap to the saliency map .. tab-item:: CLI (with config) @@ -41,6 +42,7 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho --data_root data/wgisd \ --checkpoint otx-workspace/20240312_051135/checkpoints/epoch_033.ckpt \ --dump True # Wherether to save saliency map images or not + --explain_config.postprocess True # Resizes and applies colormap to the saliency map .. tab-item:: API @@ -49,7 +51,7 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho engine.explain( checkpoint="", datamodule=OTXDataModule(...), # The data module to use for predictions - explain_config=ExplainConfig(postprocess=True), + explain_config=ExplainConfig(postprocess=True), # Resizes and applies colormap to the saliency map dump=True # Wherether to save saliency map images or not ) diff --git a/src/otx/algo/detection/d_fine.py b/src/otx/algo/detection/d_fine.py index 5e16aa9c3c7..717ea9d6b23 100644 --- a/src/otx/algo/detection/d_fine.py +++ b/src/otx/algo/detection/d_fine.py @@ -157,6 +157,9 @@ def _customize_inputs( ) targets.append({"boxes": scaled_bboxes, "labels": ll}) + if self.explain_mode: + return {"entity": entity} + return { "images": entity.images, "targets": targets, @@ -185,6 +188,33 @@ def _customize_outputs( original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info] scores, bboxes, labels = self.model.postprocess(outputs, original_sizes) + if self.explain_mode: + if not isinstance(outputs, dict): + msg = f"Model output should be a dict, but got {type(outputs)}." + raise ValueError(msg) + + if "feature_vector" not in outputs: + msg = "No feature vector in the model output." + raise ValueError(msg) + + if "saliency_map" not in outputs: + msg = "No saliency maps in the model output." + raise ValueError(msg) + + saliency_map = outputs["saliency_map"].detach().cpu().numpy() + feature_vector = outputs["feature_vector"].detach().cpu().numpy() + + return DetBatchPredEntity( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + bboxes=bboxes, + labels=labels, + feature_vector=feature_vector, + saliency_map=saliency_map, + ) + return DetBatchPredEntity( batch_size=len(outputs), images=inputs.images, @@ -306,3 +336,29 @@ def _optimization_config(self) -> dict[str, Any]: }, }, } + + @staticmethod + def _forward_explain_detection( + self, # noqa: ANN001 + entity: DetBatchDataEntity, + mode: str = "tensor", # noqa: ARG004 + ) -> dict[str, torch.Tensor]: + """Forward function for explainable detection model.""" + backbone_feats = self.encoder(self.backbone(entity.images)) + predictions = self.decoder(backbone_feats, explain_mode=True) + + raw_logits = DETR.split_and_reshape_logits( + backbone_feats, + predictions["raw_logits"], + ) + + saliency_map = self.explain_fn(raw_logits) + feature_vector = self.feature_vector_fn(backbone_feats) + predictions.update( + { + "feature_vector": feature_vector, + "saliency_map": saliency_map, + }, + ) + + return predictions diff --git a/src/otx/algo/detection/detectors/detection_transformer.py b/src/otx/algo/detection/detectors/detection_transformer.py index d6798f1d426..f3cda5b7417 100644 --- a/src/otx/algo/detection/detectors/detection_transformer.py +++ b/src/otx/algo/detection/detectors/detection_transformer.py @@ -1,11 +1,10 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # """Base DETR model implementations.""" from __future__ import annotations -import warnings from typing import Any import numpy as np @@ -96,22 +95,47 @@ def export( explain_mode: bool = False, ) -> dict[str, Any] | tuple[list[Any], list[Any], list[Any]]: """Exports the model.""" + backbone_feats = self.encoder(self.backbone(batch_inputs)) + predictions = self.decoder(backbone_feats, explain_mode=True) results = self.postprocess( - self._forward_features(batch_inputs), + predictions, [meta["img_shape"] for meta in batch_img_metas], deploy_mode=True, ) - if explain_mode: - # TODO(Eugene): Implement explain mode for DETR model. - warnings.warn("Explain mode is not supported for DETR model. Return dummy values.", stacklevel=2) + raw_logits = self.split_and_reshape_logits(backbone_feats, predictions["raw_logits"]) + feature_vector = self.feature_vector_fn(backbone_feats) + saliency_map = self.explain_fn(raw_logits) xai_output = { - "feature_vector": torch.zeros(1, 1), - "saliency_map": torch.zeros(1), + "feature_vector": feature_vector, + "saliency_map": saliency_map, } results.update(xai_output) # type: ignore[union-attr] return results + @staticmethod + def split_and_reshape_logits( + backbone_feats: tuple[Tensor, ...], + raw_logits: Tensor, + ) -> tuple[Tensor, ...]: + """Splits and reshapes raw logits for explain mode. + + Args: + backbone_feats (tuple[Tensor,...]): Tuple of backbone features. + raw_logits (Tensor): Raw logits. + + Returns: + tuple[Tensor,...]: The reshaped logits. + """ + splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats] + # Permute and split logits in one line + raw_logits = torch.split(raw_logits.permute(0, 2, 1), splits, dim=-1) + + # Reshape each split in a list comprehension + return tuple( + logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats) + ) + def postprocess( self, outputs: dict[str, Tensor], diff --git a/src/otx/algo/detection/heads/dfine_decoder.py b/src/otx/algo/detection/heads/dfine_decoder.py index d28e0cf3864..e2d8f9dd663 100644 --- a/src/otx/algo/detection/heads/dfine_decoder.py +++ b/src/otx/algo/detection/heads/dfine_decoder.py @@ -723,7 +723,7 @@ def _get_decoder_input( enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1) content = torch.concat([denoising_logits, content], dim=1) - return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list + return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list, enc_outputs_logits def _select_topk( self, @@ -762,8 +762,22 @@ def _select_topk( return topk_memory, topk_logits, topk_anchors - def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None) -> dict[str, Tensor]: - """Forward pass of the DFine Transformer module.""" + def forward( + self, + feats: Tensor, + targets: list[dict[str, Tensor]] | None = None, + explain_mode: bool = False, + ) -> dict[str, Tensor]: + """Forward function of the D-FINE Decoder Transformer Module. + + Args: + feats (Tensor): Feature maps. + targets (list[dict[str, Tensor]] | None, optional): target annotations. Defaults to None. + explain_mode (bool, optional): Whether to return raw logits for explanation. Defaults to False. + + Returns: + dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes. + """ # input projection and embedding memory, spatial_shapes = self._get_encoder_input(feats) @@ -781,7 +795,13 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None) else: denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None - init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = self._get_decoder_input( + ( + init_ref_contents, + init_ref_points_unact, + enc_topk_bboxes_list, + enc_topk_logits_list, + raw_logits, + ) = self._get_decoder_input( memory, spatial_shapes, denoising_logits, @@ -858,6 +878,9 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None) "pred_boxes": out_bboxes[-1], } + if explain_mode: + out["raw_logits"] = raw_logits + return out @torch.jit.unused diff --git a/src/otx/algo/detection/heads/rtdetr_decoder.py b/src/otx/algo/detection/heads/rtdetr_decoder.py index dd5cf2f1991..bf140675ef7 100644 --- a/src/otx/algo/detection/heads/rtdetr_decoder.py +++ b/src/otx/algo/detection/heads/rtdetr_decoder.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # """RTDETR decoder, modified from https://github.com/lyuwenyu/RT-DETR.""" @@ -546,10 +546,10 @@ def _get_decoder_input( output_memory = self.enc_output(memory) - enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_logits = self.enc_score_head(output_memory) enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors - _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1) + _, topk_ind = torch.topk(enc_outputs_logits.max(-1).values, self.num_queries, dim=1) reference_points_unact = enc_outputs_coord_unact.gather( dim=1, @@ -560,9 +560,9 @@ def _get_decoder_input( if denoising_bbox_unact is not None: reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) - enc_topk_logits = enc_outputs_class.gather( + enc_topk_logits = enc_outputs_logits.gather( dim=1, - index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]), + index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_logits.shape[-1]), ) # extract region features @@ -575,10 +575,24 @@ def _get_decoder_input( if denoising_class is not None: target = torch.concat([denoising_class, target], 1) - return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits + return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits, enc_outputs_logits - def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] | None = None) -> torch.Tensor: - """Forward pass of the RTDETRTransformer module.""" + def forward( + self, + feats: torch.Tensor, + targets: list[dict[str, torch.Tensor]] | None = None, + explain_mode: bool = False, + ) -> dict[str, torch.Tensor]: + """Forward function of RTDETRTransformer. + + Args: + feats (Tensor): Input features. + targets (List[Dict[str, Tensor]]): List of target dictionaries. + explain_mode (bool): Whether to return raw logits for explanation. + + Returns: + dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes. + """ # input projection and embedding (memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats) @@ -596,7 +610,7 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] | else: denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None - target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = self._get_decoder_input( + target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits, raw_logits = self._get_decoder_input( memory, spatial_shapes, denoising_class, @@ -630,6 +644,9 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] | out["dn_aux_outputs"] = self._set_aux_loss(dn_out_logits, dn_out_bboxes) out["dn_meta"] = dn_meta + if explain_mode: + out["raw_logits"] = raw_logits + return out @torch.jit.unused diff --git a/src/otx/algo/detection/rtdetr.py b/src/otx/algo/detection/rtdetr.py index 87784dadd7a..fcbf6330c2e 100644 --- a/src/otx/algo/detection/rtdetr.py +++ b/src/otx/algo/detection/rtdetr.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # """RTDetr model implementations.""" @@ -128,6 +128,9 @@ def _customize_inputs( ) targets.append({"boxes": scaled_bboxes, "labels": ll}) + if self.explain_mode: + return {"entity": entity} + return { "images": entity.images, "targets": targets, @@ -156,6 +159,33 @@ def _customize_outputs( original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info] scores, bboxes, labels = self.model.postprocess(outputs, original_sizes) + if self.explain_mode: + if not isinstance(outputs, dict): + msg = f"Model output should be a dict, but got {type(outputs)}." + raise ValueError(msg) + + if "feature_vector" not in outputs: + msg = "No feature vector in the model output." + raise ValueError(msg) + + if "saliency_map" not in outputs: + msg = "No saliency maps in the model output." + raise ValueError(msg) + + saliency_map = outputs["saliency_map"].detach().cpu().numpy() + feature_vector = outputs["feature_vector"].detach().cpu().numpy() + + return DetBatchPredEntity( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + bboxes=bboxes, + labels=labels, + feature_vector=feature_vector, + saliency_map=saliency_map, + ) + return DetBatchPredEntity( batch_size=len(outputs), images=inputs.images, @@ -271,3 +301,29 @@ def _exporter(self) -> OTXModelExporter: def _optimization_config(self) -> dict[str, Any]: """PTQ config for RT-DETR.""" return {"model_type": "transformer"} + + @staticmethod + def _forward_explain_detection( + self, # noqa: ANN001 + entity: DetBatchDataEntity, + mode: str = "tensor", # noqa: ARG004 + ) -> dict[str, torch.Tensor]: + """Forward function for explainable detection model.""" + backbone_feats = self.encoder(self.backbone(entity.images)) + predictions = self.decoder(backbone_feats, explain_mode=True) + + raw_logits = DETR.split_and_reshape_logits( + backbone_feats, + predictions["raw_logits"], + ) + + saliency_map = self.explain_fn(raw_logits) + feature_vector = self.feature_vector_fn(backbone_feats) + predictions.update( + { + "feature_vector": feature_vector, + "saliency_map": saliency_map, + }, + ) + + return predictions diff --git a/tests/integration/api/test_xai.py b/tests/integration/api/test_xai.py index 5408cd5049e..d82723470ec 100644 --- a/tests/integration/api/test_xai.py +++ b/tests/integration/api/test_xai.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from pathlib import Path @@ -51,11 +51,6 @@ def test_forward_explain( # TODO(Eugene): maskdino not support yet. pytest.skip(f"There's issue with inst-seg: {model_name}. Skip for now.") - if "dfine" in model_name: - # TODO(Eugene): dfine not support yet. - # https://jira.devtools.intel.com/browse/CVS-160781 - pytest.skip(f"There's issue with dfine: {model_name}. Skip for now.") - if "dino" in model_name: pytest.skip("DINO is not supported.") @@ -63,9 +58,6 @@ def test_forward_explain( # TODO (sungchul): enable xai for rtmdet_tiny (CVS-142651) pytest.skip("rtmdet_tiny on detection is not supported yet.") - if "rtdetr" in recipe: - pytest.skip("rtdetr on detection is not supported yet.") - if "yolov9" in recipe: pytest.skip("yolov9 on detection is not supported yet.") @@ -123,11 +115,6 @@ def test_predict_with_explain( # TODO(Eugene): maskdino not support yet. pytest.skip(f"There's issue with inst-seg: {model_name}. Skip for now.") - if "dfine" in model_name: - # TODO(Eugene): dfine not support yet. - # https://jira.devtools.intel.com/browse/CVS-160781 - pytest.skip(f"There's issue with dfine: {model_name}. Skip for now.") - if "rtmdet_tiny" in recipe: # TODO (sungchul): enable xai for rtmdet_tiny (CVS-142651) pytest.skip("rtmdet_tiny on detection is not supported yet.") @@ -136,9 +123,6 @@ def test_predict_with_explain( # TODO (Galina): required to update model-api to 2.1 pytest.skip("yolox_tiny_tile on detection requires model-api update") - if "rtdetr" in recipe: - pytest.skip("rtdetr on detection is not supported yet.") - if "yolov9" in recipe: pytest.skip("yolov9 on detection is not supported yet.") diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index 0de2a490929..3c11993ddab 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations @@ -252,12 +252,6 @@ def test_otx_e2e( if "dino" in model_name: return # DINO is not supported. - if "dfine" in model_name: - return # DFine is not supported. - - if "rtdetr" in model_name: - return # RT-DETR currently is not supported. - if "yolov9" in model_name: return # RT-DETR currently is not supported. @@ -334,13 +328,8 @@ def test_otx_explain_e2e( if "dino" in model_name: pytest.skip("DINO is not supported.") - if "dfine" in model_name: - pytest.skip("DFine is not supported.") - if "maskrcnn_r50_tv" in model_name: pytest.skip("MaskRCNN R50 Torchvision model doesn't support explain.") - elif "rtdetr" in recipe: - pytest.skip("rtdetr model is not supported yet with explain.") elif "keypoint" in recipe: pytest.skip("keypoint detection models don't support explain.") elif "yolov9" in recipe: