diff --git a/captum/attr/_core/layer/grad_cam.py b/captum/attr/_core/layer/grad_cam.py index df839a811..90c57c87e 100644 --- a/captum/attr/_core/layer/grad_cam.py +++ b/captum/attr/_core/layer/grad_cam.py @@ -82,6 +82,7 @@ def attribute( additional_forward_args: Any = None, attribute_to_layer_input: bool = False, relu_attributions: bool = False, + attr_dim_summation: bool = True, ) -> Union[Tensor, Tuple[Tensor, ...]]: r""" Args: @@ -149,6 +150,10 @@ def attribute( otherwise, by default, both positive and negative attributions are returned. Default: False + attr_dim_summation (bool, optional): Indicates whether to + sum attributions along dimension 1 (usually channel). + The default (True) means to sum along dimension 1. + Default: True Returns: *Tensor* or *tuple[Tensor, ...]* of **attributions**: @@ -208,10 +213,17 @@ def attribute( for layer_grad in layer_gradients ) - scaled_acts = tuple( - torch.sum(summed_grad * layer_eval, dim=1, keepdim=True) - for summed_grad, layer_eval in zip(summed_grads, layer_evals) - ) + if attr_dim_summation: + scaled_acts = tuple( + torch.sum(summed_grad * layer_eval, dim=1, keepdim=True) + for summed_grad, layer_eval in zip(summed_grads, layer_evals) + ) + else: + scaled_acts = tuple( + summed_grad * layer_eval + for summed_grad, layer_eval in zip(summed_grads, layer_evals) + ) + if relu_attributions: scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts) return _format_output(len(scaled_acts) > 1, scaled_acts) diff --git a/tests/attr/layer/test_grad_cam.py b/tests/attr/layer/test_grad_cam.py index 6f0229a76..a8cafbf44 100644 --- a/tests/attr/layer/test_grad_cam.py +++ b/tests/attr/layer/test_grad_cam.py @@ -33,6 +33,23 @@ def test_simple_input_conv(self) -> None: net, net.conv1, inp, [[[[11.25, 13.5], [20.25, 22.5]]]] ) + def test_simple_input_conv_split_channels(self) -> None: + net = BasicModel_ConvNet_One_Conv() + inp = torch.arange(16).view(1, 1, 4, 4).float() + expected_result = [ + [ + [[-3.7500, 3.0000], [23.2500, 30.0000]], + [[15.0000, 10.5000], [-3.0000, -7.5000]], + ] + ] + self._grad_cam_test_assert( + net, + net.conv1, + inp, + expected_activation=expected_result, + attr_dim_summation=False, + ) + def test_simple_input_conv_no_grad(self) -> None: net = BasicModel_ConvNet_One_Conv() @@ -100,6 +117,7 @@ def _grad_cam_test_assert( additional_input: Any = None, attribute_to_layer_input: bool = False, relu_attributions: bool = False, + attr_dim_summation: bool = True, ): layer_gc = LayerGradCam(model, target_layer) self.assertFalse(layer_gc.multiplies_by_inputs) @@ -109,6 +127,7 @@ def _grad_cam_test_assert( additional_forward_args=additional_input, attribute_to_layer_input=attribute_to_layer_input, relu_attributions=relu_attributions, + attr_dim_summation=attr_dim_summation, ) assertTensorTuplesAlmostEqual( self, attributions, expected_activation, delta=0.01