Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

non-leaf tensor warning in some attribution algorithms #491

Closed
hal-314 opened this issue Oct 13, 2020 · 1 comment
Closed

non-leaf tensor warning in some attribution algorithms #491

hal-314 opened this issue Oct 13, 2020 · 1 comment
Assignees

Comments

@hal-314
Copy link

hal-314 commented Oct 13, 2020

🐛 Bug

I get the following warning when using the Saliency or InputXGradient attribution method but not with IntegratedGradients or GradientShap:

UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
  warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "

To Reproduce

Execute this:

import warnings
import torch
import torch.nn as nn

from captum.attr import Saliency, IntegratedGradients, InputXGradient
from captum.attr import configure_interpretable_embedding_layer

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.embedding = nn.Embedding(10,3)
        self.fc = nn.Linear(5+3, 2)

    def forward(self, x_cat, x):
        x_cat = self.embedding(x_cat)
        x = torch.cat([x, x_cat], dim=1)
        x = self.fc(x)
        return x

def sal(model, X, node_index):
    sal = Saliency(model)
    #sal = IntegratedGradients(model)
    #sal = InputXGradient(model)
    grads = sal.attribute(X, target=node_index)
    return grads


net = Net()

X_cont = torch.rand(1, 5)
X_cat = torch.randint(0,9, (1,))

X_cont.requires_grad = True

net(X_cat, X_cont)

#with torch.no_grad(): # <- Uncomment to remove warnings. 
X_cat_emb = net.embedding(X_cat)

X_cat_emb.requires_grad_()

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    emb_net = configure_interpretable_embedding_layer(net)

# It's the same if we compute embeddings from InterpretableEmbeddingBase like Bert tutorial
#X_cat_emb2 = emb_net.indices_to_embeddings(X_cat)
#X_cat_emb2.requires_grad_()


attr = sal(net, (X_cat_emb, X_cont), 0)
#attr = sal(net, (X_cat_emb2, X_cont), 0)

Expected behavior

All algorithms should behave consistently. I think, captum shouldn't raise the warnings or correct tutorials that use models with embeddings + add comment in the api docs.

Environment

Pytorch 1.6.0 + Captum 0.2 + Ubuntu 20.04

Additional context

I tested with Saliency, InputXGradient, IntegratedGradients and GradientShap gradient methods.
Finally, I think that this bug is similar to #421

facebook-github-bot pushed a commit that referenced this issue Jan 27, 2021
Summary:
This removes the resetting of grad attribute to zero, which is causing warnings as mentioned in #491 and #421 . Based on torch [documentation](https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad), resetting of grad is only needed when using torch.autograd.backward, which accumulates results into the grad attribute for leaf nodes. Since we only utilize torch.autograd.grad (with only_inputs always set to True), the gradients obtained in Captum are never actually accumulated into grad attributes, so resetting the attribute is not actually necessary.

This also adds a test to confirm that the grad attribute is not altered when gradients are utilized through Saliency.

Pull Request resolved: #597

Reviewed By: bilalsal

Differential Revision: D26079970

Pulled By: vivekmig

fbshipit-source-id: f7ccee02a17f66ee75e2176f1b328672b057dbfa
@NarineK
Copy link
Contributor

NarineK commented Jan 28, 2021

Fixed in the PR: #597

@NarineK NarineK closed this as completed Jan 28, 2021
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants