1
1
#!/usr/bin/env python3
2
+ import inspect
2
3
import math
3
4
import typing
4
5
import warnings
@@ -138,13 +139,18 @@ def __init__(
138
139
the original input space (matching type and tensor shapes
139
140
of original input) or in the interpretable input space,
140
141
which is a vector containing the intepretable features.
142
+ Alternatively, this function can return a generator
143
+ yielding samples to train the interpretable surrogate
144
+ model, and n_samples perturbations will be sampled
145
+ from this generator.
141
146
142
147
The expected signature of this callable is:
143
148
144
149
>>> perturb_func(
145
150
>>> original_input: Tensor or tuple of Tensors,
146
151
>>> **kwargs: Any
147
- >>> ) -> Tensor or tuple of Tensors
152
+ >>> ) -> Tensor or tuple of Tensors or
153
+ >>> generator yielding tensor or tuple of Tensors
148
154
149
155
All kwargs passed to the attribute method are
150
156
provided as keyword arguments (kwargs) to this callable.
@@ -411,8 +417,22 @@ def attribute(
411
417
curr_model_inputs = []
412
418
expanded_additional_args = None
413
419
expanded_target = None
420
+ perturb_generator = None
421
+ if inspect .isgeneratorfunction (self .perturb_func ):
422
+ perturb_generator = self .perturb_func (inputs , ** kwargs )
423
+ batch_count = 0
414
424
for _ in range (n_samples ):
415
- curr_sample = self .perturb_func (inputs , ** kwargs )
425
+ if perturb_generator :
426
+ try :
427
+ curr_sample = next (perturb_generator )
428
+ except StopIteration :
429
+ warnings .warn (
430
+ "Generator completed prior to given n_samples iterations!"
431
+ )
432
+ break
433
+ else :
434
+ curr_sample = self .perturb_func (inputs , ** kwargs )
435
+ batch_count += 1
416
436
if self .perturb_interpretable_space :
417
437
interpretable_inps .append (curr_sample )
418
438
curr_model_inputs .append (
@@ -481,7 +501,7 @@ def attribute(
481
501
dataset = TensorDataset (
482
502
combined_interp_inps , combined_outputs , combined_sim
483
503
)
484
- self .interpretable_model .fit (DataLoader (dataset , batch_size = n_samples ))
504
+ self .interpretable_model .fit (DataLoader (dataset , batch_size = batch_count ))
485
505
return self .interpretable_model .representation ()
486
506
487
507
def _evaluate_batch (
@@ -602,6 +622,31 @@ def default_perturb_func(original_inp, **kwargs):
602
622
return torch .bernoulli (probs ).to (device = device ).long ()
603
623
604
624
625
+ def construct_feature_mask (feature_mask , formatted_inputs ):
626
+ if feature_mask is None :
627
+ feature_mask , num_interp_features = _construct_default_feature_mask (
628
+ formatted_inputs
629
+ )
630
+ else :
631
+ feature_mask = _format_input (feature_mask )
632
+ min_interp_features = int (
633
+ min (torch .min (single_inp ).item () for single_inp in feature_mask )
634
+ )
635
+ if min_interp_features != 0 :
636
+ warnings .warn (
637
+ "Minimum element in feature mask is not 0, shifting indices to"
638
+ " start at 0."
639
+ )
640
+ feature_mask = tuple (
641
+ single_inp - min_interp_features for single_inp in feature_mask
642
+ )
643
+
644
+ num_interp_features = int (
645
+ max (torch .max (single_inp ).item () for single_inp in feature_mask ) + 1
646
+ )
647
+ return feature_mask , num_interp_features
648
+
649
+
605
650
class Lime (LimeBase ):
606
651
r"""
607
652
Lime is an interpretability method that trains an interpretable surrogate model
@@ -713,7 +758,7 @@ def __init__(
713
758
(integer, determined from feature mask).
714
759
perturb_func (optional, callable): Function which returns a single
715
760
sampled input, which is a binary vector of length
716
- num_interp_features.
761
+ num_interp_features, or a generator of such tensors .
717
762
718
763
This function is optional, the default function returns
719
764
a binary vector where each element is selected
@@ -726,6 +771,7 @@ def __init__(
726
771
>>> original_input: Tensor or tuple of Tensors,
727
772
>>> **kwargs: Any
728
773
>>> ) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
774
+ >>> or generator yielding such tensors
729
775
730
776
kwargs includes baselines, feature_mask, num_interp_features
731
777
(integer, determined from feature mask).
@@ -975,31 +1021,36 @@ def attribute( # type: ignore
975
1021
>>> # matching input shape.
976
1022
>>> attr = lime.attribute(input, target=1, feature_mask=feature_mask)
977
1023
"""
1024
+ return self ._attribute_kwargs (
1025
+ inputs = inputs ,
1026
+ baselines = baselines ,
1027
+ target = target ,
1028
+ additional_forward_args = additional_forward_args ,
1029
+ feature_mask = feature_mask ,
1030
+ n_samples = n_samples ,
1031
+ perturbations_per_eval = perturbations_per_eval ,
1032
+ return_input_shape = return_input_shape ,
1033
+ )
1034
+
1035
+ def _attribute_kwargs ( # type: ignore
1036
+ self ,
1037
+ inputs : TensorOrTupleOfTensorsGeneric ,
1038
+ baselines : BaselineType = None ,
1039
+ target : TargetType = None ,
1040
+ additional_forward_args : Any = None ,
1041
+ feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
1042
+ n_samples : int = 25 ,
1043
+ perturbations_per_eval : int = 1 ,
1044
+ return_input_shape : bool = True ,
1045
+ ** kwargs
1046
+ ) -> TensorOrTupleOfTensorsGeneric :
978
1047
is_inputs_tuple = _is_tuple (inputs )
979
1048
formatted_inputs , baselines = _format_input_baseline (inputs , baselines )
980
1049
bsz = formatted_inputs [0 ].shape [0 ]
981
1050
982
- if feature_mask is None :
983
- feature_mask , num_interp_features = _construct_default_feature_mask (
984
- formatted_inputs
985
- )
986
- else :
987
- feature_mask = _format_input (feature_mask )
988
- min_interp_features = int (
989
- min (torch .min (single_inp ).item () for single_inp in feature_mask )
990
- )
991
- if min_interp_features != 0 :
992
- warnings .warn (
993
- "Minimum element in feature mask is not 0, shifting indices to"
994
- " start at 0."
995
- )
996
- feature_mask = tuple (
997
- single_inp + min_interp_features for single_inp in feature_mask
998
- )
999
-
1000
- num_interp_features = int (
1001
- max (torch .max (single_inp ).item () for single_inp in feature_mask ) + 1
1002
- )
1051
+ feature_mask , num_interp_features = construct_feature_mask (
1052
+ feature_mask , formatted_inputs
1053
+ )
1003
1054
1004
1055
if num_interp_features > 10000 :
1005
1056
warnings .warn (
@@ -1051,6 +1102,7 @@ def attribute( # type: ignore
1051
1102
if is_inputs_tuple
1052
1103
else curr_feature_mask [0 ],
1053
1104
num_interp_features = num_interp_features ,
1105
+ ** kwargs
1054
1106
)
1055
1107
if return_input_shape :
1056
1108
output_list .append (
@@ -1087,6 +1139,7 @@ def attribute( # type: ignore
1087
1139
baselines = baselines if is_inputs_tuple else baselines [0 ],
1088
1140
feature_mask = feature_mask if is_inputs_tuple else feature_mask [0 ],
1089
1141
num_interp_features = num_interp_features ,
1142
+ ** kwargs
1090
1143
)
1091
1144
if return_input_shape :
1092
1145
return self ._convert_output_shape (
0 commit comments