Skip to content

Commit e31bf38

Browse files
NarineKfacebook-github-bot
authored andcommitted
Fix layer_gradient_x_activation and add logging for metrics (#643)
Summary: + Adding logging for captum.metrics + currently layer_gradient_x_activation is failing for integer inputs. Adding datatype check before calling `apply_gradient_requirements` otherwise if attribute is called in `torch.no_grad()` context then it fails. Pull Request resolved: #643 Reviewed By: vivekmig Differential Revision: D27723943 Pulled By: NarineK fbshipit-source-id: 96a5fc1a6b848007e59212cbe7f469546fed5178
1 parent d630574 commit e31bf38

File tree

3 files changed

+10
-0
lines changed

3 files changed

+10
-0
lines changed

captum/metrics/_core/infidelity.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
safe_div,
1818
)
1919
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
20+
from captum.log import log_usage
2021
from captum.metrics._utils.batching import _divide_and_aggregate_metrics
2122

2223

@@ -108,6 +109,7 @@ def default_perturb_func(
108109
return sub_infidelity_perturb_func_decorator
109110

110111

112+
@log_usage()
111113
def infidelity(
112114
forward_func: Callable,
113115
perturb_func: Callable,

captum/metrics/_core/sensitivity.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_format_tensor_into_tuples,
1717
)
1818
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
19+
from captum.log import log_usage
1920
from captum.metrics._utils.batching import _divide_and_aggregate_metrics
2021

2122

@@ -57,6 +58,7 @@ def default_perturb_func(
5758
return perturbed_input
5859

5960

61+
@log_usage()
6062
def sensitivity_max(
6163
explanation_func: Callable,
6264
inputs: TensorOrTupleOfTensorsGeneric,

tests/attr/layer/test_layer_gradient_x_activation.py

+6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ def test_simple_input_gradient_activation(self) -> None:
2323
inp = torch.tensor([[0.0, 100.0, 0.0]], requires_grad=True)
2424
self._layer_activation_test_assert(net, net.linear0, inp, [0.0, 400.0, 0.0])
2525

26+
def test_simple_input_gradient_activation_no_grad(self) -> None:
27+
net = BasicModel_MultiLayer()
28+
inp = torch.tensor([[0.0, 100.0, 0.0]], requires_grad=True)
29+
with torch.no_grad():
30+
self._layer_activation_test_assert(net, net.linear0, inp, [0.0, 400.0, 0.0])
31+
2632
def test_simple_linear_gradient_activation(self) -> None:
2733
net = BasicModel_MultiLayer()
2834
inp = torch.tensor([[0.0, 100.0, 0.0]])

0 commit comments

Comments
 (0)