diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index 77584594a9..72a26607ab 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -2,7 +2,7 @@ # pyre-strict -from typing import no_type_check, Optional, Tuple, Union +from typing import Dict, no_type_check, Optional, Tuple, Union import torch import torch.nn as nn @@ -467,7 +467,9 @@ def __init__( self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0])) @no_type_check - def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor: + def forward( + self, x: Tensor, add_input: Optional[Tensor] = None + ) -> Dict[str, Tensor]: input = x if add_input is None else x + add_input lin0_out = self.linear0(input) lin1_out = self.linear1(lin0_out) @@ -485,7 +487,14 @@ def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor: lin3_out = self.linear3(lin1_out_alt).to(torch.int64) - return torch.cat((lin2_out, lin3_out), dim=1) + output_tensors = torch.cat((lin2_out, lin3_out), dim=1) + + # we return a dictionary of tensors as an output to test the case + # where an output accessor is required + return { + "task {}".format(i + 1): output_tensors[:, i] + for i in range(output_tensors.shape[1]) + } class MultiRelu(nn.Module):