1
1
#!/usr/bin/env python3
2
2
3
- import math
4
3
from typing import Any , Callable , Generator , Tuple , Union
5
4
6
5
import torch
17
16
from captum .log import log_usage
18
17
19
18
20
- def combination (n : int , k : int ) -> int :
21
- try :
22
- # Combination only available in Python 3.8
23
- return math .comb (n , k ) # type: ignore
24
- except AttributeError :
25
- return math .factorial (n ) // math .factorial (k ) // math .factorial (n - k )
26
-
27
-
28
- def kernel_shap_similarity_kernel (
29
- _ , __ , interpretable_sample : Tensor , ** kwargs
30
- ) -> Tensor :
31
- assert (
32
- "num_interp_features" in kwargs
33
- ), "Must provide num_interp_features to use default similarity kernel"
34
- num_selected_features = int (interpretable_sample .sum (dim = 1 ).item ())
35
- num_features = kwargs ["num_interp_features" ]
36
- if num_selected_features == 0 or num_selected_features == num_features :
37
- # weight should be theoretically infinite when denom = 0
38
- # enforcing that trained linear model must satisfy
39
- # end-point criteria. In practice, it is sufficient to
40
- # make this weight substantially larger so setting this
41
- # weight to 1000000 (all other weights are 1).
42
- similarities = 1000000.0
43
- else :
44
- similarities = 1.0
45
- return torch .tensor ([similarities ])
46
-
47
-
48
- def kernel_shap_perturb_generator (
49
- original_inp , ** kwargs
50
- ) -> Generator [Tensor , None , None ]:
51
- assert "num_select_distribution" in kwargs and "num_interp_features" in kwargs , (
52
- "num_select_distribution and num_interp_features are necessary"
53
- " to use kernel_shap_perturb_func"
54
- )
55
- if isinstance (original_inp , Tensor ):
56
- device = original_inp .device
57
- else :
58
- device = original_inp [0 ].device
59
- num_features = kwargs ["num_interp_features" ]
60
- yield torch .ones (1 , num_features , device = device , dtype = torch .long )
61
- yield torch .zeros (1 , num_features , device = device , dtype = torch .long )
62
- while True :
63
- num_selected_features = kwargs ["num_select_distribution" ].sample ()
64
- rand_vals = torch .randn (1 , num_features )
65
- threshold = torch .kthvalue (
66
- rand_vals , num_features - num_selected_features
67
- ).values .item ()
68
- yield (rand_vals > threshold ).to (device = device ).long ()
69
-
70
-
71
19
class KernelShap (Lime ):
72
20
r"""
73
21
Kernel SHAP is a method that uses the LIME framework to compute
@@ -92,9 +40,10 @@ def __init__(self, forward_func: Callable) -> None:
92
40
self ,
93
41
forward_func ,
94
42
interpretable_model = SkLearnLinearRegression (),
95
- similarity_func = kernel_shap_similarity_kernel ,
96
- perturb_func = kernel_shap_perturb_generator ,
43
+ similarity_func = self . kernel_shap_similarity_kernel ,
44
+ perturb_func = self . kernel_shap_perturb_generator ,
97
45
)
46
+ self .inf_weight = 1000000.0
98
47
99
48
@log_usage ()
100
49
@lime_n_perturb_samples_deprecation_decorator
@@ -337,3 +286,46 @@ def attribute( # type: ignore
337
286
return_input_shape = return_input_shape ,
338
287
num_select_distribution = Categorical (probs ),
339
288
)
289
+
290
+ def kernel_shap_similarity_kernel (
291
+ self , _ , __ , interpretable_sample : Tensor , ** kwargs
292
+ ) -> Tensor :
293
+ assert (
294
+ "num_interp_features" in kwargs
295
+ ), "Must provide num_interp_features to use default similarity kernel"
296
+ num_selected_features = int (interpretable_sample .sum (dim = 1 ).item ())
297
+ num_features = kwargs ["num_interp_features" ]
298
+ if num_selected_features == 0 or num_selected_features == num_features :
299
+ # weight should be theoretically infinite when denom = 0
300
+ # enforcing that trained linear model must satisfy
301
+ # end-point criteria. In practice, it is sufficient to
302
+ # make this weight substantially larger so setting this
303
+ # weight to 1000000 (all other weights are 1).
304
+ similarities = self .inf_weight
305
+ else :
306
+ similarities = 1.0
307
+ return torch .tensor ([similarities ])
308
+
309
+ def kernel_shap_perturb_generator (
310
+ self , original_inp : Union [Tensor , Tuple [Tensor , ...]], ** kwargs
311
+ ) -> Generator [Tensor , None , None ]:
312
+ assert (
313
+ "num_select_distribution" in kwargs and "num_interp_features" in kwargs
314
+ ), (
315
+ "num_select_distribution and num_interp_features are necessary"
316
+ " to use kernel_shap_perturb_func"
317
+ )
318
+ if isinstance (original_inp , Tensor ):
319
+ device = original_inp .device
320
+ else :
321
+ device = original_inp [0 ].device
322
+ num_features = kwargs ["num_interp_features" ]
323
+ yield torch .ones (1 , num_features , device = device , dtype = torch .long )
324
+ yield torch .zeros (1 , num_features , device = device , dtype = torch .long )
325
+ while True :
326
+ num_selected_features = kwargs ["num_select_distribution" ].sample ()
327
+ rand_vals = torch .randn (1 , num_features )
328
+ threshold = torch .kthvalue (
329
+ rand_vals , num_features - num_selected_features
330
+ ).values .item ()
331
+ yield (rand_vals > threshold ).to (device = device ).long ()
0 commit comments