1
1
#!/usr/bin/env python3
2
2
from collections import defaultdict
3
3
from copy import copy
4
- from typing import Callable , Dict , List , Optional , Tuple , Union
4
+ from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple , Union
5
5
6
6
import torch
7
7
from captum ._utils .common import (
@@ -32,11 +32,78 @@ def _concat_tensors(accum, cur_output, _):
32
32
return cur_output if accum is None else torch .cat ([accum , cur_output ])
33
33
34
34
35
+ def _create_perturbation_mask (
36
+ perturbed_feature_indices : Tensor , # 1D tensor of one-hot feature indices
37
+ feature_mask : Tuple [Tensor , ...],
38
+ feature_idx_to_mask_idx : Dict [int , List [int ]],
39
+ ) -> Tuple [Union [Tensor , None ], ...]:
40
+ """
41
+ Create binary mask for inputs based on perturbed one-hot feature indices
42
+ Use None if no perturbation is needed for the corresponding input
43
+ """
44
+
45
+ # a set of input/mask indices that need perturbation
46
+ perturbation_mask_indices = set ()
47
+ for i , v in enumerate (perturbed_feature_indices .tolist ()):
48
+ # value 0 means the feature has been perturbed
49
+ if not v :
50
+ perturbation_mask_indices |= set (feature_idx_to_mask_idx [i ])
51
+
52
+ # create binary mask for inputs & set it to None if no perturbation is needed
53
+ perturbation_mask = tuple (
54
+ perturbed_feature_indices [mask_elem ] if i in perturbation_mask_indices else None
55
+ for i , mask_elem in enumerate (feature_mask )
56
+ )
57
+
58
+ return perturbation_mask
59
+
60
+
61
+ def _perturb_inputs (
62
+ inputs : Iterable [Any ],
63
+ input_roles : Tuple [int ],
64
+ baselines : Tuple [Union [int , float , Tensor ], ...],
65
+ perturbation_mask : Tuple [Union [Tensor , None ], ...],
66
+ ) -> Tuple [Any , ...]:
67
+ """
68
+ Perturb inputs based on perturbation mask and baselines
69
+ """
70
+
71
+ perturbed_inputs = []
72
+ attr_inp_count = 0
73
+
74
+ for inp , role in zip (inputs , input_roles ):
75
+ if role != InputRole .need_attr :
76
+ perturbed_inputs .append (inp )
77
+ continue
78
+
79
+ pert_mask = perturbation_mask [attr_inp_count ]
80
+
81
+ # no perturbation is needed for this input
82
+ if pert_mask is None :
83
+ perturbed_inputs .append (inp )
84
+ else :
85
+ baseline = baselines [attr_inp_count ]
86
+
87
+ perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask )
88
+ perturbed_inputs .append (perturbed_inp )
89
+
90
+ attr_inp_count += 1
91
+
92
+ perturbed_inputs = tuple (perturbed_inputs )
93
+
94
+ return perturbed_inputs
95
+
96
+
35
97
def _convert_output_shape (
36
98
unique_attr : Tensor ,
37
99
attr_inputs : Tuple [Tensor , ...],
38
100
feature_mask : Tuple [Tensor , ...],
39
101
) -> Tuple [Tensor , ...]:
102
+ """
103
+ Convert the shape of a single tensor of unique feature attributionto
104
+ to match the shape of the inputs returned by dataloader
105
+ """
106
+
40
107
# unique_attr in shape(*output_dims, n_features)
41
108
output_dims = unique_attr .shape [:- 1 ]
42
109
n_features = unique_attr .shape [- 1 ]
@@ -107,77 +174,73 @@ def __init__(self, attr_method: Attribution) -> None:
107
174
108
175
def _forward_with_dataloader (
109
176
self ,
110
- perturbed_feature_indices ,
177
+ batched_perturbed_feature_indices : Tensor ,
111
178
dataloader : torch .utils .data .DataLoader ,
112
179
input_roles : Tuple [int ],
113
180
baselines : Tuple [Union [int , float , Tensor ], ...],
114
181
feature_mask : Tuple [Tensor , ...],
115
182
reduce : Callable ,
116
183
to_metric : Optional [Callable ],
117
- perturbation_per_pass : int ,
118
184
show_progress : bool ,
119
185
feature_idx_to_mask_idx : Dict [int , List [int ]],
120
186
):
121
- # a set of input/mask indices that need perturbation
122
- perturbation_mask_indices = set ()
123
- for i , v in enumerate (perturbed_feature_indices [0 ].tolist ()):
124
- # value 0 means the feature has been perturbed
125
- if not v :
126
- perturbation_mask_indices |= set (feature_idx_to_mask_idx [i ])
127
-
128
- # create binary mask for inputs & set it to None if no perturbation is needed
129
- perturbation_mask = tuple (
130
- perturbed_feature_indices [0 ][mask_elem ]
131
- if i in perturbation_mask_indices
132
- else None
133
- for i , mask_elem in enumerate (feature_mask )
134
- )
135
-
136
- accum = None
137
- for inputs in dataloader :
138
- perturbed_inputs = []
139
- attr_inp_count = 0
140
-
141
- for inp , role in zip (inputs , input_roles ):
142
- if role != InputRole .need_attr :
143
- perturbed_inputs .append (inp )
144
- continue
145
-
146
- pert_mask = perturbation_mask [attr_inp_count ]
187
+ """
188
+ Wrapper of the original given forward_func to be used in the attribution method
189
+ It iterates over the dataloader with the given forward_func
190
+ """
147
191
148
- # no perturbation is needed for this input
149
- if pert_mask is None :
150
- perturbed_inputs .append (inp )
151
- else :
152
- baseline = baselines [attr_inp_count ]
192
+ # batched_perturbed_feature_indices in shape(n_perturb, n_features)
193
+ # n_perturb is not always the same as perturb_per_pass if not enough perturb
194
+ perturbation_mask_list : List [Tuple [Union [Tensor , None ], ...]] = [
195
+ _create_perturbation_mask (
196
+ perturbed_feature_indices ,
197
+ feature_mask ,
198
+ feature_idx_to_mask_idx ,
199
+ )
200
+ for perturbed_feature_indices in batched_perturbed_feature_indices
201
+ ]
153
202
154
- perturbed_inp = inp * pert_mask + baseline * ( 1 - pert_mask )
155
- perturbed_inputs . append ( perturbed_inp )
203
+ # each perturbation needs an accum state
204
+ accum_states = [ None for _ in range ( len ( perturbation_mask_list ))]
156
205
157
- attr_inp_count += 1
206
+ # tranverse the dataloader
207
+ for inputs in dataloader :
208
+ # for each batch read from the dataloader,
209
+ # apply every perturbation based on perturbations_per_pass
210
+ for i , perturbation_mask in enumerate (perturbation_mask_list ):
211
+ perturbed_inputs = _perturb_inputs (
212
+ inputs , input_roles , baselines , perturbation_mask
213
+ )
158
214
159
- perturbed_inputs = tuple (perturbed_inputs )
215
+ # due to explicitly defined roles
216
+ # we can keep inputs in their original order
217
+ # regardless of if they need attr
218
+ # instead of using additional_forward_inputs
219
+ forward_inputs = tuple (
220
+ _
221
+ for _ , role in zip (perturbed_inputs , input_roles )
222
+ if role != InputRole .no_forward
223
+ )
160
224
161
- # due to explicitly defined roles
162
- # we can keep inputs in their original order regardless of if they need attr
163
- # instead of using additional_forward_inputs to always appeend in the end
164
- forward_inputs = tuple (
165
- _
166
- for _ , role in zip (perturbed_inputs , input_roles )
167
- if role != InputRole .no_forward
168
- )
225
+ output = _run_forward (
226
+ self .forward_func ,
227
+ forward_inputs ,
228
+ )
169
229
170
- output = _run_forward (
171
- self .forward_func ,
172
- forward_inputs ,
173
- )
230
+ accum_states [i ] = reduce (accum_states [i ], output , perturbed_inputs )
174
231
175
- accum = reduce ( accum , output , perturbed_inputs )
232
+ accum_results = [ to_metric ( accum ) if to_metric else accum for accum in accum_states ]
176
233
177
- if to_metric is not None :
178
- return to_metric (accum )
234
+ assert all (type (r ) is Tensor for r in accum_results ), (
235
+ "Accumulated metrics for attribution must be a Tensor,"
236
+ f"received: { next (r for r in accum_results if type (r ) is not Tensor )} "
237
+ )
179
238
180
- return accum
239
+ # shape(n_perturb * output_dims[0], *output_dims[1:])
240
+ # the underneath attr method needs to support forward_func output's
241
+ # 1st dim to grow with perturb_per_eval
242
+ batched_accum = torch .stack (accum_results , dim = 0 )
243
+ return batched_accum
181
244
182
245
def attribute (
183
246
self ,
@@ -187,7 +250,7 @@ def attribute(
187
250
feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
188
251
reduce : Optional [Callable ] = None ,
189
252
to_metric : Optional [Callable ] = None ,
190
- perturbation_per_pass : int = - 1 ,
253
+ perturbations_per_pass : int = 1 ,
191
254
show_progress : bool = False ,
192
255
return_input_shape : bool = True ,
193
256
) -> Union [Tensor , Tuple [Tensor , ...]]:
@@ -240,16 +303,17 @@ def attribute(
240
303
metric (Tensor): final result to be attributed, must be a Tensor
241
304
242
305
If None, will directly attribute w.r.t the reduced ``accum``
243
- perturbation_per_pass (int, optional
306
+ perturbations_per_pass (int, optional) the number perturbations to execute
244
307
concurrently in each traverse of the dataloader. The number of
245
- traverses is ceil(n_perturbations / perturbation_per_pass).
246
- The parameter offers a control of the trade-off between memory
308
+ traverses needed is
309
+ ceil(n_perturbations / perturbations_per_pass).
310
+
311
+ This arguement offers control of the trade-off between memory
247
312
and efficiency. If the dataloader involves slow operations like
248
313
remote request or file I/O, multiple traversals can be
249
- inefficient. Each perturbation needs to store its accumulated
250
- outputs of the reduce function until the end of the data
251
- traverse. If the value is -1, all perturbations are concurrent
252
- in a single traverse.
314
+ inefficient. On the other hand, each perturbation needs to
315
+ store its accumulated outputs of the reduce
316
+ function until the end of the data traverse.
253
317
return_input_shape (bool, optional): if True, returns the attribution
254
318
following the input shapes given by the dataloader.
255
319
Otherwise, returns a single tensor for the attributions of
@@ -352,14 +416,14 @@ def attribute(
352
416
# unique_attr in shape(*output_dims, n_features)
353
417
unique_attr = self .attr_method .attribute (
354
418
feature_indices ,
419
+ perturbations_per_eval = perturbations_per_pass ,
355
420
additional_forward_args = (
356
421
dataloader ,
357
422
input_roles ,
358
423
baselines ,
359
424
feature_mask ,
360
425
reduce ,
361
426
to_metric ,
362
- perturbation_per_pass ,
363
427
show_progress ,
364
428
feature_idx_to_mask_idx ,
365
429
),
0 commit comments