Skip to content

Commit d2b3684

Browse files
styusuffacebook-github-bot
authored andcommitted
Adding a new Module for unsupported layer. Adding test for unsupported layers. Simple logging for unsupported layers
Summary: We are adding test for unsupported gradient layers. Open to ideas if there is a better way to structure the test. Differential Revision: D69792994
1 parent 087325a commit d2b3684

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

captum/testing/helpers/basic_models.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44

5-
from typing import no_type_check, Optional, Tuple, Union
5+
from typing import List, no_type_check, Optional, Tuple, Union
66

77
import torch
88
import torch.nn as nn
@@ -428,8 +428,24 @@ def forward(self, arg1: Tensor, arg2: Tensor) -> Tuple[Tensor, Tensor]:
428428
return (self.relu1(arg1), self.relu2(arg2))
429429

430430

431+
class GradientUnsupportedLayerOutput(nn.Module):
432+
def __init__(self, inplace: bool = False) -> None:
433+
super().__init__()
434+
self.relu1 = nn.ReLU(inplace=inplace)
435+
self.relu2 = nn.ReLU(inplace=inplace)
436+
437+
@no_type_check
438+
def forward(self, arg1: Tensor, arg2: Tensor) -> List[Tensor]:
439+
return [self.relu1(arg1), self.relu2(arg2)]
440+
441+
431442
class BasicModel_MultiLayer(nn.Module):
432-
def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None:
443+
def __init__(
444+
self,
445+
inplace: bool = False,
446+
multi_input_module: bool = False,
447+
unsupported_layer_output: bool = False,
448+
) -> None:
433449
super().__init__()
434450
# Linear 0 is simply identity transform
435451
self.multi_input_module = multi_input_module
@@ -443,7 +459,11 @@ def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> N
443459
self.linear1_alt = nn.Linear(3, 4)
444460
self.linear1_alt.weight = nn.Parameter(torch.ones(4, 3))
445461
self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))
446-
self.multi_relu = MultiRelu(inplace=inplace)
462+
self.multi_relu = (
463+
GradientUnsupportedLayerOutput(inplace=inplace)
464+
if unsupported_layer_output
465+
else MultiRelu(inplace=inplace)
466+
)
447467
self.relu = nn.ReLU(inplace=inplace)
448468

449469
self.linear2 = nn.Linear(4, 2)

0 commit comments

Comments
 (0)