-
Notifications
You must be signed in to change notification settings - Fork 513
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Fix layer_gradient_x_activation and add logging for metrics #643
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this 👍 ! Just one question / comment.
@@ -170,7 +170,10 @@ def attribute( | |||
additional_forward_args = _format_additional_forward_args( | |||
additional_forward_args | |||
) | |||
gradient_mask = apply_gradient_requirements(inputs) | |||
|
|||
if inputs[0].is_floating_point() or inputs[0].is_complex(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One question on this, it seems like apply_gradient_requirements
already checks input types and avoids enabling grads for any int / long inputs, is there some case that's not covered by that check?
Also, alternatively, it might make sense to remove the gradient requirement here and instead set requires grad on only the layer inputs / outputs within the forward hook in _forward_layer_distributed_eval based on a flag, since we only need gradients starting from the target layer. This will avoid requiring gradients before the target layer, and should work within torch.no_grad even if the original inputs are integers as long as the target layer can require gradients. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, thank you. I was thinking that we might have done that check already. apply_gradient_requirements
does the check but it occasionally doesn't pass. Now, I cannot reproduce it but an external user pointed it out to me (later it went away for him) and I was able to reproduce it at that moment but right now I can't reproduce it. It might have been another pytorch version. Since apply_gradient_requirements
has it, there is no need to do the check again but we'll keep an eye on this. It is likely that it will reoccur and we will investigate it further.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, sounds good! One issue I can imagine would be in cases like the example below, this would currently fail since the gradient requirement wouldn't be set on integer inputs. Is this similar to the issue you had in mind?
from captum.attr import LayerGradientXActivation
import torch.nn as nn
class TestModuleSingleEmbedding(nn.Module):
def __init__(
self, vocab_size: int = 1000, emb_dim: int = 128
):
super().__init__()
self.embed = nn.Embedding(vocab_size, emb_dim)
def forward(self, idx):
return torch.sum(self.embed(idx)).reshape(1,)
mod = TestModuleSingleEmbedding()
for param in mod.parameters():
param.requires_grad = False
lga = LayerGradientXActivation(mod, mod.embed)
idx = torch.tensor([1,2,3,4])
attr = lga.attribute(idx)
Setting the gradient requirement to within the forward hook in _forward_layer_distributed_eval should likely resolve this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that the error was related to input data type but I've seen this one too. This one is because tensors do not carry gradients and resulting output has no grad_fn. I've tried couple things for that as well. Where do you recommend to set 'gradient requirement' ? _forward_layer_distributed_eval
is called in enabled gradient context. If forward hook returns an output that has gradients then that will lead to model output with gradients but torch sees it as unused in the computation graph. Or we need to iterate and set grads true in mod.parameters() if forward function is a model.
I think I saw this issue if we use torch.no_grad() context as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense! The fix I had in mind is this #647 . I think this should be sufficient for any layer method that needs grads with respect to a given layer input / output, and will also avoid enabling grads between the input and target layer if unnecessary. I don't think this should have any issues with unused tensors in the computation graph; I think that should only occur if the inputs provided to torch.autograd.grad were not used in the compute graph for the outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR. I think the problem that I was seeing is that the output has no grad_fn
.
captum/captum/_utils/gradient.py
Line 644 in 1b19bcb
saved_grads = torch.autograd.grad(torch.unbind(output), grad_inputs) |
In
#647
it is setting the grads to the inputs w.r.t. which the gradients will be computed, right ? But I was thinking that if output has no grad_fn
function then it can cause errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, yeah I think the error is generally something like this for the output:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
I think for any tensor operation, if any input requires gradients, then the output requires gradients and has a corresponding grad_fn. In this case, if the inputs are not floating-point, then the inputs never actually require gradients, so if either the context is no_grad or the parameters don't require gradients, the output has no grad function, causing this issue. By requiring grads on the target layer, gradients should be enabled on the output as well, resolving this issue.
@NarineK has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
apply_gradient_requirements
otherwise if attribute is called intorch.no_grad()
context then it fails.