Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Update version to 0.8.0 #1504

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ of integrated gradients, saliency maps, smoothgrad, vargrad and others for
PyTorch models. It has quick integration for models built with domain-specific
libraries such as torchvision, torchtext, and others.

*Captum is currently in beta and under active development!*


#### About Captum

Expand Down Expand Up @@ -92,6 +90,7 @@ pip install -e .
To customize the installation, you can also run the following variants of the
above:
* `pip install -e .[insights]`: Also installs all packages necessary for running Captum Insights.
**NOTE**: Captum Insights is being deprecated. See further details [below](#captum-insights).
* `pip install -e .[dev]`: Also installs all tools necessary for development
(testing, linting, docs building; see [Contributing](#contributing) below).
* `pip install -e .[tutorials]`: Also installs all packages necessary for running the tutorial notebooks.
Expand Down Expand Up @@ -388,6 +387,10 @@ Captum on different types of models can be found in our tutorials.

## Captum Insights

**NOTE**: *Support for Captum Insights is being deprecated in an upcoming release.
While the code will still be available, there will no longer be active
development or support for it.*

Captum provides a web interface called Insights for easy visualization and
access to a number of our interpretability algorithms.

Expand Down
2 changes: 1 addition & 1 deletion captum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
import captum.robust as robust


__version__ = "0.7.0"
__version__ = "0.8.0"

__all__ = ["attr", "concept", "influence", "log", "metrics", "robust"]
8 changes: 4 additions & 4 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,12 @@ def _format_tensor_into_tuples(inputs: None) -> None: ...

@overload
def _format_tensor_into_tuples(
inputs: Union[Tensor, Tuple[Tensor, ...]]
inputs: Union[Tensor, Tuple[Tensor, ...]],
) -> Tuple[Tensor, ...]: ...


def _format_tensor_into_tuples(
inputs: Union[None, Tensor, Tuple[Tensor, ...]]
inputs: Union[None, Tensor, Tuple[Tensor, ...]],
) -> Union[None, Tuple[Tensor, ...]]:
if inputs is None:
return None
Expand All @@ -261,7 +261,7 @@ def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any:


def _format_float_or_tensor_into_tuples(
inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]
inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]],
) -> Tuple[Union[float, Tensor], ...]:
if not isinstance(inputs, tuple):
assert isinstance(
Expand All @@ -276,7 +276,7 @@ def _format_float_or_tensor_into_tuples(
@overload
def _format_additional_forward_args(
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
additional_forward_args: Union[Tensor, Tuple]
additional_forward_args: Union[Tensor, Tuple],
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Tuple: ...

Expand Down
6 changes: 2 additions & 4 deletions captum/attr/_core/dataloader_attr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

# pyre-strict

from collections import defaultdict
from copy import copy
from typing import Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -30,7 +31,6 @@ class InputRole:


# default reducer wehn reduce is None. Simply concat the outputs by the batch dimension
# pyre-fixme[2]: Parameter must be annotated.
def _concat_tensors(accum: Optional[Tensor], cur_output: Tensor, _) -> Tensor:
return cur_output if accum is None else torch.cat([accum, cur_output])

Expand Down Expand Up @@ -87,9 +87,7 @@ def _perturb_inputs(
else:
baseline = baselines[attr_inp_count]

# pyre-fixme[58]: `*` is not supported for operand types `object` and
# `Tensor`.
perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask)
perturbed_inp = cast(Tensor, inp) * pert_mask + baseline * (1 - pert_mask)
perturbed_inputs.append(perturbed_inp)

attr_inp_count += 1
Expand Down
2 changes: 1 addition & 1 deletion captum/attr/_core/layer/layer_lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _get_output_relevance(

@staticmethod
def _convert_list_to_tuple(
relevances: Union[List[T], Tuple[T, ...]]
relevances: Union[List[T], Tuple[T, ...]],
) -> Tuple[T, ...]:
if isinstance(relevances, list):
return tuple(relevances)
Expand Down
11 changes: 9 additions & 2 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,19 @@ def _forward_func(
outputs.past_key_values = DynamicCache.from_legacy_cache(
outputs.past_key_values
)
# nn.Module typing suggests non-base attributes are modules or
# tensors
_update_model_kwargs_for_generation = (
self.model._update_model_kwargs_for_generation
)
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
model_kwargs = self.model._update_model_kwargs_for_generation(
model_kwargs = _update_model_kwargs_for_generation( # type: ignore
outputs, model_kwargs
)
# nn.Module typing suggests non-base attributes are modules or tensors
prep_inputs_for_generation = self.model.prepare_inputs_for_generation
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
model_inputs = self.model.prepare_inputs_for_generation(
model_inputs = prep_inputs_for_generation( # type: ignore
model_inp, **model_kwargs
)
outputs = self.model.forward(**model_inputs)
Expand Down
2 changes: 1 addition & 1 deletion captum/attr/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def _find_output_mode_and_verify(


def _construct_default_feature_mask(
inputs: Tuple[Tensor, ...]
inputs: Tuple[Tensor, ...],
) -> Tuple[Tuple[Tensor, ...], int]:
feature_mask = []
current_num_features = 0
Expand Down
13 changes: 6 additions & 7 deletions captum/attr/_utils/stat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable, List, Optional, TYPE_CHECKING

from typing import Any, Callable, cast, List, Optional, TYPE_CHECKING

import torch
from torch import Tensor
Expand Down Expand Up @@ -117,20 +118,18 @@ def get(self) -> Optional[Tensor]:
return self.rolling_mean

def init(self) -> None:
# pyre-fixme[8]: Attribute has type `Optional[Count]`; used as `Optional[Stat]`.
self.n = self._get_stat(Count()) # type: ignore
self.n = cast(Count, self._get_stat(Count()))

def update(self, x: Tensor) -> None:
# pyre-fixme[16]: `Optional` has no attribute `get`.
n = self.n.get() # type: ignore
n = cast(Count, self.n).get()

if self.rolling_mean is None:
# Ensures rolling_mean is a float tensor
self.rolling_mean = x.clone() if x.is_floating_point() else x.double()
else:
delta = x - self.rolling_mean
# pyre-fixme[16]: `Optional` has no attribute `__iadd__`.
self.rolling_mean += delta / n
# pyre-ignore[16]: `Optional` has no attribute `__iadd__` (false positive)
self.rolling_mean += delta / cast(int, n)


class MSE(Stat):
Expand Down
4 changes: 2 additions & 2 deletions captum/influence/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def __len__(self) -> int:

def _format_inputs_dataset(
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
inputs_dataset: Union[Tuple[Any, ...], DataLoader]
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
) -> DataLoader:
# if `inputs_dataset` is not a `DataLoader`, turn it into one.
# `_DatasetFromList` turns a list into a `Dataset` where `__getitem__`
Expand Down Expand Up @@ -604,7 +604,7 @@ def _flatten_params(_params: Tuple[Tensor, ...]) -> Tensor:

# pyre-fixme[3]: Return type must be annotated.
def _unflatten_params_factory(
param_shapes: Union[List[Tuple[int, ...]], Tuple[Tensor, ...]]
param_shapes: Union[List[Tuple[int, ...]], Tuple[Tensor, ...]],
):
"""
returns a function which is the inverse of `_flatten_params`
Expand Down
2 changes: 1 addition & 1 deletion captum/insights/attr_vis/_utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def format_transforms(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
transforms: Optional[Union[Callable, List[Callable]]]
transforms: Optional[Union[Callable, List[Callable]]],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
) -> List[Callable]:
if transforms is None:
Expand Down
2 changes: 1 addition & 1 deletion captum/insights/attr_vis/frontend/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "frontend",
"version": "0.7.0",
"version": "0.8.0",
"private": true,
"homepage": ".",
"dependencies": {
Expand Down
9 changes: 6 additions & 3 deletions captum/metrics/_core/infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def infidelity_perturb_func_decorator(
"""

def sub_infidelity_perturb_func_decorator(
perturb_func: Callable[..., TensorOrTupleOfTensorsGeneric]
perturb_func: Callable[..., TensorOrTupleOfTensorsGeneric],
) -> Callable[
[TensorOrTupleOfTensorsGeneric, BaselineType],
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
Expand Down Expand Up @@ -611,6 +611,11 @@ def _next_infidelity_tensors(
targets_expanded,
additional_forward_args_expanded,
)
if isinstance(inputs_perturbed_fwd, torch.futures.Future):
raise NotImplementedError(
f"Outputs from forward_func of type {type(inputs_perturbed_fwd)} are "
"not yet supported."
)
inputs_fwd = _run_forward(forward_func, inputs, target, additional_forward_args)
# _run_forward may return future of Tensor,
# but we don't support it here now
Expand All @@ -619,8 +624,6 @@ def _next_infidelity_tensors(
inputs_fwd = torch.repeat_interleave(
inputs_fwd, current_n_perturb_samples, dim=0
)
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
# `Union[Future[Tensor], Tensor]`.
perturbed_fwd_diffs = inputs_fwd - inputs_perturbed_fwd
attributions_expanded = tuple(
torch.repeat_interleave(attribution, current_n_perturb_samples, dim=0)
Expand Down
6 changes: 3 additions & 3 deletions captum/testing/helpers/influence/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _isSorted(x, key=lambda x: x, descending=True) -> bool:

# pyre-fixme[2]: Parameter must be annotated.
def _wrap_model_in_dataparallel(net) -> Module:
alt_device_ids = [0] + [x for x in range(torch.cuda.device_count() - 1, 0, -1)]
alt_device_ids = [0] + list(range(torch.cuda.device_count() - 1, 0, -1))
net = net.cuda()
return torch.nn.DataParallel(net, device_ids=alt_device_ids)

Expand Down Expand Up @@ -505,7 +505,7 @@ def get_random_model_and_data(

# pyre-fixme[3]: Return type must be annotated.
def generate_symmetric_matrix_given_eigenvalues(
eigenvalues: Union[Tensor, List[float]]
eigenvalues: Union[Tensor, List[float]],
):
"""
following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L123 # noqa: E501
Expand All @@ -523,7 +523,7 @@ def generate_symmetric_matrix_given_eigenvalues(


def generate_assymetric_matrix_given_eigenvalues(
eigenvalues: Union[Tensor, List[float]]
eigenvalues: Union[Tensor, List[float]],
) -> Tensor:
"""
following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L105 # noqa: E501
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
first_party_detection = false

[tool.black]
target-version = ['py36']
target-version = ['py39']
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,14 @@ def get_package_files(root, subdirs):
"conda": "https://anaconda.org/pytorch/captum",
},
keywords=[
"Model Interpretability",
"Model Understanding",
"Model Interpretability",
"Model Understanding",
"Feature Importance",
"Neuron Importance",
"Data Attribution",
"Explainable AI",
"PyTorch",
],
classifiers=[
Expand All @@ -148,7 +152,10 @@ def get_package_files(root, subdirs):
],
long_description=long_description,
long_description_content_type="text/markdown",
python_requires=">=3.9",
python_requires=">={required_major}.{required_minor}".format(
required_minor=REQUIRED_MINOR,
required_major=REQUIRED_MAJOR,
),
install_requires=[
"matplotlib",
"numpy<2.0",
Expand Down
5 changes: 3 additions & 2 deletions tests/attr/layer/test_layer_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def test_multiple_layers_multiple_inputs_shared_input(self) -> None:
self,
# last input for second layer is first input =>
# add the attributions
(attribs_inputs[0] + attribs_inputs[1][-1],) + attribs_inputs[1][0:-1],
(attribs_inputs[0] + attribs_inputs[1][-1],) # type: ignore
+ attribs_inputs[1][0:-1], # type: ignore
attribs_inputs_regular_ig,
delta=1e-5,
)
Expand Down Expand Up @@ -183,7 +184,7 @@ def test_multiple_layers_multiple_input_outputs(self) -> None:

assertTensorTuplesAlmostEqual(
self,
(attribs_inputs[0],) + attribs_inputs[1],
(attribs_inputs[0],) + attribs_inputs[1], # type: ignore
attribs_inputs_regular_ig,
delta=1e-7,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/concept/test_tcav.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def _compute_cavs_interpret(
attribute_to_layer_input: bool = False,
) -> None:
def wrap_in_list_if_not_already(
input: Union[str, float, List[float], List[str]]
input: Union[str, float, List[float], List[str]],
) -> Union[List[Union[float, str]], List[float], List[str]]:
return (
input
Expand Down
4 changes: 2 additions & 2 deletions tests/metrics/test_infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _local_perturb_func_default(
# pyre-ignore[43]: The implementation of `_local_perturb_func` does not accept all
# possible arguments of overload defined on line `43`.
def _local_perturb_func(
inputs: Tuple[Tensor, ...]
inputs: Tuple[Tensor, ...],
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: ...


Expand Down Expand Up @@ -83,7 +83,7 @@ def _global_perturb_func1_default(
# pyre-fixme[43]: The implementation of `_global_perturb_func1` does not accept all
# possible arguments of overload defined on line `74`.
def _global_perturb_func1(
inputs: Tuple[Tensor, ...]
inputs: Tuple[Tensor, ...],
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: ...


Expand Down
Loading