Skip to content

Commit 907ce81

Browse files
aobo-yfacebook-github-bot
authored andcommitted
Enable perturbations_per_pass in DataloaderAttribution (#1158)
Summary: Pull Request resolved: #1158 Enable argument `perturbations_per_pass` in `DataloaderAttribution` to support multiple perturbation in a single traverse of the dataloader Differential Revision: D46965996 fbshipit-source-id: c50d9a0c27067f7ea27fdbaa5e6ab208aab1f76d
1 parent b1a9830 commit 907ce81

File tree

2 files changed

+177
-71
lines changed

2 files changed

+177
-71
lines changed

captum/attr/_core/dataloader_attr.py

+129-63
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
from collections import defaultdict
33
from copy import copy
4-
from typing import Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
55

66
import torch
77
from captum._utils.common import (
@@ -32,11 +32,78 @@ def _concat_tensors(accum, cur_output, _):
3232
return cur_output if accum is None else torch.cat([accum, cur_output])
3333

3434

35+
def _create_perturbation_mask(
36+
perturbed_feature_indices: Tensor, # 1D tensor of one-hot feature indices
37+
feature_mask: Tuple[Tensor, ...],
38+
feature_idx_to_mask_idx: Dict[int, List[int]],
39+
) -> Tuple[Union[Tensor, None], ...]:
40+
"""
41+
Create binary mask for inputs based on perturbed one-hot feature indices
42+
Use None if no perturbation is needed for the corresponding input
43+
"""
44+
45+
# a set of input/mask indices that need perturbation
46+
perturbation_mask_indices = set()
47+
for i, v in enumerate(perturbed_feature_indices.tolist()):
48+
# value 0 means the feature has been perturbed
49+
if not v:
50+
perturbation_mask_indices |= set(feature_idx_to_mask_idx[i])
51+
52+
# create binary mask for inputs & set it to None if no perturbation is needed
53+
perturbation_mask = tuple(
54+
perturbed_feature_indices[mask_elem] if i in perturbation_mask_indices else None
55+
for i, mask_elem in enumerate(feature_mask)
56+
)
57+
58+
return perturbation_mask
59+
60+
61+
def _perturb_inputs(
62+
inputs: Iterable[Any],
63+
input_roles: Tuple[int],
64+
baselines: Tuple[Union[int, float, Tensor], ...],
65+
perturbation_mask: Tuple[Union[Tensor, None], ...],
66+
) -> Tuple[Any, ...]:
67+
"""
68+
Perturb inputs based on perturbation mask and baselines
69+
"""
70+
71+
perturbed_inputs = []
72+
attr_inp_count = 0
73+
74+
for inp, role in zip(inputs, input_roles):
75+
if role != InputRole.need_attr:
76+
perturbed_inputs.append(inp)
77+
continue
78+
79+
pert_mask = perturbation_mask[attr_inp_count]
80+
81+
# no perturbation is needed for this input
82+
if pert_mask is None:
83+
perturbed_inputs.append(inp)
84+
else:
85+
baseline = baselines[attr_inp_count]
86+
87+
perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask)
88+
perturbed_inputs.append(perturbed_inp)
89+
90+
attr_inp_count += 1
91+
92+
perturbed_inputs = tuple(perturbed_inputs)
93+
94+
return perturbed_inputs
95+
96+
3597
def _convert_output_shape(
3698
unique_attr: Tensor,
3799
attr_inputs: Tuple[Tensor, ...],
38100
feature_mask: Tuple[Tensor, ...],
39101
) -> Tuple[Tensor, ...]:
102+
"""
103+
Convert the shape of a single tensor of unique feature attributionto
104+
to match the shape of the inputs returned by dataloader
105+
"""
106+
40107
# unique_attr in shape(*output_dims, n_features)
41108
output_dims = unique_attr.shape[:-1]
42109
n_features = unique_attr.shape[-1]
@@ -107,77 +174,75 @@ def __init__(self, attr_method: Attribution) -> None:
107174

108175
def _forward_with_dataloader(
109176
self,
110-
perturbed_feature_indices,
177+
batched_perturbed_feature_indices: Tensor,
111178
dataloader: torch.utils.data.DataLoader,
112179
input_roles: Tuple[int],
113180
baselines: Tuple[Union[int, float, Tensor], ...],
114181
feature_mask: Tuple[Tensor, ...],
115182
reduce: Callable,
116183
to_metric: Optional[Callable],
117-
perturbation_per_pass: int,
118184
show_progress: bool,
119185
feature_idx_to_mask_idx: Dict[int, List[int]],
120186
):
121-
# a set of input/mask indices that need perturbation
122-
perturbation_mask_indices = set()
123-
for i, v in enumerate(perturbed_feature_indices[0].tolist()):
124-
# value 0 means the feature has been perturbed
125-
if not v:
126-
perturbation_mask_indices |= set(feature_idx_to_mask_idx[i])
127-
128-
# create binary mask for inputs & set it to None if no perturbation is needed
129-
perturbation_mask = tuple(
130-
perturbed_feature_indices[0][mask_elem]
131-
if i in perturbation_mask_indices
132-
else None
133-
for i, mask_elem in enumerate(feature_mask)
134-
)
135-
136-
accum = None
137-
for inputs in dataloader:
138-
perturbed_inputs = []
139-
attr_inp_count = 0
140-
141-
for inp, role in zip(inputs, input_roles):
142-
if role != InputRole.need_attr:
143-
perturbed_inputs.append(inp)
144-
continue
145-
146-
pert_mask = perturbation_mask[attr_inp_count]
187+
"""
188+
Wrapper of the original given forward_func to be used in the attribution method
189+
It iterates over the dataloader with the given forward_func
190+
"""
147191

148-
# no perturbation is needed for this input
149-
if pert_mask is None:
150-
perturbed_inputs.append(inp)
151-
else:
152-
baseline = baselines[attr_inp_count]
192+
# batched_perturbed_feature_indices in shape(n_perturb, n_features)
193+
# n_perturb is not always the same as perturb_per_pass if not enough perturb
194+
perturbation_mask_list: List[Tuple[Union[Tensor, None], ...]] = [
195+
_create_perturbation_mask(
196+
perturbed_feature_indices,
197+
feature_mask,
198+
feature_idx_to_mask_idx,
199+
)
200+
for perturbed_feature_indices in batched_perturbed_feature_indices
201+
]
153202

154-
perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask)
155-
perturbed_inputs.append(perturbed_inp)
203+
# each perturbation needs an accum state
204+
accum_states = [None for _ in range(len(perturbation_mask_list))]
156205

157-
attr_inp_count += 1
206+
# tranverse the dataloader
207+
for inputs in dataloader:
208+
# for each batch read from the dataloader,
209+
# apply every perturbation based on perturbations_per_pass
210+
for i, perturbation_mask in enumerate(perturbation_mask_list):
211+
perturbed_inputs = _perturb_inputs(
212+
inputs, input_roles, baselines, perturbation_mask
213+
)
158214

159-
perturbed_inputs = tuple(perturbed_inputs)
215+
# due to explicitly defined roles
216+
# we can keep inputs in their original order
217+
# regardless of if they need attr
218+
# instead of using additional_forward_inputs
219+
forward_inputs = tuple(
220+
_
221+
for _, role in zip(perturbed_inputs, input_roles)
222+
if role != InputRole.no_forward
223+
)
160224

161-
# due to explicitly defined roles
162-
# we can keep inputs in their original order regardless of if they need attr
163-
# instead of using additional_forward_inputs to always appeend in the end
164-
forward_inputs = tuple(
165-
_
166-
for _, role in zip(perturbed_inputs, input_roles)
167-
if role != InputRole.no_forward
168-
)
225+
output = _run_forward(
226+
self.forward_func,
227+
forward_inputs,
228+
)
169229

170-
output = _run_forward(
171-
self.forward_func,
172-
forward_inputs,
173-
)
230+
accum_states[i] = reduce(accum_states[i], output, perturbed_inputs)
174231

175-
accum = reduce(accum, output, perturbed_inputs)
232+
accum_results = [
233+
to_metric(accum) if to_metric else accum for accum in accum_states
234+
]
176235

177-
if to_metric is not None:
178-
return to_metric(accum)
236+
assert all(type(r) is Tensor for r in accum_results), (
237+
"Accumulated metrics for attribution must be a Tensor,"
238+
f"received: {next(r for r in accum_results if type(r) is not Tensor)}"
239+
)
179240

180-
return accum
241+
# shape(n_perturb * output_dims[0], *output_dims[1:])
242+
# the underneath attr method needs to support forward_func output's
243+
# 1st dim to grow with perturb_per_eval
244+
batched_accum = torch.stack(accum_results, dim=0)
245+
return batched_accum
181246

182247
def attribute(
183248
self,
@@ -187,7 +252,7 @@ def attribute(
187252
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
188253
reduce: Optional[Callable] = None,
189254
to_metric: Optional[Callable] = None,
190-
perturbation_per_pass: int = -1,
255+
perturbations_per_pass: int = 1,
191256
show_progress: bool = False,
192257
return_input_shape: bool = True,
193258
) -> Union[Tensor, Tuple[Tensor, ...]]:
@@ -240,16 +305,17 @@ def attribute(
240305
metric (Tensor): final result to be attributed, must be a Tensor
241306
242307
If None, will directly attribute w.r.t the reduced ``accum``
243-
perturbation_per_pass (int, optional
308+
perturbations_per_pass (int, optional) the number perturbations to execute
244309
concurrently in each traverse of the dataloader. The number of
245-
traverses is ceil(n_perturbations / perturbation_per_pass).
246-
The parameter offers a control of the trade-off between memory
310+
traverses needed is
311+
ceil(n_perturbations / perturbations_per_pass).
312+
313+
This arguement offers control of the trade-off between memory
247314
and efficiency. If the dataloader involves slow operations like
248315
remote request or file I/O, multiple traversals can be
249-
inefficient. Each perturbation needs to store its accumulated
250-
outputs of the reduce function until the end of the data
251-
traverse. If the value is -1, all perturbations are concurrent
252-
in a single traverse.
316+
inefficient. On the other hand, each perturbation needs to
317+
store its accumulated outputs of the reduce
318+
function until the end of the data traverse.
253319
return_input_shape (bool, optional): if True, returns the attribution
254320
following the input shapes given by the dataloader.
255321
Otherwise, returns a single tensor for the attributions of
@@ -352,14 +418,14 @@ def attribute(
352418
# unique_attr in shape(*output_dims, n_features)
353419
unique_attr = self.attr_method.attribute(
354420
feature_indices,
421+
perturbations_per_eval=perturbations_per_pass,
355422
additional_forward_args=(
356423
dataloader,
357424
input_roles,
358425
baselines,
359426
feature_mask,
360427
reduce,
361428
to_metric,
362-
perturbation_per_pass,
363429
show_progress,
364430
feature_idx_to_mask_idx,
365431
),

tests/attr/test_dataloader_attr.py

+48-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env fbpython
2+
import math
23
from typing import cast
4+
from unittest.mock import Mock, patch
35

46
import torch
57

@@ -12,6 +14,7 @@
1214
BaseTest,
1315
)
1416
from torch import Tensor
17+
from torch.utils.data import DataLoader, TensorDataset
1518

1619

1720
def sum_forward(*inps):
@@ -29,7 +32,7 @@ def forward(self, *inps):
2932
return self.linear(torch.cat(inps, dim=1))
3033

3134

32-
mock_dataset = torch.utils.data.TensorDataset(
35+
mock_dataset = TensorDataset(
3336
# iD feature
3437
torch.tensor(
3538
[
@@ -74,7 +77,7 @@ def test_dl_attr(self, forward) -> None:
7477
fa = FeatureAblation(forward)
7578
dl_fa = DataloaderAttribution(fa)
7679

77-
dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=2)
80+
dataloader = DataLoader(mock_dataset, batch_size=2)
7881

7982
dl_attributions = dl_fa.attribute(dataloader)
8083

@@ -108,7 +111,7 @@ def test_dl_attr_with_mask(self, forward) -> None:
108111
fa = FeatureAblation(forward)
109112
dl_fa = DataloaderAttribution(fa)
110113

111-
dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=2)
114+
dataloader = DataLoader(mock_dataset, batch_size=2)
112115

113116
dl_attributions = dl_fa.attribute(dataloader, feature_mask=masks)
114117

@@ -140,7 +143,7 @@ def test_dl_attr_with_baseline(self, forward) -> None:
140143
fa = FeatureAblation(forward)
141144
dl_fa = DataloaderAttribution(fa)
142145

143-
dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=2)
146+
dataloader = DataLoader(mock_dataset, batch_size=2)
144147

145148
dl_attributions = dl_fa.attribute(dataloader, baselines=baselines)
146149

@@ -188,7 +191,7 @@ def to_metric(accum):
188191
dl_fa = DataloaderAttribution(fa)
189192

190193
batch_size = 2
191-
dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=batch_size)
194+
dataloader = DataLoader(mock_dataset, batch_size=batch_size)
192195

193196
dl_attribution = dl_fa.attribute(
194197
dataloader,
@@ -243,7 +246,7 @@ def forward(*forward_inputs):
243246
dl_fa = DataloaderAttribution(fa)
244247

245248
batch_size = 2
246-
dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=batch_size)
249+
dataloader = DataLoader(mock_dataset, batch_size=batch_size)
247250

248251
dl_attributions = dl_fa.attribute(
249252
dataloader,
@@ -282,7 +285,7 @@ def test_dl_attr_not_return_input_shape(self) -> None:
282285
fa = FeatureAblation(forward)
283286
dl_fa = DataloaderAttribution(fa)
284287

285-
dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=2)
288+
dataloader = DataLoader(mock_dataset, batch_size=2)
286289

287290
dl_attribution = dl_fa.attribute(dataloader, return_input_shape=False)
288291

@@ -320,7 +323,7 @@ def test_dl_attr_with_mask_not_return_input_shape(self) -> None:
320323
fa = FeatureAblation(forward)
321324
dl_fa = DataloaderAttribution(fa)
322325

323-
dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=2)
326+
dataloader = DataLoader(mock_dataset, batch_size=2)
324327

325328
dl_attribution = dl_fa.attribute(
326329
dataloader, feature_mask=masks, return_input_shape=False
@@ -331,3 +334,40 @@ def test_dl_attr_with_mask_not_return_input_shape(self) -> None:
331334
self.assertEqual(type(dl_attribution), Tensor)
332335
dl_attribution = cast(Tensor, dl_attribution)
333336
self.assertEqual(dl_attribution.shape, expected_attr_shape)
337+
338+
@parameterized.expand([(2,), (3,), (4,)])
339+
def test_dl_attr_with_perturb_per_pass(self, perturb_per_pass) -> None:
340+
forward = sum_forward
341+
342+
fa = FeatureAblation(forward)
343+
dl_fa = DataloaderAttribution(fa)
344+
345+
mock_dl_iter = Mock(wraps=DataLoader.__iter__)
346+
347+
with patch.object(DataLoader, "__iter__", lambda self: mock_dl_iter(self)):
348+
dataloader = DataLoader(mock_dataset, batch_size=2)
349+
350+
dl_attributions = dl_fa.attribute(
351+
dataloader, perturbations_per_pass=perturb_per_pass
352+
)
353+
354+
n_features = 7
355+
# 2 extra iter calls: get one input for format; get unperturbed output
356+
n_iter_overhead = 2
357+
358+
self.assertEqual(
359+
mock_dl_iter.call_count,
360+
math.ceil(n_features / perturb_per_pass) + n_iter_overhead,
361+
)
362+
363+
# default reduce of DataloaderAttribution works the same as concat all batches
364+
attr_list = []
365+
for batch in dataloader:
366+
batch_attr = fa.attribute(tuple(batch))
367+
attr_list.append(batch_attr)
368+
369+
expected_attr = tuple(
370+
torch.cat(feature_attrs, dim=0) for feature_attrs in zip(*attr_list)
371+
)
372+
373+
assertAttributionComparision(self, dl_attributions, expected_attr)

0 commit comments

Comments
 (0)