-
Notifications
You must be signed in to change notification settings - Fork 513
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Internal Batching Improvements #333
Conversation
…eInternalBatching
@@ -53,6 +53,7 @@ def __init__(self, forward_func: Callable) -> None: | |||
modification of it | |||
""" | |||
GradientAttribution.__init__(self, forward_func) | |||
self.predefined_step_size_alphas = None # used for internal batching |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be dangerous if we call attribute multiple times for different inputs for different n_steps
. Does it have to be an instance variable ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This attribute is only modified in _batched_attribution
for recursive sub-calls and always reset to None after each operation, so it shouldn't affect any separate attribution calls. The other option is to add another argument to attribute that is only used for this, but it could be confusing to users looking at the method signature, so I preferred this approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that if this gets executed in a multi-threaded environment there can be some issues...
One way to avoid it could be to move the core logic of ig into an auxiliary function such as: _attribute
that takes, in addition, start
and end
kwargs. _batched_attribution
can call _attribute
instead of attribute
?
https://github.com/pytorch/captum/pull/333/files#diff-67ff19e8dcc2510648379db71c2ee9fbR275
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I can switch to that, it would avoid adding a user-facing parameter which would be good! Thanks for the suggestion!
I think calling attribute from different threads on the same attribution object isn't broadly supported by our methods, since we have other cases where we store temp info such as hook lists as an object attribute, which would have the same issue. But I agree that approach is cleaner, will update to it, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you :) yes, that's a good point. Ideally we need to make those hook holders local variables as well. IG is in a more sophisticated shape and more likely to be used than others. We can keep this in mind and make necessary adjustments in other approaches as well.
@@ -43,7 +43,7 @@ commands: | |||
steps: | |||
- run: | |||
name: "Check import order with isort" | |||
command: isort --check-only | |||
command: isort --check-only -v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switching this to verbose to have more details on incorrect orderings in CircleCI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thank you! Left some comments.
Instead of asserts for the internal_batch_size
it might be good to default it and show warning messages.
captum/attr/_utils/batching.py
Outdated
step_count = internal_batch_size // num_examples | ||
assert ( | ||
step_count > 0 | ||
), "Internal batch size must be at least equal to the number of input examples." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be better to to do: step_count = max(1, internal_batch_size // num_examples)
and show a warning to the user when internal_batch_size
is too small ? This might be a better user experience ?
captum/attr/_utils/batching.py
Outdated
if include_endpoint: | ||
assert ( | ||
step_count > 1 | ||
), "Internal batch size must be at least twice the number of input examples." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps explaining why
here would be better. Also, perhaps defaulting like the previous suggestion would be better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean that if the n_steps=1
then this will always fail because internal_batch_size
can only be 1 in that case ?
Perhaps making a recommendation message would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that will fail (and probably should), since it would've also failed in the implementation (LayerConductance) where this argument is used, because we take the difference between consecutive steps to estimate gradient, and wouldn't be able to do this with 1 step.
@@ -1,5 +1,5 @@ | |||
#!/usr/bin/env python3 | |||
from typing import List, NamedTuple, Optional, Tuple, Dict | |||
from typing import Dict, List, NamedTuple, Optional, Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixing isort failure on master
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: This PR improves internal batching to avoid creating an expanded tensor initially and simply recursively compute each subset of steps and add the resulting attributions. This adds the extra condition that internal_batch_size must be at least equal to the number of examples, and the documentation has been updated appropriately. Pull Request resolved: pytorch#333 Reviewed By: NarineK Differential Revision: D20796123 Pulled By: vivekmig fbshipit-source-id: 78931a86b0e092ec0b257793aa6e20aadc081947
This PR improves internal batching to avoid creating an expanded tensor initially and simply recursively compute each subset of steps and add the resulting attributions.
This adds the extra condition that internal_batch_size must be at least equal to the number of examples, and the documentation has been updated appropriately.