Skip to content

Commit 172f835

Browse files
Vivek Miglanifacebook-github-bot
Vivek Miglani
authored andcommitted
Fix Lime output dimension in batch forward (#1513)
Summary: Currently, when a batch of inputs is provided with a forward function that returns a single scalar per batch, Lime and KernelShap still return output matching the input shape. This behavior is inconsistent with other perturbation based methods, particularly Feature Ablation and Shapley Value Sampling. This change breaks backward compatibility for OSS users, but since it's a specific case (scalar per batch), should be fine to update with only a documentation update. Reviewed By: craymichael Differential Revision: D70096644
1 parent 3188421 commit 172f835

File tree

3 files changed

+20
-7
lines changed

3 files changed

+20
-7
lines changed

Diff for: captum/attr/_core/lime.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,12 @@ def attribute( # type: ignore
10381038
coefficient of the corresponding interpretale feature.
10391039
All elements with the same value in the feature mask
10401040
will contain the same coefficient in the returned
1041-
attributions. If return_input_shape is False, a 1D
1041+
attributions.
1042+
If forward_func returns a single element per batch, then the
1043+
first dimension of each tensor will be 1, and the remaining
1044+
dimensions will have the same shape as the original input
1045+
tensor.
1046+
If return_input_shape is False, a 1D
10421047
tensor is returned, containing only the coefficients
10431048
of the trained interpreatable models, with length
10441049
num_interp_features.
@@ -1242,6 +1247,7 @@ def _attribute_kwargs( # type: ignore
12421247
coefs,
12431248
num_interp_features,
12441249
is_inputs_tuple,
1250+
leading_dim_one=(bsz > 1),
12451251
)
12461252
else:
12471253
return coefs
@@ -1254,6 +1260,7 @@ def _convert_output_shape(
12541260
coefs: Tensor,
12551261
num_interp_features: int,
12561262
is_inputs_tuple: Literal[True],
1263+
leading_dim_one: bool = False,
12571264
) -> Tuple[Tensor, ...]: ...
12581265

12591266
@typing.overload
@@ -1264,6 +1271,7 @@ def _convert_output_shape( # type: ignore
12641271
coefs: Tensor,
12651272
num_interp_features: int,
12661273
is_inputs_tuple: Literal[False],
1274+
leading_dim_one: bool = False,
12671275
) -> Tensor: ...
12681276

12691277
@typing.overload
@@ -1274,6 +1282,7 @@ def _convert_output_shape(
12741282
coefs: Tensor,
12751283
num_interp_features: int,
12761284
is_inputs_tuple: bool,
1285+
leading_dim_one: bool = False,
12771286
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
12781287

12791288
def _convert_output_shape(
@@ -1283,6 +1292,7 @@ def _convert_output_shape(
12831292
coefs: Tensor,
12841293
num_interp_features: int,
12851294
is_inputs_tuple: bool,
1295+
leading_dim_one: bool = False,
12861296
) -> Union[Tensor, Tuple[Tensor, ...]]:
12871297
coefs = coefs.flatten()
12881298
attr = [
@@ -1295,4 +1305,7 @@ def _convert_output_shape(
12951305
coefs[single_feature].item()
12961306
* (feature_mask[tensor_ind] == single_feature).float()
12971307
)
1308+
if leading_dim_one:
1309+
for i in range(len(attr)):
1310+
attr[i] = attr[i][0:1]
12981311
return _format_output(is_inputs_tuple, tuple(attr))

Diff for: tests/attr/test_kernel_shap.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,9 @@ def _multi_input_scalar_kernel_shap_assert(self, func: Callable) -> None:
348348
mask2 = torch.tensor([[0, 1, 2]])
349349
mask3 = torch.tensor([[0, 1, 2]])
350350
expected = (
351-
[[3850.6666, 3850.6666, 3850.6666]] * 2,
352-
[[306.6666, 3850.6666, 410.6666]] * 2,
353-
[[306.6666, 3850.6666, 410.6666]] * 2,
351+
[[3850.6666, 3850.6666, 3850.6666]],
352+
[[306.6666, 3850.6666, 410.6666]],
353+
[[306.6666, 3850.6666, 410.6666]],
354354
)
355355

356356
self._kernel_shap_test_assert(

Diff for: tests/attr/test_lime.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,9 @@ def _multi_input_scalar_lime_assert(self, func: Callable) -> None:
494494
mask2 = torch.tensor([[0, 1, 2]])
495495
mask3 = torch.tensor([[0, 1, 2]])
496496
expected = (
497-
[[3850.6666, 3850.6666, 3850.6666]] * 2,
498-
[[305.5, 3850.6666, 410.1]] * 2,
499-
[[305.5, 3850.6666, 410.1]] * 2,
497+
[[3850.6666, 3850.6666, 3850.6666]],
498+
[[305.5, 3850.6666, 410.1]],
499+
[[305.5, 3850.6666, 410.1]],
500500
)
501501

502502
self._lime_test_assert(

0 commit comments

Comments
 (0)