Skip to content

Adding a new Module for unsupported layer. Adding test for unsupported layers. Simple logging for unsupported layers #1505

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
TupleOrTensorOrBoolGeneric = TypeVar(
"TupleOrTensorOrBoolGeneric", Tuple[Tensor, ...], Tensor, bool
)
PassThroughOutputType = TypeVar("PassThroughOutputType")
ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]]
Expand Down
78 changes: 77 additions & 1 deletion captum/testing/helpers/basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from captum._utils.typing import PassThroughOutputType
from torch import Tensor
from torch.futures import Future

Expand Down Expand Up @@ -417,6 +418,76 @@ def forward(self, input1, input2, input3=None):
return self.linear2(self.relu(self.linear1(embeddings))).sum(1)


class GradientUnsupportedLayerOutput(nn.Module):
"""
This layer is used to test the case where the model returns a layer that
is not supported by the gradient computation.
"""

def __init__(self) -> None:
super().__init__()

@no_type_check
def forward(
self, unsupported_layer_output: PassThroughOutputType
) -> PassThroughOutputType:
return unsupported_layer_output


class BasicModel_GradientLayerAttribution(nn.Module):
def __init__(
self,
inplace: bool = False,
unsupported_layer_output: PassThroughOutputType = None,
) -> None:
super().__init__()
# Linear 0 is simply identity transform
self.unsupported_layer_output = unsupported_layer_output
self.linear0 = nn.Linear(3, 3)
self.linear0.weight = nn.Parameter(torch.eye(3))
self.linear0.bias = nn.Parameter(torch.zeros(3))
self.linear1 = nn.Linear(3, 4)
self.linear1.weight = nn.Parameter(torch.ones(4, 3))
self.linear1.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))

self.linear1_alt = nn.Linear(3, 4)
self.linear1_alt.weight = nn.Parameter(torch.ones(4, 3))
self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))

self.relu = nn.ReLU(inplace=inplace)
self.relu_alt = nn.ReLU(inplace=False)
self.unsupportedLayer = GradientUnsupportedLayerOutput()

self.linear2 = nn.Linear(4, 2)
self.linear2.weight = nn.Parameter(torch.ones(2, 4))
self.linear2.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))

self.linear3 = nn.Linear(4, 2)
self.linear3.weight = nn.Parameter(torch.ones(2, 4))
self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))

@no_type_check
def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor:
input = x if add_input is None else x + add_input
lin0_out = self.linear0(input)
lin1_out = self.linear1(lin0_out)
lin1_out_alt = self.linear1_alt(lin0_out)

if self.unsupported_layer_output is not None:
self.unsupportedLayer(self.unsupported_layer_output)
# unsupportedLayer is unused in the forward func.
self.relu_alt(
lin1_out_alt
) # relu_alt's output is supported but it's unused in the forward func.

relu_out = self.relu(lin1_out)
lin2_out = self.linear2(relu_out)

lin3_out = self.linear3(lin1_out_alt).to(torch.int64)

return torch.cat((lin2_out, lin3_out), dim=1)


class MultiRelu(nn.Module):
def __init__(self, inplace: bool = False) -> None:
super().__init__()
Expand All @@ -429,7 +500,11 @@ def forward(self, arg1: Tensor, arg2: Tensor) -> Tuple[Tensor, Tensor]:


class BasicModel_MultiLayer(nn.Module):
def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None:
def __init__(
self,
inplace: bool = False,
multi_input_module: bool = False,
) -> None:
super().__init__()
# Linear 0 is simply identity transform
self.multi_input_module = multi_input_module
Expand Down Expand Up @@ -461,6 +536,7 @@ def forward(
input = x if add_input is None else x + add_input
lin0_out = self.linear0(input)
lin1_out = self.linear1(lin0_out)

if self.multi_input_module:
relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input))
relu_out = relu_out1 + relu_out2
Expand Down