From ad4212ca6c04fb421a3d0e8b4b78355be40dfcaf Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 23 Dec 2020 10:10:57 -0800 Subject: [PATCH] Lime Type Fixes Summary: This makes Lime work appropriately with int / long features; currently input only worked appropriately with float features. Differential Revision: D25693888 fbshipit-source-id: d1da7fccb405c4b86780d3f8a1b2835568f38d21 --- captum/attr/_core/lime.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index e28964fbd..5292f5e69 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -527,15 +527,15 @@ def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs): ), "Must provide baselines to use default interpretable representation transfrom" feature_mask = kwargs["feature_mask"] if isinstance(feature_mask, Tensor): - binary_mask = curr_sample[0][feature_mask] + binary_mask = curr_sample[0][feature_mask].to(original_inputs.dtype) return binary_mask * original_inputs + (1 - binary_mask) * kwargs["baselines"] else: binary_mask = tuple( curr_sample[0][feature_mask[j]] for j in range(len(feature_mask)) ) return tuple( - binary_mask[j] * original_inputs[j] - + (1 - binary_mask[j]) * kwargs["baselines"][j] + binary_mask[j].to(original_inputs[j].dtype) * original_inputs[j] + + (1 - binary_mask[j].to(original_inputs[j].dtype)) * kwargs["baselines"][j] for j in range(len(feature_mask)) ) @@ -575,8 +575,8 @@ def get_exp_kernel_similarity_function( """ def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs): - flattened_original_inp = _flatten_tensor_or_tuple(original_inp) - flattened_perturbed_inp = _flatten_tensor_or_tuple(perturbed_inp) + flattened_original_inp = _flatten_tensor_or_tuple(original_inp).float() + flattened_perturbed_inp = _flatten_tensor_or_tuple(perturbed_inp).float() if distance_mode == "cosine": cos_sim = CosineSimilarity(dim=0) distance = 1 - cos_sim(flattened_original_inp, flattened_perturbed_inp) @@ -599,7 +599,7 @@ def default_perturb_func(original_inp, **kwargs): device = original_inp[0].device probs = torch.ones(1, kwargs["num_interp_features"]) * 0.5 - return torch.bernoulli(probs).to(device=device) + return torch.bernoulli(probs).to(device=device).long() class Lime(LimeBase): @@ -1130,7 +1130,10 @@ def _convert_output_shape( is_inputs_tuple: bool, ) -> Union[Tensor, Tuple[Tensor, ...]]: coefs = coefs.flatten() - attr = [torch.zeros_like(single_inp) for single_inp in formatted_inp] + attr = [ + torch.zeros_like(single_inp, dtype=torch.float) + for single_inp in formatted_inp + ] for tensor_ind in range(len(formatted_inp)): for single_feature in range(num_interp_features): attr[tensor_ind] += (