Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

using captum (integrated gradients) for the CIFAR10 in greyscale as the input for training a resnet18 #378

Closed
mrdupadupa opened this issue May 10, 2020 · 12 comments
Assignees

Comments

@mrdupadupa
Copy link

Hi all,
I have an issue regarding using Captum for Grayscale CIFAR10 with the ResNet18.
I used the example from the tutorial: "Interpreting vision with CIFAR"
However, I have errors during the execution, which I could not solve:

  • for DeepLift I have:
    A Module ReLU(inplace=True) was detected that does not contain some of the input/output attributes that are required for DeepLift computations. This can occur, for example, if your module is being used more than once in the network. Please, ensure that module is being used only once in the network.

Does It mean that for the ResNet I could not use DeepLift method because I'm using ReLU not a once?

  • for the visualisation part with methods Overlayed Gradient Magnitudes, Overlayed Integrated Gradients, etc(exactly like in the tutorial). I have the error:
    Invalid shape (32, 32, 1) for image data

  • for the compute attributions using Integrated Gradients and visualize them on the image
    attributions_ig = integrated_gradients.attribute(input, target = pred_label_idx, n_steps=100)
    I have the error: tuple index out of range

I am pretty sure that the errors connected with the "greyscale" attribute of my Cifar dataset(for the RGB it works fine). But I don't know what to change in captum code to adapt it to my data.

Here is the code:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Changing the transform argument for augmentation
transform_trainset = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),    
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))])
transform_testset = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))])
trainset = torchvision.datasets.CIFAR10(root='/home/andrei/Study/master_thesis/data', train=True, download=True, transform=transform_trainset)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='/home/andrei/Study/master_thesis/data', train=False, download=True, transform=transform_testset)
testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=True, num_workers=4)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#Define a Convolutional Neural Network
class MyResNet(nn.Module):
    def __init__(self, in_channels=1):
        super(MyResNet, self).__init__()
        self.model = torchvision.models.resnet18()
        self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
    def forward(self, x):
        return self.model(x)

my_resnet = MyResNet()

input = torch.randn((8,1,32,32))
output = my_resnet(input)
print(output.shape)
net = my_resnet
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
dataiter = iter(trainloader)
images, labels = dataiter.next()
#Train the network
for epoch in range(1):

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')
#load some images from the test dataset and perform predictions
def imshow(img, one_channel=True):
    img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(npimg, cmap="Greys")
dataiter = iter(testloader)
images, labels = dataiter.next()
# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(8)))
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(8)))
#choose a test image at index ind and apply some of our attribution algorithms on it.
ind = 6
input = images[ind].unsqueeze(0)
input.requires_grad = True
net.eval()
def attribute_image_features(algorithm, input, **kwargs):
    net.zero_grad()
    tensor_attributions = algorithm.attribute(input,
                                              target=labels[ind],
                                              **kwargs
                                             )
    return tensor_attributions

saliency = Saliency(net)
grads = saliency.attribute(input, target=labels[ind].item())
grads = np.transpose(grads.squeeze().cpu().detach().numpy())
ig = IntegratedGradients(net)
attr_ig, delta = attribute_image_features(ig, input, baselines=input * 0, return_convergence_delta=True)
attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy())
print('Approximation delta: ', abs(delta))
#use integrated gradients and noise tunnel with smoothgrad square option on the test image

In [18]:

ig = IntegratedGradients(net)
ig = IntegratedGradients(net)
nt = NoiseTunnel(ig)
attr_ig_nt = attribute_image_features(nt, input, baselines=input * 0, nt_type='smoothgrad_sq',
                                      n_samples=100, stdevs=0.2)
attr_ig_nt = np.transpose(attr_ig_nt.squeeze(0).cpu().detach().numpy())
#Applies DeepLift on test image
dl = DeepLift(net)
attr_dl = attribute_image_features(dl, input, baselines=input * 0)
attr_dl = np.transpose(attr_dl.squeeze(0).cpu().detach().numpy())
#visualize the attributions for Saliency Maps, DeepLift, Integrated Gradients and Integrated Gradients with SmoothGrad
print('Original Image')
print('Predicted:', classes[predicted[ind]], 
      ' Probability:', torch.max(F.softmax(outputs, 1)).item())

original_image = np.transpose((images[ind].cpu().detach().numpy() / 2) + 0.5)

_ = viz.visualize_image_attr(None, original_image, 
                      method="original_image", title="Original Image")

_ = viz.visualize_image_attr(grads, original_image, method="blended_heat_map", sign="absolute_value",
                          show_colorbar=True, title="Overlayed Gradient Magnitudes")

_ = viz.visualize_image_attr(attr_ig, original_image, method="blended_heat_map",sign="all",
                          show_colorbar=True, title="Overlayed Integrated Gradients")

_ = viz.visualize_image_attr(attr_ig_nt, original_image, method="blended_heat_map", sign="absolute_value", 
                             outlier_perc=10, show_colorbar=True, 
                             title="Overlayed Integrated Gradients \n with SmoothGrad Squared")

_ = viz.visualize_image_attr(attr_dl, original_image, method="blended_heat_map",sign="all",show_colorbar=True, 
                          title="Overlayed DeepLift")
#compute attributions using Integrated Gradients and visualize them on the image.
integrated_gradients = ig
pred_label_idx = predicted[ind]
attributions_ig = integrated_gradients.attribute(input, target = pred_label_idx, n_steps=100)
transformed_img = input

@psteinb

@bilalsal bilalsal self-assigned this May 11, 2020
@bilalsal
Copy link
Contributor

bilalsal commented May 11, 2020

Hi @mrdupadupa, thank you for bringing this up!
We will fix the visualization module to automatically handle gray-scale images.

As a quick fix you can manually reshape the visualization input to fit with what the methods expect, as illustrated in the example below.

  • The "original_image" method expects a 224 x 224 input for the grayscale image.
  • The "blended_heat_map" mode expects a 224 x 224 x 1 input the grayscale image to match the computed attribution.
original_image = np.zeros([224, 224, 1])
attr = ig.attribute(torch.zeros([1, 1, 224, 224]), target=1) # returns a [1, 1, 224, 224] tensor
attr = attr.squeeze().unsqueeze(2).cpu().detach().numpy() # returns a [224, 224, 1] numpy array

_ = viz.visualize_image_attr(attr, original_image, method="blended_heat_map",sign="all",
                          show_colorbar=True, title="Overlayed Integrated Gradients")

_ = viz.visualize_image_attr_multiple(attr,
                                      np.squeeze(original_image),
                                      ["original_image", "heat_map"],
                                      ["all", "absolute_value"],
                                      show_colorbar=True)

The DeepLift error seems related to your model and not to the fact that your input is grayscale (I could reproduce it with in_channels = 3). @vivekmig could you have a look?

Hope this helps

@vivekmig
Copy link
Contributor

Hi @mrdupadupa , the issue with DeepLift seems to be due to repeated use of relu in the torchvision ResNet basic block. You can resolve this by copying the ResNet source from here and just replacing the existing BasicBlock with the modifications shown below:

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

        # Added another relu here
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity

        # Modified to use relu2
        out = self.relu2(out)

        return out

@mrdupadupa
Copy link
Author

Hi, @bilalsal @vivekmig Thank you for your answers. It helps a lot.

@psteinb
Copy link

psteinb commented May 15, 2020

@vivekmig can you shed some light on why DeepLift has issues with a reused layer in the network?

@vivekmig
Copy link
Contributor

Hi @psteinb , sure, this limitation is particular to the current implementation and essentially because intermediate activations for both baselines and inputs need to be stored in the forward pass and used to override the gradient in the backward pass. Currently, those activations are stored as (temporary) attributes on the corresponding modules themselves, so a reused activation causes overwriting the stored temporary attributes. (A similar issue also indirectly affects layer and neuron attribution methods in Captum, they always attribute with respect to the last execution of a reused module.)

In most cases, we may be able to get around this issue with refactoring the implementation to allow storing multiple activations for a single module by keying on the execution count, essentially by separately storing the activations for the 1st, 2nd, 3rd, etc. time the module is executed. We haven't yet worked on this refactor, but we will consider if we can prioritize it for future releases if this is a common issue.

@NarineK
Copy link
Contributor

NarineK commented May 24, 2020

@psteinb, in a more broader context PyTorch currently does not tell us where exactly in the computation graph an operator is executed. Forward and backward hooks do not provide that information. This affects all hooks including all layer and neuron attributions. In the latter cases you'll receive the attribution with respect to last execution of that hook only.

There are some plans on expanding PyTorch to be able to give access to that information. JIT gives information about graph structure but, unfortunately, the hooks aren't currently supported there but there are some folks working on it.

One way of trying to solve that issue, as Vivek mentioned, is to count how many times a hook gets hit but it can be hack-y(not elegant) to implement and it might be easier to redefine an activation instead of reusing it. It's pretty straightforward to do in PyTorch.
I don't think that a refactoring is needed at this point and a more elegant way to solve the problem is to align it with the extended / improved functionality of PyTorch.

@unnir
Copy link

unnir commented Feb 22, 2021

to have it in one place:

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)
    
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def resnet18(pretrained=False, progress=True, **kwargs):
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

        # Added another relu here
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity

        # Modified to use relu2
        out = self.relu2(out)

        return out

and then you can:

test_model = resnet18(pretrained=False, progress=True)

@cdsarto
Copy link

cdsarto commented May 19, 2021

Hey everybody,

I have a similiar issue than reported by @mrdupadupa . I tried to use DeepLift with Reset 50. As recommended I got the source code from the ResNets from [https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py] and adjsuted the Basic Block, but I still get the same Error:

A Module ReLU(inplace=True) was detected that does not contain some of the input/output attributes that are required for DeepLift computations. This can occur, for example, if your module is being used more than once in the network.Please, ensure that module is being used only once in the network.<

I also found the comment from @giangnguyen2412 in the #480 where he wrote:

Hi @vivekmig , thanks for your help. I want to add some details into your advice (for those who are new). First, copy the definitions from here and execute them as your source. You should import two funcs load_state_dict_from_url and _get_torch_home from torch.hub. Then define your model as model=resnet50(pretrained=True).eval(). One more thing is that you are not only required to modifed the BasicBlock class but also Bottleneck class to make it work!

But I'm not sure, how to modify the BottleNeck Class and/or why to import the two mentioned funcs.

Thank you in advance :)

@shikhar-srivastava
Copy link

shikhar-srivastava commented Jul 17, 2021

@vivekmig's comment solves this.
@cdsarto: The comment by @unnir misses the BottleNeck reference in his code. Have a look at https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py for the complete source and replace the BasicBlock class with that from @vivekmig's comment here. It's a straightforward change.

Here is the full ResNet src modified to have no reused ReLU blocks.

@ericotjo001
Copy link

Hi, although there are already good answers, I just add a repo to implement the change. Feel free to take a look:
https://github.com/ericotjo001/pytorch_captum_fix

Basically, each Bottleneck or BasicBlock is replaced with AdjustedBottleneck and AdjustedBasicBlock, and then DeepLIFT will work.

@youyinnn
Copy link

@NarineK @vivekmig The solution seems not to be working anymore for torch 2.1. Especially I am using mac with mps device.

@skengman1312
Copy link

I'm still faceing the same issues with a residual connections

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests