Skip to content

Commit b8e4186

Browse files
vivekmigfacebook-github-bot
authored andcommitted
DataParallel DeepLift Fixes (#335)
Summary: This PR adds only the bug fixes identified from refactoring and adding dynamic tests. 1. Additional forward arguments were not working appropriately with DeepLift on a DataParallel model. This is because the device split of expanded additional forward args didn't necessarily match that of inputs. The behavior has been changed to expand the additional args in the hook function (after device split), which ensures the additional args and inputs remain matched. 2. Different targets per example was not working appropriately with DeepLift on a DataParallel model. This is because the model output concatenated the outputs of the devices in DataParallel, which mixed input / baseline outputs, inhibiting appropriate matching between input example and target. Additional forward hooks have been added to appropriately return the output with all inputs followed by all baselines. 3. GradCAM is primarily intended for layers with >= 3 dimensions, since it computes average gradient for each example / channel. For layers with 2 dimensions, the mean gradient over all dimensions was being taken. This has been updated to use the layer gradients directly in this case, which better aligns with the behavior for >= 3 dimensions. 4. DeepLiftShap (and Neuron / Layer variants) were incorrectly repeating additional forward args, this has been fixed to use repeat interleave instead. Pull Request resolved: #335 Reviewed By: edward-io Differential Revision: D20844511 Pulled By: vivekmig fbshipit-source-id: c895b348c3d5c56355c39d429947f2f36dda37a7
1 parent 72294ff commit b8e4186

File tree

4 files changed

+44
-20
lines changed

4 files changed

+44
-20
lines changed

captum/attr/_core/deep_lift.py

+37-12
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_format_tensor_into_tuples,
2121
_is_tuple,
2222
_run_forward,
23+
_select_targets,
2324
)
2425
from ..._utils.typing import (
2526
BaselineType,
@@ -304,22 +305,20 @@ def attribute( # type: ignore
304305
)
305306

306307
baselines = _tensorize_baseline(inputs, baselines)
307-
main_model_pre_hook = self._pre_hook_main_model()
308+
main_model_pre_hooks = self._hook_main_model()
308309

309310
self.model.apply(self._register_hooks)
310311

311312
additional_forward_args = _format_additional_forward_args(
312313
additional_forward_args
313314
)
314-
input_base_additional_args = _expand_additional_forward_args(
315-
additional_forward_args, 2, ExpansionTypes.repeat
316-
)
315+
317316
expanded_target = _expand_target(
318317
target, 2, expansion_type=ExpansionTypes.repeat
319318
)
320319

321320
wrapped_forward_func = self._construct_forward_func(
322-
self.model, (inputs, baselines), expanded_target, input_base_additional_args
321+
self.model, (inputs, baselines), expanded_target, additional_forward_args
323322
)
324323
gradients = self.gradient_func(wrapped_forward_func, inputs)
325324
if custom_attribution_func is None:
@@ -332,7 +331,9 @@ def attribute( # type: ignore
332331
custom_attribution_func, gradients, inputs, baselines
333332
)
334333
# remove hooks from all activations
335-
main_model_pre_hook.remove()
334+
for hook in main_model_pre_hooks:
335+
hook.remove()
336+
336337
self._remove_hooks()
337338

338339
undo_gradient_requirements(inputs, gradient_mask)
@@ -355,7 +356,12 @@ def _construct_forward_func(
355356
additional_forward_args: Any = None,
356357
) -> Callable:
357358
def forward_fn():
358-
return _run_forward(forward_func, inputs, target, additional_forward_args)
359+
model_out = _run_forward(
360+
forward_func, inputs, None, additional_forward_args
361+
)
362+
return _select_targets(
363+
torch.cat((model_out[:, 0], model_out[:, 1])), target
364+
)
359365

360366
if hasattr(forward_func, "device_ids"):
361367
forward_fn.device_ids = forward_func.device_ids # type: ignore
@@ -501,7 +507,7 @@ def _remove_hooks(self) -> None:
501507
for backward_handle in self.backward_handles:
502508
backward_handle.remove()
503509

504-
def _pre_hook_main_model(self) -> RemovableHandle:
510+
def _hook_main_model(self) -> List[RemovableHandle]:
505511
def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple:
506512
inputs = baseline_inputs_add_args[0]
507513
baselines = baseline_inputs_add_args[1]
@@ -514,13 +520,28 @@ def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple:
514520
for input, baseline in zip(inputs, baselines)
515521
)
516522
if additional_args is not None:
517-
return (*baseline_input_tsr, *additional_args)
523+
expanded_additional_args = cast(
524+
Tuple,
525+
_expand_additional_forward_args(
526+
additional_args, 2, ExpansionTypes.repeat
527+
),
528+
)
529+
return (*baseline_input_tsr, *expanded_additional_args)
518530
return baseline_input_tsr
519531

532+
def forward_hook(module: Module, inputs: Tuple, outputs: Tensor):
533+
return torch.stack(torch.chunk(outputs, 2), dim=1)
534+
520535
if isinstance(self.model, nn.DataParallel):
521-
return self.model.module.register_forward_pre_hook(pre_hook) # type: ignore
536+
return [
537+
self.model.module.register_forward_pre_hook(pre_hook), # type: ignore
538+
self.model.module.register_forward_hook(forward_hook),
539+
] # type: ignore
522540
else:
523-
return self.model.register_forward_pre_hook(pre_hook) # type: ignore
541+
return [
542+
self.model.register_forward_pre_hook(pre_hook), # type: ignore
543+
self.model.register_forward_hook(forward_hook),
544+
] # type: ignore
524545

525546
def has_convergence_delta(self) -> bool:
526547
return True
@@ -810,7 +831,11 @@ def _expand_inputs_baselines_targets(
810831
target, base_bsz, expansion_type=ExpansionTypes.repeat_interleave
811832
)
812833
input_additional_args = (
813-
_expand_additional_forward_args(additional_forward_args, base_bsz)
834+
_expand_additional_forward_args(
835+
additional_forward_args,
836+
base_bsz,
837+
expansion_type=ExpansionTypes.repeat_interleave,
838+
)
814839
if additional_forward_args is not None
815840
else None
816841
)

captum/attr/_core/feature_ablation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def attribute(
243243
feature_mask = _format_input(feature_mask) if feature_mask is not None else None
244244
assert (
245245
isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
246-
), "Ablations per evaluation must be at least 1."
246+
), "Perturbations per evaluation must be an integer and at least 1."
247247
with torch.no_grad():
248248
# Computes initial evaluation with all features, which is compared
249249
# to each ablated result.

captum/attr/_core/layer/grad_cam.py

+2
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ def attribute(
207207
dim=tuple(x for x in range(2, len(layer_grad.shape))),
208208
keepdim=True,
209209
)
210+
if len(layer_grad.shape) > 2
211+
else layer_grad
210212
for layer_grad in layer_gradients
211213
)
212214

captum/attr/_core/layer/layer_deep_lift.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from ...._utils.common import (
1010
ExpansionTypes,
11-
_expand_additional_forward_args,
1211
_expand_target,
1312
_format_additional_forward_args,
1413
_format_input,
@@ -277,21 +276,18 @@ def attribute(
277276

278277
baselines = _tensorize_baseline(inputs, baselines)
279278

280-
main_model_pre_hook = self._pre_hook_main_model()
279+
main_model_hooks = self._hook_main_model()
281280

282281
self.model.apply(self._register_hooks)
283282

284283
additional_forward_args = _format_additional_forward_args(
285284
additional_forward_args
286285
)
287-
input_base_additional_args = _expand_additional_forward_args(
288-
additional_forward_args, 2, ExpansionTypes.repeat
289-
)
290286
expanded_target = _expand_target(
291287
target, 2, expansion_type=ExpansionTypes.repeat
292288
)
293289
wrapped_forward_func = self._construct_forward_func(
294-
self.model, (inputs, baselines), expanded_target, input_base_additional_args
290+
self.model, (inputs, baselines), expanded_target, additional_forward_args,
295291
)
296292

297293
def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric,) -> Sequence:
@@ -323,8 +319,9 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric,) -> Sequence:
323319
custom_attribution_func, gradients, attr_inputs, attr_baselines
324320
)
325321
# remove hooks from all activations
326-
main_model_pre_hook.remove()
327322
self._remove_hooks()
323+
for hook in main_model_hooks:
324+
hook.remove()
328325

329326
undo_gradient_requirements(inputs, gradient_mask)
330327
return _compute_conv_delta_and_format_attrs(

0 commit comments

Comments
 (0)