Skip to content

Commit

Permalink
DETR XAI (#4184)
Browse files Browse the repository at this point in the history
* Implement explainability features in DFine and RTDETR models
  • Loading branch information
eugene123tw authored Jan 23, 2025
1 parent 4416ac4 commit 2f0b54b
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 52 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4017>)
- Add D-Fine Detection Algorithm
(<https://github.com/openvinotoolkit/training_extensions/pull/4142>)
- Add DETR XAI Explain Mode
(<https://github.com/openvinotoolkit/training_extensions/pull/4184>)

### Enhancements

Expand Down
4 changes: 3 additions & 1 deletion docs/source/guide/tutorials/base/explain.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho
(otx) ...$ otx explain --work_dir otx-workspace \
--dump True # Wherether to save saliency map images or not
--explain_config.postprocess True # Resizes and applies colormap to the saliency map
.. tab-item:: CLI (with config)

Expand All @@ -41,6 +42,7 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho
--data_root data/wgisd \
--checkpoint otx-workspace/20240312_051135/checkpoints/epoch_033.ckpt \
--dump True # Wherether to save saliency map images or not
--explain_config.postprocess True # Resizes and applies colormap to the saliency map
.. tab-item:: API

Expand All @@ -49,7 +51,7 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho
engine.explain(
checkpoint="<checkpoint-path>",
datamodule=OTXDataModule(...), # The data module to use for predictions
explain_config=ExplainConfig(postprocess=True),
explain_config=ExplainConfig(postprocess=True), # Resizes and applies colormap to the saliency map
dump=True # Wherether to save saliency map images or not
)
Expand Down
56 changes: 56 additions & 0 deletions src/otx/algo/detection/d_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def _customize_inputs(
)
targets.append({"boxes": scaled_bboxes, "labels": ll})

if self.explain_mode:
return {"entity": entity}

return {
"images": entity.images,
"targets": targets,
Expand Down Expand Up @@ -185,6 +188,33 @@ def _customize_outputs(
original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info]
scores, bboxes, labels = self.model.postprocess(outputs, original_sizes)

if self.explain_mode:
if not isinstance(outputs, dict):
msg = f"Model output should be a dict, but got {type(outputs)}."
raise ValueError(msg)

if "feature_vector" not in outputs:
msg = "No feature vector in the model output."
raise ValueError(msg)

if "saliency_map" not in outputs:
msg = "No saliency maps in the model output."
raise ValueError(msg)

saliency_map = outputs["saliency_map"].detach().cpu().numpy()
feature_vector = outputs["feature_vector"].detach().cpu().numpy()

return DetBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
labels=labels,
feature_vector=feature_vector,
saliency_map=saliency_map,
)

return DetBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
Expand Down Expand Up @@ -306,3 +336,29 @@ def _optimization_config(self) -> dict[str, Any]:
},
},
}

@staticmethod
def _forward_explain_detection(
self, # noqa: ANN001
entity: DetBatchDataEntity,
mode: str = "tensor", # noqa: ARG004
) -> dict[str, torch.Tensor]:
"""Forward function for explainable detection model."""
backbone_feats = self.encoder(self.backbone(entity.images))
predictions = self.decoder(backbone_feats, explain_mode=True)

raw_logits = DETR.split_and_reshape_logits(
backbone_feats,
predictions["raw_logits"],
)

saliency_map = self.explain_fn(raw_logits)
feature_vector = self.feature_vector_fn(backbone_feats)
predictions.update(
{
"feature_vector": feature_vector,
"saliency_map": saliency_map,
},
)

return predictions
40 changes: 32 additions & 8 deletions src/otx/algo/detection/detectors/detection_transformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Base DETR model implementations."""

from __future__ import annotations

import warnings
from typing import Any

import numpy as np
Expand Down Expand Up @@ -96,22 +95,47 @@ def export(
explain_mode: bool = False,
) -> dict[str, Any] | tuple[list[Any], list[Any], list[Any]]:
"""Exports the model."""
backbone_feats = self.encoder(self.backbone(batch_inputs))
predictions = self.decoder(backbone_feats, explain_mode=True)
results = self.postprocess(
self._forward_features(batch_inputs),
predictions,
[meta["img_shape"] for meta in batch_img_metas],
deploy_mode=True,
)

if explain_mode:
# TODO(Eugene): Implement explain mode for DETR model.
warnings.warn("Explain mode is not supported for DETR model. Return dummy values.", stacklevel=2)
raw_logits = self.split_and_reshape_logits(backbone_feats, predictions["raw_logits"])
feature_vector = self.feature_vector_fn(backbone_feats)
saliency_map = self.explain_fn(raw_logits)
xai_output = {
"feature_vector": torch.zeros(1, 1),
"saliency_map": torch.zeros(1),
"feature_vector": feature_vector,
"saliency_map": saliency_map,
}
results.update(xai_output) # type: ignore[union-attr]
return results

@staticmethod
def split_and_reshape_logits(
backbone_feats: tuple[Tensor, ...],
raw_logits: Tensor,
) -> tuple[Tensor, ...]:
"""Splits and reshapes raw logits for explain mode.
Args:
backbone_feats (tuple[Tensor,...]): Tuple of backbone features.
raw_logits (Tensor): Raw logits.
Returns:
tuple[Tensor,...]: The reshaped logits.
"""
splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats]
# Permute and split logits in one line
raw_logits = torch.split(raw_logits.permute(0, 2, 1), splits, dim=-1)

# Reshape each split in a list comprehension
return tuple(
logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats)
)

def postprocess(
self,
outputs: dict[str, Tensor],
Expand Down
31 changes: 27 additions & 4 deletions src/otx/algo/detection/heads/dfine_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def _get_decoder_input(
enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1)
content = torch.concat([denoising_logits, content], dim=1)

return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list
return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list, enc_outputs_logits

def _select_topk(
self,
Expand Down Expand Up @@ -762,8 +762,22 @@ def _select_topk(

return topk_memory, topk_logits, topk_anchors

def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None) -> dict[str, Tensor]:
"""Forward pass of the DFine Transformer module."""
def forward(
self,
feats: Tensor,
targets: list[dict[str, Tensor]] | None = None,
explain_mode: bool = False,
) -> dict[str, Tensor]:
"""Forward function of the D-FINE Decoder Transformer Module.
Args:
feats (Tensor): Feature maps.
targets (list[dict[str, Tensor]] | None, optional): target annotations. Defaults to None.
explain_mode (bool, optional): Whether to return raw logits for explanation. Defaults to False.
Returns:
dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes.
"""
# input projection and embedding
memory, spatial_shapes = self._get_encoder_input(feats)

Expand All @@ -781,7 +795,13 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None)
else:
denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None

init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = self._get_decoder_input(
(
init_ref_contents,
init_ref_points_unact,
enc_topk_bboxes_list,
enc_topk_logits_list,
raw_logits,
) = self._get_decoder_input(
memory,
spatial_shapes,
denoising_logits,
Expand Down Expand Up @@ -858,6 +878,9 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None)
"pred_boxes": out_bboxes[-1],
}

if explain_mode:
out["raw_logits"] = raw_logits

return out

@torch.jit.unused
Expand Down
35 changes: 26 additions & 9 deletions src/otx/algo/detection/heads/rtdetr_decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""RTDETR decoder, modified from https://github.com/lyuwenyu/RT-DETR."""
Expand Down Expand Up @@ -546,10 +546,10 @@ def _get_decoder_input(

output_memory = self.enc_output(memory)

enc_outputs_class = self.enc_score_head(output_memory)
enc_outputs_logits = self.enc_score_head(output_memory)
enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors

_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
_, topk_ind = torch.topk(enc_outputs_logits.max(-1).values, self.num_queries, dim=1)

reference_points_unact = enc_outputs_coord_unact.gather(
dim=1,
Expand All @@ -560,9 +560,9 @@ def _get_decoder_input(
if denoising_bbox_unact is not None:
reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)

enc_topk_logits = enc_outputs_class.gather(
enc_topk_logits = enc_outputs_logits.gather(
dim=1,
index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]),
index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_logits.shape[-1]),
)

# extract region features
Expand All @@ -575,10 +575,24 @@ def _get_decoder_input(
if denoising_class is not None:
target = torch.concat([denoising_class, target], 1)

return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits, enc_outputs_logits

def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] | None = None) -> torch.Tensor:
"""Forward pass of the RTDETRTransformer module."""
def forward(
self,
feats: torch.Tensor,
targets: list[dict[str, torch.Tensor]] | None = None,
explain_mode: bool = False,
) -> dict[str, torch.Tensor]:
"""Forward function of RTDETRTransformer.
Args:
feats (Tensor): Input features.
targets (List[Dict[str, Tensor]]): List of target dictionaries.
explain_mode (bool): Whether to return raw logits for explanation.
Returns:
dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes.
"""
# input projection and embedding
(memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats)

Expand All @@ -596,7 +610,7 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] |
else:
denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None

target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = self._get_decoder_input(
target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits, raw_logits = self._get_decoder_input(
memory,
spatial_shapes,
denoising_class,
Expand Down Expand Up @@ -630,6 +644,9 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] |
out["dn_aux_outputs"] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
out["dn_meta"] = dn_meta

if explain_mode:
out["raw_logits"] = raw_logits

return out

@torch.jit.unused
Expand Down
58 changes: 57 additions & 1 deletion src/otx/algo/detection/rtdetr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""RTDetr model implementations."""
Expand Down Expand Up @@ -128,6 +128,9 @@ def _customize_inputs(
)
targets.append({"boxes": scaled_bboxes, "labels": ll})

if self.explain_mode:
return {"entity": entity}

return {
"images": entity.images,
"targets": targets,
Expand Down Expand Up @@ -156,6 +159,33 @@ def _customize_outputs(
original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info]
scores, bboxes, labels = self.model.postprocess(outputs, original_sizes)

if self.explain_mode:
if not isinstance(outputs, dict):
msg = f"Model output should be a dict, but got {type(outputs)}."
raise ValueError(msg)

if "feature_vector" not in outputs:
msg = "No feature vector in the model output."
raise ValueError(msg)

if "saliency_map" not in outputs:
msg = "No saliency maps in the model output."
raise ValueError(msg)

saliency_map = outputs["saliency_map"].detach().cpu().numpy()
feature_vector = outputs["feature_vector"].detach().cpu().numpy()

return DetBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
labels=labels,
feature_vector=feature_vector,
saliency_map=saliency_map,
)

return DetBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
Expand Down Expand Up @@ -271,3 +301,29 @@ def _exporter(self) -> OTXModelExporter:
def _optimization_config(self) -> dict[str, Any]:
"""PTQ config for RT-DETR."""
return {"model_type": "transformer"}

@staticmethod
def _forward_explain_detection(
self, # noqa: ANN001
entity: DetBatchDataEntity,
mode: str = "tensor", # noqa: ARG004
) -> dict[str, torch.Tensor]:
"""Forward function for explainable detection model."""
backbone_feats = self.encoder(self.backbone(entity.images))
predictions = self.decoder(backbone_feats, explain_mode=True)

raw_logits = DETR.split_and_reshape_logits(
backbone_feats,
predictions["raw_logits"],
)

saliency_map = self.explain_fn(raw_logits)
feature_vector = self.feature_vector_fn(backbone_feats)
predictions.update(
{
"feature_vector": feature_vector,
"saliency_map": saliency_map,
},
)

return predictions
Loading

0 comments on commit 2f0b54b

Please # to comment.