16
16
from .._utils .batching import _divide_and_aggregate_metrics
17
17
18
18
19
+ def infidelity_perturb_func_decorator (pertub_func ):
20
+ r"""
21
+ An auxiliary, decorator function that helps with computing
22
+ perturbations given perturbed inputs. It can be useful for cases
23
+ when `pertub_func` returns only perturbed inputs and we
24
+ internally compute the perturbations as
25
+ (input - perturbed_input) / (input - baseline).
26
+
27
+ If users decorate their `pertub_func` with
28
+ `@infidelity_perturb_func_decorator` function then their `pertub_func`
29
+ needs to only return perturbed inputs.
30
+
31
+ Note that if your attribution algorithm is inherently local such as
32
+ Saliency maps you should not use the decorator because the decorator
33
+ always divides by (input - baseline) and that is unnecessary for local
34
+ methods.
35
+ Args:
36
+
37
+ pertub_func(callable): Input perturbation function that takes inputs
38
+ and optionally baselines and returns perturbed inputs
39
+
40
+ Returns:
41
+
42
+ default_perturb_func(callable): Internal default perturbation
43
+ function that computes the perturbations internally and returns
44
+ perturbations and perturbed inputs.
45
+
46
+ Examples::
47
+ >>> @infidelity_perturb_func_decorator
48
+ >>> def perturb_fn(inputs):
49
+ >>> noise = torch.tensor(np.random.normal(0, 0.003, inputs.shape)).float()
50
+ >>> return inputs - noise
51
+ >>> # Computes infidelity score using `perturb_fn`
52
+ >>> infidelity = infidelity_attr(model, perturb_fn, input, ...)
53
+
54
+ """
55
+
56
+ def default_perturb_func (inputs , baselines = None ):
57
+ r"""
58
+ """
59
+ inputs_perturbed = (
60
+ pertub_func (inputs , baselines )
61
+ if baselines is not None
62
+ else pertub_func (inputs )
63
+ )
64
+ inputs_perturbed = _format_tensor_into_tuples (inputs_perturbed )
65
+ inputs = _format_tensor_into_tuples (inputs )
66
+ baselines = _format_tensor_into_tuples (baselines )
67
+ if baselines is None :
68
+ perturbations = tuple (
69
+ safe_div (
70
+ input - input_perturbed ,
71
+ input ,
72
+ torch .tensor (1.0 , device = input .device ),
73
+ )
74
+ for input , input_perturbed in zip (inputs , inputs_perturbed )
75
+ )
76
+ else :
77
+ perturbations = tuple (
78
+ safe_div (
79
+ input - input_perturbed ,
80
+ input - baseline ,
81
+ torch .tensor (1.0 , device = input .device ),
82
+ )
83
+ for input , input_perturbed , baseline in zip (
84
+ inputs , inputs_perturbed , baselines
85
+ )
86
+ )
87
+ return perturbations , inputs_perturbed
88
+
89
+ return default_perturb_func
90
+
91
+
19
92
def infidelity (
20
93
forward_func ,
21
94
perturb_func ,
@@ -26,7 +99,6 @@ def infidelity(
26
99
target = None ,
27
100
n_samples = 10 ,
28
101
max_examples_per_batch = None ,
29
- perturb_func_custom = False ,
30
102
):
31
103
r"""
32
104
Explanation infidelity represents the expected mean-squared error
@@ -62,35 +134,52 @@ def infidelity(
62
134
The perturbation function of model inputs. This function takes
63
135
model inputs and optionally baselines as input arguments and returns
64
136
either a tuple of perturbations and perturbed inputs or just
65
- perturbed inputs. If `perturb_func` returns only perturbed inputs
66
- then the users have to set the `perturb_func_custom=True`, this
67
- will allow us to compute the perturbations internally both for local
68
- and global infidelity and makes sense
69
- to use only if input attributions are global attributions.
137
+ perturbed inputs. For example:
138
+
139
+ def my_perturb_func(inputs):
140
+ <MY-LOGIC-HERE>
141
+ return perturbations, perturbed_inputs
142
+
143
+ If we want to only return perturbed inputs and compute
144
+ perturbations internally then we can wrap perturb_func with
145
+ `infidelity_perturb_func_decorator` decorator such as:
146
+
147
+ from captum.metrics import infidelity_perturb_func_decorator
148
+ @infidelity_perturb_func_decorator
149
+ def my_perturb_func(inputs):
150
+ <MY-LOGIC-HERE>
151
+ return perturbed_inputs
152
+
153
+ In this case we compute perturbations by dividing
154
+ (input - perturbed_input) by (input - baselines) and the user needs to
155
+ only return perturbed inputs in `perturb_func` as described above.
156
+
157
+ `infidelity_perturb_func_decorator` makes sense to use only for global
158
+ attribution algorithms such as integrated gradients, deeplift, etc.
159
+ In case user has a local attribution algorithm or decides to compute
160
+ perturbations and perturbed inputs in `perturb_func` then they must not
161
+ use `infidelity_perturb_func_decorator`.
70
162
71
163
If there are more than one inputs passed to infidelity function those
72
164
will be passed to `perturb_func` as tuples in the same order as they
73
165
are passed to infidelity function.
74
166
75
- In case `perturb_func_custom=False` and if inputs
167
+ If inputs
76
168
- is a single tensor, the function needs to return a tuple
77
169
of perturbations and perturbed input such as:
78
170
perturb, perturbed_input
171
+
172
+ and only perturbed_input in case
173
+ `infidelity_perturb_func_decorator`
174
+ is used.
79
175
- is a tuple of tensors,
80
176
corresponding perturbations and perturbed inputs must be computed
81
177
and returned as tuples in the following format:
82
178
(perturb1, perturb2, ... perturbN), (perturbed_input1,
83
179
perturbed_input2, ... perturbed_inputN)
84
-
85
- In case `perturb_func_custom=True` and if inputs
86
- - is a single tensor, the function needs to return
87
- only perturbed input
88
- perturbed_input
89
- - is a tuple of tensors,
90
- corresponding perturbed inputs must be computed and
91
- returned as tuples in the following format:
92
- (perturbed_input1, perturbed_input2, ... perturbed_inputN)
93
-
180
+ Similar to previous case here as well we need to return only
181
+ perturbed inputs in case `infidelity_perturb_func_decorator`
182
+ decorates out perturb_func
94
183
It is important to note that for performance reasons `perturb_func`
95
184
isn't called for each example individually but on a batch of
96
185
input examples that are repeated `max_examples_per_batch / batch_size`
@@ -164,9 +253,8 @@ def infidelity(
164
253
tensor as well. If inputs is provided as a tuple of tensors
165
254
then attributions will be tuples of tensors as well.
166
255
167
- If `perturb_func_custom=True` then we internally divide global
168
- attribution values by (input - baselines) and the user needs to
169
- only return perturbed inputs in `perturb_func`.
256
+ For more details on when to use `infidelity_perturb_func_decorator`,
257
+ please, read the documentation about `perturb_func`
170
258
171
259
additional_forward_args (any, optional): If the forward function
172
260
requires additional arguments other than the inputs for
@@ -220,17 +308,6 @@ def infidelity(
220
308
examples are processed together.
221
309
222
310
Default: None
223
- perturb_func_custom (boolean, optional): A flag that indicates whether
224
- to use default perturbation logic that always divides the
225
- attributions by (input - baseline). If this flag
226
- is True then `perturb_func` needs to return only the
227
- perturbed inputs.
228
- The perturbations will be computed internally by
229
- `default_perturb_func`. This makes sense to use only with
230
- global attribution values because otherwise there is no need
231
- to divide the attributions by (input - baseline).
232
-
233
- Default: False
234
311
Returns:
235
312
236
313
infidelities (tensor): A tensor of scalar infidelity scores per
@@ -254,31 +331,6 @@ def infidelity(
254
331
>>> infidelity = infidelity_attr(net, perturb_fn, input, attribution)
255
332
"""
256
333
257
- def default_perturb_func (inputs , inputs_perturbed , baselines = None ):
258
- r"""
259
- """
260
- if baselines is None :
261
- perturbations = tuple (
262
- safe_div (
263
- input - input_perturbed ,
264
- input ,
265
- torch .tensor (1.0 , device = input .device ),
266
- )
267
- for input , input_perturbed in zip (inputs , inputs_perturbed )
268
- )
269
- else :
270
- perturbations = tuple (
271
- safe_div (
272
- input - input_perturbed ,
273
- input - baseline ,
274
- torch .tensor (1.0 , device = input .device ),
275
- )
276
- for input , input_perturbed , baseline in zip (
277
- inputs , inputs_perturbed , baselines
278
- )
279
- )
280
- return perturbations , inputs_perturbed
281
-
282
334
def _generate_perturbations (current_n_samples ):
283
335
r"""
284
336
The perturbations are generated for each example `current_n_samples` times.
@@ -308,6 +360,7 @@ def call_perturb_func():
308
360
inputs_expanded = tuple (
309
361
torch .repeat_interleave (input , current_n_samples , dim = 0 ) for input in inputs
310
362
)
363
+
311
364
if baselines is not None :
312
365
baselines_expanded = tuple (
313
366
baseline .repeat_interleave (current_n_samples , dim = 0 )
@@ -320,16 +373,7 @@ def call_perturb_func():
320
373
else :
321
374
baselines_expanded = None
322
375
323
- perturb_func_out = call_perturb_func ()
324
-
325
- if perturb_func_custom :
326
- return default_perturb_func (
327
- inputs_expanded ,
328
- _format_tensor_into_tuples (perturb_func_out ),
329
- baselines = baselines_expanded ,
330
- )
331
- else :
332
- return perturb_func_out
376
+ return call_perturb_func ()
333
377
334
378
def _validate_inputs_and_perturbations (inputs , inputs_perturbed , perturbations ):
335
379
# asserts the sizes of the perturbations and inputs
0 commit comments