Skip to content

Commit c825327

Browse files
vivekmigfacebook-github-bot
authored andcommitted
KernelSHAP / Lime Improvements (#619)
Summary: * Adds support for generators as perturb function for Lime with corresponding tests * Modifies KernelSHAP to sample based on categorical distributed on expected selected features and randomly sample vectors given expected number of selected features. This is theoretically equivalent to the previous approach of weighting randomly selected vectors, but this approach computationally scales better with larger numbers of features, since weights for larger numbers of features lead to arithmetic underflow. Pull Request resolved: #619 Reviewed By: NarineK Differential Revision: D26505649 Pulled By: vivekmig fbshipit-source-id: 596ca849208cadf3165d2c39c9eb7889f78e9b2d
1 parent 7d21f58 commit c825327

File tree

4 files changed

+174
-67
lines changed

4 files changed

+174
-67
lines changed

captum/attr/_core/kernel_shap.py

+79-40
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,21 @@
11
#!/usr/bin/env python3
22

3-
import math
4-
from typing import Any, Callable, Tuple, Union
3+
from typing import Any, Callable, Generator, Tuple, Union
54

65
import torch
76
from torch import Tensor
7+
from torch.distributions.categorical import Categorical
88

99
from captum._utils.models.linear_model import SkLearnLinearRegression
1010
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
11-
from captum.attr._core.lime import Lime
12-
from captum.attr._utils.common import lime_n_perturb_samples_deprecation_decorator
11+
from captum.attr._core.lime import Lime, construct_feature_mask
12+
from captum.attr._utils.common import (
13+
_format_input_baseline,
14+
lime_n_perturb_samples_deprecation_decorator,
15+
)
1316
from captum.log import log_usage
1417

1518

16-
def combination(n: int, k: int) -> int:
17-
try:
18-
# Combination only available in Python 3.8
19-
return math.comb(n, k) # type: ignore
20-
except AttributeError:
21-
return math.factorial(n) // math.factorial(k) // math.factorial(n - k)
22-
23-
24-
def kernel_shap_similarity_kernel(
25-
_, __, interpretable_sample: Tensor, **kwargs
26-
) -> Tensor:
27-
assert (
28-
"num_interp_features" in kwargs
29-
), "Must provide num_interp_features to use default similarity kernel"
30-
num_selected_features = int(interpretable_sample.sum(dim=1).item())
31-
num_features = kwargs["num_interp_features"]
32-
combinations = combination(num_features, num_selected_features)
33-
denom = (
34-
combinations * num_selected_features * (num_features - num_selected_features)
35-
)
36-
if denom != 0:
37-
similarities = (num_features - 1) / denom
38-
else:
39-
# weight should be theoretically infinite when denom = 0
40-
# enforcing that trained linear model must satisfy
41-
# end-point criteria. In practice, it is sufficient to
42-
# make this weight substantially larger so setting this
43-
# weight to 100 (all other weights are < 1).
44-
similarities = 100.0
45-
return torch.tensor([similarities])
46-
47-
4819
class KernelShap(Lime):
4920
r"""
5021
Kernel SHAP is a method that uses the LIME framework to compute
@@ -68,9 +39,11 @@ def __init__(self, forward_func: Callable) -> None:
6839
Lime.__init__(
6940
self,
7041
forward_func,
71-
SkLearnLinearRegression(),
72-
kernel_shap_similarity_kernel,
42+
interpretable_model=SkLearnLinearRegression(),
43+
similarity_func=self.kernel_shap_similarity_kernel,
44+
perturb_func=self.kernel_shap_perturb_generator,
7345
)
46+
self.inf_weight = 1000000.0
7447

7548
@log_usage()
7649
@lime_n_perturb_samples_deprecation_decorator
@@ -294,8 +267,15 @@ def attribute( # type: ignore
294267
>>> # Computes KernelSHAP attributions with feature mask.
295268
>>> attr = ks.attribute(input, target=1, feature_mask=feature_mask)
296269
"""
297-
return Lime.attribute.__wrapped__(
298-
self,
270+
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
271+
feature_mask, num_interp_features = construct_feature_mask(
272+
feature_mask, formatted_inputs
273+
)
274+
num_features_list = torch.arange(num_interp_features, dtype=torch.float)
275+
denom = num_features_list * (num_interp_features - num_features_list)
276+
probs = (num_interp_features - 1) / denom
277+
probs[0] = 0.0
278+
return self._attribute_kwargs(
299279
inputs=inputs,
300280
baselines=baselines,
301281
target=target,
@@ -304,4 +284,63 @@ def attribute( # type: ignore
304284
n_samples=n_samples,
305285
perturbations_per_eval=perturbations_per_eval,
306286
return_input_shape=return_input_shape,
287+
num_select_distribution=Categorical(probs),
288+
)
289+
290+
def kernel_shap_similarity_kernel(
291+
self, _, __, interpretable_sample: Tensor, **kwargs
292+
) -> Tensor:
293+
assert (
294+
"num_interp_features" in kwargs
295+
), "Must provide num_interp_features to use default similarity kernel"
296+
num_selected_features = int(interpretable_sample.sum(dim=1).item())
297+
num_features = kwargs["num_interp_features"]
298+
if num_selected_features == 0 or num_selected_features == num_features:
299+
# weight should be theoretically infinite when
300+
# num_selected_features = 0 or num_features
301+
# enforcing that trained linear model must satisfy
302+
# end-point criteria. In practice, it is sufficient to
303+
# make this weight substantially larger so setting this
304+
# weight to 1000000 (all other weights are 1).
305+
similarities = self.inf_weight
306+
else:
307+
similarities = 1.0
308+
return torch.tensor([similarities])
309+
310+
def kernel_shap_perturb_generator(
311+
self, original_inp: Union[Tensor, Tuple[Tensor, ...]], **kwargs
312+
) -> Generator[Tensor, None, None]:
313+
r"""
314+
Perturbations are sampled by the following process:
315+
- Choose k (number of selected features), based on the distribution
316+
p(k) = (M - 1) / (k * (M - k))
317+
where M is the total number of features in the interpretable space
318+
- Randomly select a binary vector with k ones, each sample is equally
319+
likely. This is done by generating a random vector of normal
320+
values and thresholding based on the top k elements.
321+
322+
Since there are M choose k vectors with k ones, this weighted sampling
323+
is equivalent to applying the Shapley kernel for the sample weight,
324+
defined as:
325+
k(M, k) = (M - 1) / (k * (M - k) * (M choose k))
326+
"""
327+
assert (
328+
"num_select_distribution" in kwargs and "num_interp_features" in kwargs
329+
), (
330+
"num_select_distribution and num_interp_features are necessary"
331+
" to use kernel_shap_perturb_func"
307332
)
333+
if isinstance(original_inp, Tensor):
334+
device = original_inp.device
335+
else:
336+
device = original_inp[0].device
337+
num_features = kwargs["num_interp_features"]
338+
yield torch.ones(1, num_features, device=device, dtype=torch.long)
339+
yield torch.zeros(1, num_features, device=device, dtype=torch.long)
340+
while True:
341+
num_selected_features = kwargs["num_select_distribution"].sample()
342+
rand_vals = torch.randn(1, num_features)
343+
threshold = torch.kthvalue(
344+
rand_vals, num_features - num_selected_features
345+
).values.item()
346+
yield (rand_vals > threshold).to(device=device).long()

captum/attr/_core/lime.py

+78-25
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
import inspect
23
import math
34
import typing
45
import warnings
@@ -138,13 +139,18 @@ def __init__(
138139
the original input space (matching type and tensor shapes
139140
of original input) or in the interpretable input space,
140141
which is a vector containing the intepretable features.
142+
Alternatively, this function can return a generator
143+
yielding samples to train the interpretable surrogate
144+
model, and n_samples perturbations will be sampled
145+
from this generator.
141146
142147
The expected signature of this callable is:
143148
144149
>>> perturb_func(
145150
>>> original_input: Tensor or tuple of Tensors,
146151
>>> **kwargs: Any
147-
>>> ) -> Tensor or tuple of Tensors
152+
>>> ) -> Tensor or tuple of Tensors or
153+
>>> generator yielding tensor or tuple of Tensors
148154
149155
All kwargs passed to the attribute method are
150156
provided as keyword arguments (kwargs) to this callable.
@@ -411,8 +417,22 @@ def attribute(
411417
curr_model_inputs = []
412418
expanded_additional_args = None
413419
expanded_target = None
420+
perturb_generator = None
421+
if inspect.isgeneratorfunction(self.perturb_func):
422+
perturb_generator = self.perturb_func(inputs, **kwargs)
423+
batch_count = 0
414424
for _ in range(n_samples):
415-
curr_sample = self.perturb_func(inputs, **kwargs)
425+
if perturb_generator:
426+
try:
427+
curr_sample = next(perturb_generator)
428+
except StopIteration:
429+
warnings.warn(
430+
"Generator completed prior to given n_samples iterations!"
431+
)
432+
break
433+
else:
434+
curr_sample = self.perturb_func(inputs, **kwargs)
435+
batch_count += 1
416436
if self.perturb_interpretable_space:
417437
interpretable_inps.append(curr_sample)
418438
curr_model_inputs.append(
@@ -481,7 +501,7 @@ def attribute(
481501
dataset = TensorDataset(
482502
combined_interp_inps, combined_outputs, combined_sim
483503
)
484-
self.interpretable_model.fit(DataLoader(dataset, batch_size=n_samples))
504+
self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count))
485505
return self.interpretable_model.representation()
486506

487507
def _evaluate_batch(
@@ -602,6 +622,31 @@ def default_perturb_func(original_inp, **kwargs):
602622
return torch.bernoulli(probs).to(device=device).long()
603623

604624

625+
def construct_feature_mask(feature_mask, formatted_inputs):
626+
if feature_mask is None:
627+
feature_mask, num_interp_features = _construct_default_feature_mask(
628+
formatted_inputs
629+
)
630+
else:
631+
feature_mask = _format_input(feature_mask)
632+
min_interp_features = int(
633+
min(torch.min(single_inp).item() for single_inp in feature_mask)
634+
)
635+
if min_interp_features != 0:
636+
warnings.warn(
637+
"Minimum element in feature mask is not 0, shifting indices to"
638+
" start at 0."
639+
)
640+
feature_mask = tuple(
641+
single_inp - min_interp_features for single_inp in feature_mask
642+
)
643+
644+
num_interp_features = int(
645+
max(torch.max(single_inp).item() for single_inp in feature_mask) + 1
646+
)
647+
return feature_mask, num_interp_features
648+
649+
605650
class Lime(LimeBase):
606651
r"""
607652
Lime is an interpretability method that trains an interpretable surrogate model
@@ -713,7 +758,7 @@ def __init__(
713758
(integer, determined from feature mask).
714759
perturb_func (optional, callable): Function which returns a single
715760
sampled input, which is a binary vector of length
716-
num_interp_features.
761+
num_interp_features, or a generator of such tensors.
717762
718763
This function is optional, the default function returns
719764
a binary vector where each element is selected
@@ -726,6 +771,7 @@ def __init__(
726771
>>> original_input: Tensor or tuple of Tensors,
727772
>>> **kwargs: Any
728773
>>> ) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
774+
>>> or generator yielding such tensors
729775
730776
kwargs includes baselines, feature_mask, num_interp_features
731777
(integer, determined from feature mask).
@@ -975,31 +1021,36 @@ def attribute( # type: ignore
9751021
>>> # matching input shape.
9761022
>>> attr = lime.attribute(input, target=1, feature_mask=feature_mask)
9771023
"""
1024+
return self._attribute_kwargs(
1025+
inputs=inputs,
1026+
baselines=baselines,
1027+
target=target,
1028+
additional_forward_args=additional_forward_args,
1029+
feature_mask=feature_mask,
1030+
n_samples=n_samples,
1031+
perturbations_per_eval=perturbations_per_eval,
1032+
return_input_shape=return_input_shape,
1033+
)
1034+
1035+
def _attribute_kwargs( # type: ignore
1036+
self,
1037+
inputs: TensorOrTupleOfTensorsGeneric,
1038+
baselines: BaselineType = None,
1039+
target: TargetType = None,
1040+
additional_forward_args: Any = None,
1041+
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
1042+
n_samples: int = 25,
1043+
perturbations_per_eval: int = 1,
1044+
return_input_shape: bool = True,
1045+
**kwargs
1046+
) -> TensorOrTupleOfTensorsGeneric:
9781047
is_inputs_tuple = _is_tuple(inputs)
9791048
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
9801049
bsz = formatted_inputs[0].shape[0]
9811050

982-
if feature_mask is None:
983-
feature_mask, num_interp_features = _construct_default_feature_mask(
984-
formatted_inputs
985-
)
986-
else:
987-
feature_mask = _format_input(feature_mask)
988-
min_interp_features = int(
989-
min(torch.min(single_inp).item() for single_inp in feature_mask)
990-
)
991-
if min_interp_features != 0:
992-
warnings.warn(
993-
"Minimum element in feature mask is not 0, shifting indices to"
994-
" start at 0."
995-
)
996-
feature_mask = tuple(
997-
single_inp + min_interp_features for single_inp in feature_mask
998-
)
999-
1000-
num_interp_features = int(
1001-
max(torch.max(single_inp).item() for single_inp in feature_mask) + 1
1002-
)
1051+
feature_mask, num_interp_features = construct_feature_mask(
1052+
feature_mask, formatted_inputs
1053+
)
10031054

10041055
if num_interp_features > 10000:
10051056
warnings.warn(
@@ -1051,6 +1102,7 @@ def attribute( # type: ignore
10511102
if is_inputs_tuple
10521103
else curr_feature_mask[0],
10531104
num_interp_features=num_interp_features,
1105+
**kwargs
10541106
)
10551107
if return_input_shape:
10561108
output_list.append(
@@ -1087,6 +1139,7 @@ def attribute( # type: ignore
10871139
baselines=baselines if is_inputs_tuple else baselines[0],
10881140
feature_mask=feature_mask if is_inputs_tuple else feature_mask[0],
10891141
num_interp_features=num_interp_features,
1142+
**kwargs
10901143
)
10911144
if return_input_shape:
10921145
return self._convert_output_shape(

tests/attr/test_kernel_shap.py

+1
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def test_multi_input_batch_kernel_shap(self) -> None:
197197
expected,
198198
additional_input=(1,),
199199
feature_mask=(mask1, mask2, mask3),
200+
n_perturb_samples=300,
200201
)
201202
expected_with_baseline = (
202203
[[1040, 1040, 1040], [184, 580.0, 184]],

0 commit comments

Comments
 (0)