-
Notifications
You must be signed in to change notification settings - Fork 512
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
KernelSHAP / Lime Improvements #619
Conversation
vivekmig
commented
Feb 18, 2021
- Adds support for generators as perturb function for Lime with corresponding tests
- Modifies KernelSHAP to sample based on categorical distributed on expected selected features and randomly sample vectors given expected number of selected features. This is theoretically equivalent to the previous approach of weighting randomly selected vectors, but this approach computationally scales better with larger numbers of features, since weights for larger numbers of features lead to arithmetic underflow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First pass - mostly nits
captum/attr/_core/kernel_shap.py
Outdated
return torch.tensor([similarities]) | ||
|
||
|
||
def kernel_shap_perturb_generator( | ||
original_inp, **kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typehint? I assume original_inp
is Union[Tensor, Tuple[Tensor, ...]]
captum/attr/_core/kernel_shap.py
Outdated
# weight to 100 (all other weights are < 1). | ||
similarities = 100.0 | ||
# weight to 1000000 (all other weights are 1). | ||
similarities = 1000000.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doubt this would be a concern, but just incase we could add this as a default param to this method. With this, users can do a functools.partial
to change the value just incase it is not sufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's a good point that this could need to be customized. For now, to avoid having too many parameters, I can make this an instance variable that advanced users can override on the object after creation, but we can make it a parameter later if necessary.
captum/attr/_core/lime.py
Outdated
@@ -72,7 +73,7 @@ def __init__( | |||
forward_func: Callable, | |||
interpretable_model: Model, | |||
similarity_func: Callable, | |||
perturb_func: Callable, | |||
perturb_func: Union[Callable], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing typehint in the Union?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, thanks! Forgot to revert this change.
Thanks for the review @miguelmartin75 ! Addressed comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on this PR, @vivekmig!
I left couple nits. I think that it would be good to describe this trick a little bit in the code since the original approach in the paper is a bit different in terms of the kernel similarity function.
threshold = torch.kthvalue( | ||
rand_vals, num_features - num_selected_features | ||
).values.item() | ||
yield (rand_vals > threshold).to(device=device).long() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can we, please, describe a little bit why are we following this logic instead of the default behavior in default_perturb_func
:
captum/captum/attr/_core/lime.py
Line 612 in 6b71d66
def default_perturb_func(original_inp, **kwargs): |
I think default_perturb_func
is missing typehints too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will add more documentation on this. There are a few other helper methods in Lime without type hints, so will add them together in a separate PR.
Thanks for the review @NarineK ! Addressed comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
captum/attr/_core/kernel_shap.py
Outdated
Perturbations are sampled by the following process: | ||
- Choose k (number of selected features), based on the distribution | ||
p(k) = (M - 1) / (k * (M - k)) | ||
where M is the total number of features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: total number of features in the interpretable space ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the explanation! Looks great! Maybe you could add in the description that each of the (M choose k) samples has equal prob of getting chosen thus we do this:
rand_vals = torch.randn(1, num_features)
threshold = torch.kthvalue( ...
If I remember your explanation correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.