1
1
#!/usr/bin/env python3
2
2
from collections import defaultdict
3
3
from copy import copy
4
- from typing import Callable , Dict , List , Optional , Tuple , Union
4
+ from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple , Union
5
5
6
6
import torch
7
7
from captum ._utils .common import (
@@ -32,11 +32,78 @@ def _concat_tensors(accum, cur_output, _):
32
32
return cur_output if accum is None else torch .cat ([accum , cur_output ])
33
33
34
34
35
+ def _create_perturbation_mask (
36
+ perturbed_feature_indices : Tensor , # 1D tensor of one-hot feature indices
37
+ feature_mask : Tuple [Tensor , ...],
38
+ feature_idx_to_mask_idx : Dict [int , List [int ]],
39
+ ) -> Tuple [Union [Tensor , None ], ...]:
40
+ """
41
+ Create binary mask for inputs based on perturbed one-hot feature indices
42
+ Use None if no perturbation is needed for the corresponding input
43
+ """
44
+
45
+ # a set of input/mask indices that need perturbation
46
+ perturbation_mask_indices = set ()
47
+ for i , v in enumerate (perturbed_feature_indices .tolist ()):
48
+ # value 0 means the feature has been perturbed
49
+ if not v :
50
+ perturbation_mask_indices |= set (feature_idx_to_mask_idx [i ])
51
+
52
+ # create binary mask for inputs & set it to None if no perturbation is needed
53
+ perturbation_mask = tuple (
54
+ perturbed_feature_indices [mask_elem ] if i in perturbation_mask_indices else None
55
+ for i , mask_elem in enumerate (feature_mask )
56
+ )
57
+
58
+ return perturbation_mask
59
+
60
+
61
+ def _perturb_inputs (
62
+ inputs : Iterable [Any ],
63
+ input_roles : Tuple [int ],
64
+ baselines : Tuple [Union [int , float , Tensor ], ...],
65
+ perturbation_mask : Tuple [Union [Tensor , None ], ...],
66
+ ) -> Tuple [Any , ...]:
67
+ """
68
+ Perturb inputs based on perturbation mask and baselines
69
+ """
70
+
71
+ perturbed_inputs = []
72
+ attr_inp_count = 0
73
+
74
+ for inp , role in zip (inputs , input_roles ):
75
+ if role != InputRole .need_attr :
76
+ perturbed_inputs .append (inp )
77
+ continue
78
+
79
+ pert_mask = perturbation_mask [attr_inp_count ]
80
+
81
+ # no perturbation is needed for this input
82
+ if pert_mask is None :
83
+ perturbed_inputs .append (inp )
84
+ else :
85
+ baseline = baselines [attr_inp_count ]
86
+
87
+ perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask )
88
+ perturbed_inputs .append (perturbed_inp )
89
+
90
+ attr_inp_count += 1
91
+
92
+ perturbed_inputs = tuple (perturbed_inputs )
93
+
94
+ return perturbed_inputs
95
+
96
+
35
97
def _convert_output_shape (
36
98
unique_attr : Tensor ,
37
99
attr_inputs : Tuple [Tensor , ...],
38
100
feature_mask : Tuple [Tensor , ...],
39
101
) -> Tuple [Tensor , ...]:
102
+ """
103
+ Convert the shape of a single tensor of unique feature attributionto
104
+ to match the shape of the inputs returned by dataloader
105
+ """
106
+
40
107
# unique_attr in shape(*output_dims, n_features)
41
108
output_dims = unique_attr .shape [:- 1 ]
42
109
n_features = unique_attr .shape [- 1 ]
@@ -107,77 +174,75 @@ def __init__(self, attr_method: Attribution) -> None:
107
174
108
175
def _forward_with_dataloader (
109
176
self ,
110
- perturbed_feature_indices ,
177
+ batched_perturbed_feature_indices : Tensor ,
111
178
dataloader : torch .utils .data .DataLoader ,
112
179
input_roles : Tuple [int ],
113
180
baselines : Tuple [Union [int , float , Tensor ], ...],
114
181
feature_mask : Tuple [Tensor , ...],
115
182
reduce : Callable ,
116
183
to_metric : Optional [Callable ],
117
- perturbation_per_pass : int ,
118
184
show_progress : bool ,
119
185
feature_idx_to_mask_idx : Dict [int , List [int ]],
120
186
):
121
- # a set of input/mask indices that need perturbation
122
- perturbation_mask_indices = set ()
123
- for i , v in enumerate (perturbed_feature_indices [0 ].tolist ()):
124
- # value 0 means the feature has been perturbed
125
- if not v :
126
- perturbation_mask_indices |= set (feature_idx_to_mask_idx [i ])
127
-
128
- # create binary mask for inputs & set it to None if no perturbation is needed
129
- perturbation_mask = tuple (
130
- perturbed_feature_indices [0 ][mask_elem ]
131
- if i in perturbation_mask_indices
132
- else None
133
- for i , mask_elem in enumerate (feature_mask )
134
- )
135
-
136
- accum = None
137
- for inputs in dataloader :
138
- perturbed_inputs = []
139
- attr_inp_count = 0
140
-
141
- for inp , role in zip (inputs , input_roles ):
142
- if role != InputRole .need_attr :
143
- perturbed_inputs .append (inp )
144
- continue
145
-
146
- pert_mask = perturbation_mask [attr_inp_count ]
187
+ """
188
+ Wrapper of the original given forward_func to be used in the attribution method
189
+ It iterates over the dataloader with the given forward_func
190
+ """
147
191
148
- # no perturbation is needed for this input
149
- if pert_mask is None :
150
- perturbed_inputs .append (inp )
151
- else :
152
- baseline = baselines [attr_inp_count ]
192
+ # batched_perturbed_feature_indices in shape(n_perturb, n_features)
193
+ # n_perturb is not always the same as perturb_per_pass if not enough perturb
194
+ perturbation_mask_list : List [Tuple [Union [Tensor , None ], ...]] = [
195
+ _create_perturbation_mask (
196
+ perturbed_feature_indices ,
197
+ feature_mask ,
198
+ feature_idx_to_mask_idx ,
199
+ )
200
+ for perturbed_feature_indices in batched_perturbed_feature_indices
201
+ ]
153
202
154
- perturbed_inp = inp * pert_mask + baseline * ( 1 - pert_mask )
155
- perturbed_inputs . append ( perturbed_inp )
203
+ # each perturbation needs an accum state
204
+ accum_states = [ None for _ in range ( len ( perturbation_mask_list ))]
156
205
157
- attr_inp_count += 1
206
+ # tranverse the dataloader
207
+ for inputs in dataloader :
208
+ # for each batch read from the dataloader,
209
+ # apply every perturbation based on perturbations_per_pass
210
+ for i , perturbation_mask in enumerate (perturbation_mask_list ):
211
+ perturbed_inputs = _perturb_inputs (
212
+ inputs , input_roles , baselines , perturbation_mask
213
+ )
158
214
159
- perturbed_inputs = tuple (perturbed_inputs )
215
+ # due to explicitly defined roles
216
+ # we can keep inputs in their original order
217
+ # regardless of if they need attr
218
+ # instead of using additional_forward_inputs
219
+ forward_inputs = tuple (
220
+ _
221
+ for _ , role in zip (perturbed_inputs , input_roles )
222
+ if role != InputRole .no_forward
223
+ )
160
224
161
- # due to explicitly defined roles
162
- # we can keep inputs in their original order regardless of if they need attr
163
- # instead of using additional_forward_inputs to always appeend in the end
164
- forward_inputs = tuple (
165
- _
166
- for _ , role in zip (perturbed_inputs , input_roles )
167
- if role != InputRole .no_forward
168
- )
225
+ output = _run_forward (
226
+ self .forward_func ,
227
+ forward_inputs ,
228
+ )
169
229
170
- output = _run_forward (
171
- self .forward_func ,
172
- forward_inputs ,
173
- )
230
+ accum_states [i ] = reduce (accum_states [i ], output , perturbed_inputs )
174
231
175
- accum = reduce (accum , output , perturbed_inputs )
232
+ accum_results = [
233
+ to_metric (accum ) if to_metric else accum for accum in accum_states
234
+ ]
176
235
177
- if to_metric is not None :
178
- return to_metric (accum )
236
+ assert all (type (r ) is Tensor for r in accum_results ), (
237
+ "Accumulated metrics for attribution must be a Tensor,"
238
+ f"received: { next (r for r in accum_results if type (r ) is not Tensor )} "
239
+ )
179
240
180
- return accum
241
+ # shape(n_perturb * output_dims[0], *output_dims[1:])
242
+ # the underneath attr method needs to support forward_func output's
243
+ # 1st dim to grow with perturb_per_eval
244
+ batched_accum = torch .stack (accum_results , dim = 0 )
245
+ return batched_accum
181
246
182
247
def attribute (
183
248
self ,
@@ -187,7 +252,7 @@ def attribute(
187
252
feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
188
253
reduce : Optional [Callable ] = None ,
189
254
to_metric : Optional [Callable ] = None ,
190
- perturbation_per_pass : int = - 1 ,
255
+ perturbations_per_pass : int = 1 ,
191
256
show_progress : bool = False ,
192
257
return_input_shape : bool = True ,
193
258
) -> Union [Tensor , Tuple [Tensor , ...]]:
@@ -240,16 +305,17 @@ def attribute(
240
305
metric (Tensor): final result to be attributed, must be a Tensor
241
306
242
307
If None, will directly attribute w.r.t the reduced ``accum``
243
- perturbation_per_pass (int, optional
308
+ perturbations_per_pass (int, optional) the number perturbations to execute
244
309
concurrently in each traverse of the dataloader. The number of
245
- traverses is ceil(n_perturbations / perturbation_per_pass).
246
- The parameter offers a control of the trade-off between memory
310
+ traverses needed is
311
+ ceil(n_perturbations / perturbations_per_pass).
312
+
313
+ This arguement offers control of the trade-off between memory
247
314
and efficiency. If the dataloader involves slow operations like
248
315
remote request or file I/O, multiple traversals can be
249
- inefficient. Each perturbation needs to store its accumulated
250
- outputs of the reduce function until the end of the data
251
- traverse. If the value is -1, all perturbations are concurrent
252
- in a single traverse.
316
+ inefficient. On the other hand, each perturbation needs to
317
+ store its accumulated outputs of the reduce
318
+ function until the end of the data traverse.
253
319
return_input_shape (bool, optional): if True, returns the attribution
254
320
following the input shapes given by the dataloader.
255
321
Otherwise, returns a single tensor for the attributions of
@@ -352,14 +418,14 @@ def attribute(
352
418
# unique_attr in shape(*output_dims, n_features)
353
419
unique_attr = self .attr_method .attribute (
354
420
feature_indices ,
421
+ perturbations_per_eval = perturbations_per_pass ,
355
422
additional_forward_args = (
356
423
dataloader ,
357
424
input_roles ,
358
425
baselines ,
359
426
feature_mask ,
360
427
reduce ,
361
428
to_metric ,
362
- perturbation_per_pass ,
363
429
show_progress ,
364
430
feature_idx_to_mask_idx ,
365
431
),
0 commit comments