Skip to content

Commit

Permalink
New CAM Method: ShapleyCAM (#550)
Browse files Browse the repository at this point in the history
* ShapleyCAM

Weighting the activation maps using Gradient and Hessian-Vector Product.

* name

* ReST example

* comments

* Update README.md

* Update README.md

* Update README.md

* update a simpler version

* comments

* forward function in shapely_cam.py still needed

This is because the calculation of the Hessian-vector product (HVP) requires the computation graph to be retained, see comments in line 37 or 38.

* delete forward function in shapley_cam.py

* comments
  • Loading branch information
cai2-huaiguang authored Jan 19, 2025
1 parent b1cab2d commit fd4b5c8
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 18 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o
| Deep Feature Factorizations | Non Negative Matrix Factorization on the 2D activations |
| KPCA-CAM | Like EigenCAM but with Kernel PCA instead of PCA |
| FEM | A gradient free method that binarizes activations by an activation > mean + k * std rule. |
| ShapleyCAM | Weight the activations using the gradient and Hessian-vector product.|
## Visual Examples

| What makes the network think the image label is 'pug, pug-dog' | What makes the network think the image label is 'tabby, tabby cat' | Combining Grad-CAM with Guided Backpropagation for the 'pug, pug-dog' class |
Expand Down Expand Up @@ -362,4 +363,8 @@ Sachin Karmani, Thanushon Sivakaran, Gaurav Prasad, Mehmet Ali, Wenbo Yang, Shey
https://hal.science/hal-02963298/document <br>
`Features Understanding in 3D CNNs for Actions Recognition in Video
Kazi Ahmed Asif Fuad, Pierre-Etienne Martin, Romain Giot, Romain
Bourqui, Jenny Benois-Pineau, Akka Zemmar`
Bourqui, Jenny Benois-Pineau, Akka Zemmar`

https://arxiv.org/abs/2501.06261 <br>
`CAMs as Shapley Value-based Explainers
Huaiguang Cai`
11 changes: 6 additions & 5 deletions cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from pytorch_grad_cam import (
GradCAM, FEM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM
)
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import (
show_cam_on_image, deprocess_image, preprocess_image
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, ClassifierOutputReST


def get_args():
Expand All @@ -37,7 +37,7 @@ def get_args():
'gradcam', 'fem', 'hirescam', 'gradcam++',
'scorecam', 'xgradcam', 'ablationcam',
'eigencam', 'eigengradcam', 'layercam',
'fullgrad', 'gradcamelementwise', 'kpcacam'
'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam'
],
help='CAM method')

Expand Down Expand Up @@ -75,7 +75,8 @@ def get_args():
"fullgrad": FullGrad,
"fem": FEM,
"gradcamelementwise": GradCAMElementWise,
'kpcacam': KPCA_CAM
'kpcacam': KPCA_CAM,
'shapleycam': ShapleyCAM
}

if args.device=='hpu':
Expand Down Expand Up @@ -109,7 +110,7 @@ def get_args():
# If targets is None, the highest scoring category (for every member in the batch) will be used.
# You can target specific categories by
# targets = [ClassifierOutputTarget(281)]
# targets = [ClassifierOutputTarget(281)]
# targets = [ClassifierOutputReST(281)]
targets = None

# Using the with statement ensures the context is freed, and you can
Expand Down
1 change: 1 addition & 0 deletions pytorch_grad_cam/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pytorch_grad_cam.grad_cam import GradCAM
from pytorch_grad_cam.shapley_cam import ShapleyCAM
from pytorch_grad_cam.fem import FEM
from pytorch_grad_cam.hirescam import HiResCAM
from pytorch_grad_cam.grad_cam_elementwise import GradCAMElementWise
Expand Down
22 changes: 14 additions & 8 deletions pytorch_grad_cam/activations_and_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ class ActivationsAndGradients:
""" Class for extracting activations and
registering gradients from targetted intermediate layers """

def __init__(self, model, target_layers, reshape_transform):
def __init__(self, model, target_layers, reshape_transform, detach=True):
self.model = model
self.gradients = []
self.activations = []
self.reshape_transform = reshape_transform
self.detach = detach
self.handles = []
for target_layer in target_layers:
self.handles.append(
Expand All @@ -18,10 +19,12 @@ def __init__(self, model, target_layers, reshape_transform):

def save_activation(self, module, input, output):
activation = output

if self.reshape_transform is not None:
activation = self.reshape_transform(activation)
self.activations.append(activation.cpu().detach())
if self.detach:
if self.reshape_transform is not None:
activation = self.reshape_transform(activation)
self.activations.append(activation.cpu().detach())
else:
self.activations.append(activation)

def save_gradient(self, module, input, output):
if not hasattr(output, "requires_grad") or not output.requires_grad:
Expand All @@ -30,9 +33,12 @@ def save_gradient(self, module, input, output):

# Gradients are computed in reverse order
def _store_grad(grad):
if self.reshape_transform is not None:
grad = self.reshape_transform(grad)
self.gradients = [grad.cpu().detach()] + self.gradients
if self.detach:
if self.reshape_transform is not None:
grad = self.reshape_transform(grad)
self.gradients = [grad.cpu().detach()] + self.gradients
else:
self.gradients = [grad] + self.gradients

output.register_hook(_store_grad)

Expand Down
22 changes: 18 additions & 4 deletions pytorch_grad_cam/base_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
compute_input_gradient: bool = False,
uses_gradients: bool = True,
tta_transforms: Optional[tta.Compose] = None,
detach: bool = True,
) -> None:
self.model = model.eval()
self.target_layers = target_layers
Expand All @@ -45,7 +46,8 @@ def __init__(
else:
self.tta_transforms = tta_transforms

self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform)
self.detach = detach
self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform, self.detach)

""" Get a vector of weights for every channel in the target layer.
Methods that return weights channels,
Expand All @@ -71,6 +73,8 @@ def get_cam_image(
eigen_smooth: bool = False,
) -> np.ndarray:
weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
if isinstance(activations, torch.Tensor):
activations = activations.cpu().detach().numpy()
# 2D conv
if len(activations.shape) == 4:
weighted_activations = weights[:, :, None, None] * activations
Expand Down Expand Up @@ -103,7 +107,13 @@ def forward(
if self.uses_gradients:
self.model.zero_grad()
loss = sum([target(output) for target, output in zip(targets, outputs)])
loss.backward(retain_graph=True)
if self.detach:
loss.backward(retain_graph=True)
else:
# keep the computational graph, create_graph = True is needed for hvp
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
# When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"
# loss.backward(retain_graph=True, create_graph=True)
if 'hpu' in str(self.device):
self.__htcore.mark_step()

Expand Down Expand Up @@ -132,8 +142,12 @@ def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]
def compute_cam_per_layer(
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool
) -> np.ndarray:
activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
if self.detach:
activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
else:
activations_list = [a for a in self.activations_and_grads.activations]
grads_list = [g for g in self.activations_and_grads.gradients]
target_size = self.get_target_width_height(input_tensor)

cam_per_target_layer = []
Expand Down
60 changes: 60 additions & 0 deletions pytorch_grad_cam/shapley_cam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Callable, List, Optional, Tuple
from pytorch_grad_cam.base_cam import BaseCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import torch
import numpy as np

"""
Weights the activation maps using the gradient and Hessian-Vector product.
This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods (include GradCAM, HiResCAM and the original CAM) from a Shapley value perspective.
"""
class ShapleyCAM(BaseCAM):
def __init__(self, model, target_layers,
reshape_transform=None):
super(
ShapleyCAM,
self).__init__(
model = model,
target_layers = target_layers,
reshape_transform = reshape_transform,
compute_input_gradient = True,
uses_gradients = True,
detach = False)

def get_cam_weights(self,
input_tensor,
target_layer,
target_category,
activations,
grads):

hvp = torch.autograd.grad(
outputs=grads,
inputs=activations,
grad_outputs=activations,
retain_graph=False,
allow_unused=True
)[0]
# print(torch.max(hvp[0]).item()) # check if hvp is not all zeros
if hvp is None:
hvp = torch.tensor(0).to(self.device)
else:
if self.activations_and_grads.reshape_transform is not None:
hvp = self.activations_and_grads.reshape_transform(hvp)

if self.activations_and_grads.reshape_transform is not None:
activations = self.activations_and_grads.reshape_transform(activations)
grads = self.activations_and_grads.reshape_transform(grads)

weight = (grads - 0.5 * hvp).detach().cpu().numpy()
# 2D image
if len(activations.shape) == 4:
weight = np.mean(weight, axis=(2, 3))
return weight
# 3D image
elif len(activations.shape) == 5:
weight = np.mean(weight, axis=(2, 3, 4))
return weight
else:
raise ValueError("Invalid grads shape."
"Shape of grads should be 4 (2D image) or 5 (3D image).")
16 changes: 16 additions & 0 deletions pytorch_grad_cam/utils/model_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,22 @@ def __call__(self, model_output):
return torch.softmax(model_output, dim=-1)[:, self.category]


class ClassifierOutputReST:
"""
Using both pre-softmax and post-softmax, proposed in https://arxiv.org/abs/2501.06261
"""
def __init__(self, category):
self.category = category
def __call__(self, model_output):
if len(model_output.shape) == 1:
target = torch.tensor([self.category], device=model_output.device)
model_output = model_output.unsqueeze(0)
return model_output[0][self.category] - torch.nn.functional.cross_entropy(model_output, target)
else:
target = torch.tensor([self.category] * model_output.shape[0], device=model_output.device)
return model_output[:,self.category] - torch.nn.functional.cross_entropy(model_output, target)


class BinaryClassifierOutputTarget:
def __init__(self, category):
self.category = category
Expand Down

0 comments on commit fd4b5c8

Please # to comment.