Skip to content

Commit 6b71d66

Browse files
committed
Fixes
1 parent abde859 commit 6b71d66

File tree

2 files changed

+47
-55
lines changed

2 files changed

+47
-55
lines changed

captum/attr/_core/kernel_shap.py

+46-54
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python3
22

3-
import math
43
from typing import Any, Callable, Generator, Tuple, Union
54

65
import torch
@@ -17,57 +16,6 @@
1716
from captum.log import log_usage
1817

1918

20-
def combination(n: int, k: int) -> int:
21-
try:
22-
# Combination only available in Python 3.8
23-
return math.comb(n, k) # type: ignore
24-
except AttributeError:
25-
return math.factorial(n) // math.factorial(k) // math.factorial(n - k)
26-
27-
28-
def kernel_shap_similarity_kernel(
29-
_, __, interpretable_sample: Tensor, **kwargs
30-
) -> Tensor:
31-
assert (
32-
"num_interp_features" in kwargs
33-
), "Must provide num_interp_features to use default similarity kernel"
34-
num_selected_features = int(interpretable_sample.sum(dim=1).item())
35-
num_features = kwargs["num_interp_features"]
36-
if num_selected_features == 0 or num_selected_features == num_features:
37-
# weight should be theoretically infinite when denom = 0
38-
# enforcing that trained linear model must satisfy
39-
# end-point criteria. In practice, it is sufficient to
40-
# make this weight substantially larger so setting this
41-
# weight to 1000000 (all other weights are 1).
42-
similarities = 1000000.0
43-
else:
44-
similarities = 1.0
45-
return torch.tensor([similarities])
46-
47-
48-
def kernel_shap_perturb_generator(
49-
original_inp, **kwargs
50-
) -> Generator[Tensor, None, None]:
51-
assert "num_select_distribution" in kwargs and "num_interp_features" in kwargs, (
52-
"num_select_distribution and num_interp_features are necessary"
53-
" to use kernel_shap_perturb_func"
54-
)
55-
if isinstance(original_inp, Tensor):
56-
device = original_inp.device
57-
else:
58-
device = original_inp[0].device
59-
num_features = kwargs["num_interp_features"]
60-
yield torch.ones(1, num_features, device=device, dtype=torch.long)
61-
yield torch.zeros(1, num_features, device=device, dtype=torch.long)
62-
while True:
63-
num_selected_features = kwargs["num_select_distribution"].sample()
64-
rand_vals = torch.randn(1, num_features)
65-
threshold = torch.kthvalue(
66-
rand_vals, num_features - num_selected_features
67-
).values.item()
68-
yield (rand_vals > threshold).to(device=device).long()
69-
70-
7119
class KernelShap(Lime):
7220
r"""
7321
Kernel SHAP is a method that uses the LIME framework to compute
@@ -92,9 +40,10 @@ def __init__(self, forward_func: Callable) -> None:
9240
self,
9341
forward_func,
9442
interpretable_model=SkLearnLinearRegression(),
95-
similarity_func=kernel_shap_similarity_kernel,
96-
perturb_func=kernel_shap_perturb_generator,
43+
similarity_func=self.kernel_shap_similarity_kernel,
44+
perturb_func=self.kernel_shap_perturb_generator,
9745
)
46+
self.inf_weight = 1000000.0
9847

9948
@log_usage()
10049
@lime_n_perturb_samples_deprecation_decorator
@@ -337,3 +286,46 @@ def attribute( # type: ignore
337286
return_input_shape=return_input_shape,
338287
num_select_distribution=Categorical(probs),
339288
)
289+
290+
def kernel_shap_similarity_kernel(
291+
self, _, __, interpretable_sample: Tensor, **kwargs
292+
) -> Tensor:
293+
assert (
294+
"num_interp_features" in kwargs
295+
), "Must provide num_interp_features to use default similarity kernel"
296+
num_selected_features = int(interpretable_sample.sum(dim=1).item())
297+
num_features = kwargs["num_interp_features"]
298+
if num_selected_features == 0 or num_selected_features == num_features:
299+
# weight should be theoretically infinite when denom = 0
300+
# enforcing that trained linear model must satisfy
301+
# end-point criteria. In practice, it is sufficient to
302+
# make this weight substantially larger so setting this
303+
# weight to 1000000 (all other weights are 1).
304+
similarities = self.inf_weight
305+
else:
306+
similarities = 1.0
307+
return torch.tensor([similarities])
308+
309+
def kernel_shap_perturb_generator(
310+
self, original_inp: Union[Tensor, Tuple[Tensor, ...]], **kwargs
311+
) -> Generator[Tensor, None, None]:
312+
assert (
313+
"num_select_distribution" in kwargs and "num_interp_features" in kwargs
314+
), (
315+
"num_select_distribution and num_interp_features are necessary"
316+
" to use kernel_shap_perturb_func"
317+
)
318+
if isinstance(original_inp, Tensor):
319+
device = original_inp.device
320+
else:
321+
device = original_inp[0].device
322+
num_features = kwargs["num_interp_features"]
323+
yield torch.ones(1, num_features, device=device, dtype=torch.long)
324+
yield torch.zeros(1, num_features, device=device, dtype=torch.long)
325+
while True:
326+
num_selected_features = kwargs["num_select_distribution"].sample()
327+
rand_vals = torch.randn(1, num_features)
328+
threshold = torch.kthvalue(
329+
rand_vals, num_features - num_selected_features
330+
).values.item()
331+
yield (rand_vals > threshold).to(device=device).long()

captum/attr/_core/lime.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
forward_func: Callable,
7474
interpretable_model: Model,
7575
similarity_func: Callable,
76-
perturb_func: Union[Callable],
76+
perturb_func: Callable,
7777
perturb_interpretable_space: bool,
7878
from_interp_rep_transform: Optional[Callable],
7979
to_interp_rep_transform: Optional[Callable],

0 commit comments

Comments
 (0)