Skip to content

Commit 3fb83ad

Browse files
committed
erge remote-tracking branch 'upstream/master' into insights-occlusion
2 parents 545beea + 77aa93a commit 3fb83ad

File tree

11 files changed

+712
-86
lines changed

11 files changed

+712
-86
lines changed

captum/attr/_core/deep_lift.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,9 @@ def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple:
540540
def forward_hook(module: Module, inputs: Tuple, outputs: Tensor):
541541
return torch.stack(torch.chunk(outputs, 2), dim=1)
542542

543-
if isinstance(self.model, nn.DataParallel):
543+
if isinstance(
544+
self.model, (nn.DataParallel, nn.parallel.DistributedDataParallel)
545+
):
544546
return [
545547
self.model.module.register_forward_pre_hook(pre_hook), # type: ignore
546548
self.model.module.register_forward_hook(forward_hook),

captum/insights/config.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,18 @@ def _str_to_tuple(s):
7373
post_process={"n_steps": int},
7474
),
7575
FeatureAblation.get_name(): ConfigParameters(
76-
params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))}
76+
params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))},
7777
),
7878
Occlusion.get_name(): ConfigParameters(
7979
params={
8080
"sliding_window_shapes": StrConfig(value=""),
8181
"strides": StrConfig(value=""),
82+
"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
8283
},
8384
post_process={
8485
"sliding_window_shapes": _str_to_tuple,
8586
"strides": _str_to_tuple,
87+
"perturbations_per_eval": int,
8688
},
8789
),
8890
}

captum/metrics/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/usr/bin/env python3
22

3-
from ._core.infidelity import infidelity # noqa
3+
from ._core.infidelity import infidelity, infidelity_perturb_func_decorator # noqa

captum/metrics/_core/infidelity.py

+110-66
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,79 @@
1616
from .._utils.batching import _divide_and_aggregate_metrics
1717

1818

19+
def infidelity_perturb_func_decorator(pertub_func):
20+
r"""
21+
An auxiliary, decorator function that helps with computing
22+
perturbations given perturbed inputs. It can be useful for cases
23+
when `pertub_func` returns only perturbed inputs and we
24+
internally compute the perturbations as
25+
(input - perturbed_input) / (input - baseline).
26+
27+
If users decorate their `pertub_func` with
28+
`@infidelity_perturb_func_decorator` function then their `pertub_func`
29+
needs to only return perturbed inputs.
30+
31+
Note that if your attribution algorithm is inherently local such as
32+
Saliency maps you should not use the decorator because the decorator
33+
always divides by (input - baseline) and that is unnecessary for local
34+
methods.
35+
Args:
36+
37+
pertub_func(callable): Input perturbation function that takes inputs
38+
and optionally baselines and returns perturbed inputs
39+
40+
Returns:
41+
42+
default_perturb_func(callable): Internal default perturbation
43+
function that computes the perturbations internally and returns
44+
perturbations and perturbed inputs.
45+
46+
Examples::
47+
>>> @infidelity_perturb_func_decorator
48+
>>> def perturb_fn(inputs):
49+
>>> noise = torch.tensor(np.random.normal(0, 0.003, inputs.shape)).float()
50+
>>> return inputs - noise
51+
>>> # Computes infidelity score using `perturb_fn`
52+
>>> infidelity = infidelity_attr(model, perturb_fn, input, ...)
53+
54+
"""
55+
56+
def default_perturb_func(inputs, baselines=None):
57+
r"""
58+
"""
59+
inputs_perturbed = (
60+
pertub_func(inputs, baselines)
61+
if baselines is not None
62+
else pertub_func(inputs)
63+
)
64+
inputs_perturbed = _format_tensor_into_tuples(inputs_perturbed)
65+
inputs = _format_tensor_into_tuples(inputs)
66+
baselines = _format_tensor_into_tuples(baselines)
67+
if baselines is None:
68+
perturbations = tuple(
69+
safe_div(
70+
input - input_perturbed,
71+
input,
72+
torch.tensor(1.0, device=input.device),
73+
)
74+
for input, input_perturbed in zip(inputs, inputs_perturbed)
75+
)
76+
else:
77+
perturbations = tuple(
78+
safe_div(
79+
input - input_perturbed,
80+
input - baseline,
81+
torch.tensor(1.0, device=input.device),
82+
)
83+
for input, input_perturbed, baseline in zip(
84+
inputs, inputs_perturbed, baselines
85+
)
86+
)
87+
return perturbations, inputs_perturbed
88+
89+
return default_perturb_func
90+
91+
1992
def infidelity(
2093
forward_func,
2194
perturb_func,
@@ -26,7 +99,6 @@ def infidelity(
2699
target=None,
27100
n_samples=10,
28101
max_examples_per_batch=None,
29-
perturb_func_custom=False,
30102
):
31103
r"""
32104
Explanation infidelity represents the expected mean-squared error
@@ -62,35 +134,52 @@ def infidelity(
62134
The perturbation function of model inputs. This function takes
63135
model inputs and optionally baselines as input arguments and returns
64136
either a tuple of perturbations and perturbed inputs or just
65-
perturbed inputs. If `perturb_func` returns only perturbed inputs
66-
then the users have to set the `perturb_func_custom=True`, this
67-
will allow us to compute the perturbations internally both for local
68-
and global infidelity and makes sense
69-
to use only if input attributions are global attributions.
137+
perturbed inputs. For example:
138+
139+
def my_perturb_func(inputs):
140+
<MY-LOGIC-HERE>
141+
return perturbations, perturbed_inputs
142+
143+
If we want to only return perturbed inputs and compute
144+
perturbations internally then we can wrap perturb_func with
145+
`infidelity_perturb_func_decorator` decorator such as:
146+
147+
from captum.metrics import infidelity_perturb_func_decorator
148+
@infidelity_perturb_func_decorator
149+
def my_perturb_func(inputs):
150+
<MY-LOGIC-HERE>
151+
return perturbed_inputs
152+
153+
In this case we compute perturbations by dividing
154+
(input - perturbed_input) by (input - baselines) and the user needs to
155+
only return perturbed inputs in `perturb_func` as described above.
156+
157+
`infidelity_perturb_func_decorator` makes sense to use only for global
158+
attribution algorithms such as integrated gradients, deeplift, etc.
159+
In case user has a local attribution algorithm or decides to compute
160+
perturbations and perturbed inputs in `perturb_func` then they must not
161+
use `infidelity_perturb_func_decorator`.
70162
71163
If there are more than one inputs passed to infidelity function those
72164
will be passed to `perturb_func` as tuples in the same order as they
73165
are passed to infidelity function.
74166
75-
In case `perturb_func_custom=False` and if inputs
167+
If inputs
76168
- is a single tensor, the function needs to return a tuple
77169
of perturbations and perturbed input such as:
78170
perturb, perturbed_input
171+
172+
and only perturbed_input in case
173+
`infidelity_perturb_func_decorator`
174+
is used.
79175
- is a tuple of tensors,
80176
corresponding perturbations and perturbed inputs must be computed
81177
and returned as tuples in the following format:
82178
(perturb1, perturb2, ... perturbN), (perturbed_input1,
83179
perturbed_input2, ... perturbed_inputN)
84-
85-
In case `perturb_func_custom=True` and if inputs
86-
- is a single tensor, the function needs to return
87-
only perturbed input
88-
perturbed_input
89-
- is a tuple of tensors,
90-
corresponding perturbed inputs must be computed and
91-
returned as tuples in the following format:
92-
(perturbed_input1, perturbed_input2, ... perturbed_inputN)
93-
180+
Similar to previous case here as well we need to return only
181+
perturbed inputs in case `infidelity_perturb_func_decorator`
182+
decorates out perturb_func
94183
It is important to note that for performance reasons `perturb_func`
95184
isn't called for each example individually but on a batch of
96185
input examples that are repeated `max_examples_per_batch / batch_size`
@@ -164,9 +253,8 @@ def infidelity(
164253
tensor as well. If inputs is provided as a tuple of tensors
165254
then attributions will be tuples of tensors as well.
166255
167-
If `perturb_func_custom=True` then we internally divide global
168-
attribution values by (input - baselines) and the user needs to
169-
only return perturbed inputs in `perturb_func`.
256+
For more details on when to use `infidelity_perturb_func_decorator`,
257+
please, read the documentation about `perturb_func`
170258
171259
additional_forward_args (any, optional): If the forward function
172260
requires additional arguments other than the inputs for
@@ -220,17 +308,6 @@ def infidelity(
220308
examples are processed together.
221309
222310
Default: None
223-
perturb_func_custom (boolean, optional): A flag that indicates whether
224-
to use default perturbation logic that always divides the
225-
attributions by (input - baseline). If this flag
226-
is True then `perturb_func` needs to return only the
227-
perturbed inputs.
228-
The perturbations will be computed internally by
229-
`default_perturb_func`. This makes sense to use only with
230-
global attribution values because otherwise there is no need
231-
to divide the attributions by (input - baseline).
232-
233-
Default: False
234311
Returns:
235312
236313
infidelities (tensor): A tensor of scalar infidelity scores per
@@ -254,31 +331,6 @@ def infidelity(
254331
>>> infidelity = infidelity_attr(net, perturb_fn, input, attribution)
255332
"""
256333

257-
def default_perturb_func(inputs, inputs_perturbed, baselines=None):
258-
r"""
259-
"""
260-
if baselines is None:
261-
perturbations = tuple(
262-
safe_div(
263-
input - input_perturbed,
264-
input,
265-
torch.tensor(1.0, device=input.device),
266-
)
267-
for input, input_perturbed in zip(inputs, inputs_perturbed)
268-
)
269-
else:
270-
perturbations = tuple(
271-
safe_div(
272-
input - input_perturbed,
273-
input - baseline,
274-
torch.tensor(1.0, device=input.device),
275-
)
276-
for input, input_perturbed, baseline in zip(
277-
inputs, inputs_perturbed, baselines
278-
)
279-
)
280-
return perturbations, inputs_perturbed
281-
282334
def _generate_perturbations(current_n_samples):
283335
r"""
284336
The perturbations are generated for each example `current_n_samples` times.
@@ -308,6 +360,7 @@ def call_perturb_func():
308360
inputs_expanded = tuple(
309361
torch.repeat_interleave(input, current_n_samples, dim=0) for input in inputs
310362
)
363+
311364
if baselines is not None:
312365
baselines_expanded = tuple(
313366
baseline.repeat_interleave(current_n_samples, dim=0)
@@ -320,16 +373,7 @@ def call_perturb_func():
320373
else:
321374
baselines_expanded = None
322375

323-
perturb_func_out = call_perturb_func()
324-
325-
if perturb_func_custom:
326-
return default_perturb_func(
327-
inputs_expanded,
328-
_format_tensor_into_tuples(perturb_func_out),
329-
baselines=baselines_expanded,
330-
)
331-
else:
332-
return perturb_func_out
376+
return call_perturb_func()
333377

334378
def _validate_inputs_and_perturbations(inputs, inputs_perturbed, perturbations):
335379
# asserts the sizes of the perturbations and inputs

docs/algorithms_comparison_matrix.md

+2
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,5 @@ Please, scroll to the right for more details.
213213
**^ Including Layer Variant**
214214

215215
**˚ Including Neuron Variant**
216+
217+
<a href="/img/algorithms_comparison_matrix.png">Algorithm Comparison Matrix.png</a>

scripts/update_versions_html.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def prepend_url(a_tag, base_url, version):
3535

3636
# nav
3737
nav_links = soup.find("nav").findAll("a")
38-
for l in nav_links:
39-
l.attrs["href"] = prepend_url(l, base_url, v)
38+
for link in nav_links:
39+
link.attrs["href"] = prepend_url(link, base_url, v)
4040

4141
# version link
4242
t = soup.find("h2", {"class": "headerTitleWithLogo"}).find_next("a")

tests/attr/test_data_parallel.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#!/usr/bin/env python3
22
import copy
3+
import os
34
from enum import Enum
45
from typing import Any, Callable, Dict, Optional, Tuple, Type, cast
56

67
import torch
8+
import torch.distributed as dist
79
from torch import Tensor
810
from torch.nn import Module
911

@@ -22,14 +24,17 @@
2224
from .helpers.gen_test_utils import gen_test_name, parse_test_config
2325
from .helpers.test_config import config
2426

25-
2627
"""
2728
Tests in this file are dynamically generated based on the config
2829
defined in tests/attr/helpers/test_config.py. To add new test cases,
2930
read the documentation in test_config.py and add cases based on the
3031
schema described there.
3132
"""
3233

34+
# Distributed Data Parallel env setup
35+
os.environ["MASTER_ADDR"] = "127.0.0.1"
36+
os.environ["MASTER_PORT"] = "29500"
37+
3338

3439
class DataParallelCompareMode(Enum):
3540
"""
@@ -44,6 +49,7 @@ class DataParallelCompareMode(Enum):
4449
cpu_cuda = 1
4550
data_parallel_default = 2
4651
data_parallel_alt_dev_ids = 3
52+
dist_data_parallel = 4
4753

4854

4955
class DataParallelMeta(type):
@@ -146,6 +152,13 @@ def data_parallel_test_assert(self) -> None:
146152
),
147153
)
148154
args_1, args_2 = cuda_args, cuda_args
155+
elif mode is DataParallelCompareMode.dist_data_parallel:
156+
dist.init_process_group(backend="gloo", rank=0, world_size=1)
157+
model_1, model_2 = (
158+
cuda_model,
159+
torch.nn.parallel.DistributedDataParallel(cuda_model),
160+
)
161+
args_1, args_2 = cuda_args, cuda_args
149162
else:
150163
raise AssertionError("DataParallel compare mode type is not valid.")
151164

@@ -219,8 +232,11 @@ def data_parallel_test_assert(self) -> None:
219232
self, attributions_1, attributions_2, mode="max", delta=dp_delta
220233
)
221234

235+
if mode is DataParallelCompareMode.dist_data_parallel:
236+
dist.destroy_process_group()
237+
222238
return data_parallel_test_assert
223239

224240

225-
class DataParallelTest(BaseGPUTest):
241+
class DataParallelTest(BaseGPUTest, metaclass=DataParallelMeta):
226242
pass

tests/attr/test_hook_removal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class HookRemovalMeta(type):
5252
"""
5353

5454
def __new__(cls, name: str, bases: Tuple, attrs: Dict):
55-
created_tests = {}
55+
created_tests: Dict[Tuple[Type[Attribution], HookRemovalMode], bool] = {}
5656
for test_config in config:
5757
(algorithms, model, args, layer, noise_tunnel, _,) = parse_test_config(
5858
test_config

0 commit comments

Comments
 (0)