Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add MutualInformation Metric #3230

Merged
merged 24 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ Complete list of metrics
metric.Metric
metrics_lambda.MetricsLambda
MultiLabelConfusionMatrix
MutualInformation
precision.Precision
PSNR
recall.Recall
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix
from ignite.metrics.mutual_information import MutualInformation
from ignite.metrics.nlp.bleu import Bleu
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN
from ignite.metrics.precision import Precision
Expand Down Expand Up @@ -57,6 +58,7 @@
"mIoU",
"JaccardIndex",
"MultiLabelConfusionMatrix",
"MutualInformation",
"Precision",
"PSNR",
"Recall",
Expand Down
8 changes: 6 additions & 2 deletions ignite/metrics/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def reset(self) -> None:
self._sum_of_entropies = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred = output[0].detach()
if y_pred.ndim >= 3:
Expand All @@ -80,9 +79,14 @@ def update(self, output: Sequence[torch.Tensor]) -> None:

prob = F.softmax(y_pred, dim=1)
log_prob = F.log_softmax(y_pred, dim=1)

self._update(prob, log_prob)

@reinit__is_reduced
kzkadc marked this conversation as resolved.
Show resolved Hide resolved
def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None:
entropy_sum = -torch.sum(prob * log_prob)
self._sum_of_entropies += entropy_sum.to(self._device)
self._num_examples += y_pred.shape[0]
self._num_examples += prob.shape[0]

@sync_all_reduce("_sum_of_entropies", "_num_examples")
def compute(self) -> float:
Expand Down
94 changes: 94 additions & 0 deletions ignite/metrics/mutual_information.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch

from ignite.exceptions import NotComputableError
from ignite.metrics import Entropy
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce

__all__ = ["MutualInformation"]


class MutualInformation(Entropy):
r"""Calculates the `mutual information <https://en.wikipedia.org/wiki/Mutual_information>`_
between input :math:`X` and prediction :math:`Y`.
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

.. math::
\begin{align*}
I(X;Y) &= H(Y) - H(Y|X) = H \left( \frac{1}{N}\sum_{i=1}^N \hat{\mathbf{p}}_i \right)
- \frac{1}{N}\sum_{i=1}^N H(\hat{\mathbf{p}}_i), \\
H(\mathbf{p}) &= -\sum_{c=1}^C p_c \log p_c.
\end{align*}

where :math:`\hat{\mathbf{p}}_i` is the prediction probability vector for :math:`i`-th input,
and :math:`H(\mathbf{p})` is the entropy of :math:`\mathbf{p}`.

Intuitively, this metric measures how well input data are clustered by classes in the feature space [1].

[1] https://proceedings.mlr.press/v70/hu17b.html

- ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric.
- ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification)
or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed.

Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.

For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

metric = MutualInformation()
metric.attach(default_evaluator, 'mutual_information')
y_true = torch.tensor([0, 1, 2]) # not considered in the MutualInformation metric.
y_pred = torch.tensor([
[ 0.0000, 0.6931, 1.0986],
[ 1.3863, 1.6094, 1.6094],
[ 0.0000, -2.3026, -2.3026]
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics['mutual_information'])

.. testoutput::

0.18599730730056763
"""

_state_dict_all_req_keys = ("_sum_of_probabilities",)

@reinit__is_reduced
def reset(self) -> None:
super().reset()
self._sum_of_probabilities = torch.tensor(0.0, device=self._device)

@reinit__is_reduced
kzkadc marked this conversation as resolved.
Show resolved Hide resolved
def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None:
super()._update(prob, log_prob)
self._sum_of_probabilities = self._sum_of_probabilities + prob.sum(dim=0).to(self._device)
kzkadc marked this conversation as resolved.
Show resolved Hide resolved

@sync_all_reduce("_sum_of_probabilities")
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
def compute(self) -> float:
n = self._num_examples
if n == 0:
raise NotComputableError("MutualInformation must have at least one example before it can be computed.")

marginal_prob = self._sum_of_probabilities / n
marginal_ent = -(marginal_prob * torch.log(marginal_prob)).sum()
conditional_ent = self._sum_of_entropies / n
mi = marginal_ent - conditional_ent
mi = torch.clamp(mi, min=0.0) # mutual information cannot be negative
return float(mi.item())
145 changes: 145 additions & 0 deletions tests/ignite/metrics/test_mutual_information.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from typing import Tuple

import numpy as np
import pytest
import torch
from scipy.special import softmax
from scipy.stats import entropy
from torch import Tensor

import ignite.distributed as idist

from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics import MutualInformation


def np_mutual_information(np_y_pred: np.ndarray) -> float:
prob = softmax(np_y_pred, axis=1)
marginal_ent = entropy(np.mean(prob, axis=0))
conditional_ent = np.mean(entropy(prob, axis=1))
return max(0.0, marginal_ent - conditional_ent)
kzkadc marked this conversation as resolved.
Show resolved Hide resolved


def test_zero_sample():
mi = MutualInformation()
with pytest.raises(
NotComputableError, match=r"MutualInformation must have at least one example before it can be computed"
):
mi.compute()


def test_invalid_shape():
mi = MutualInformation()
y_pred = torch.randn(10).float()
with pytest.raises(ValueError, match=r"y_pred must be in the shape of \(B, C\) or \(B, C, ...\), got"):
mi.update((y_pred, None))


@pytest.fixture(params=list(range(4)))
def test_case(request):
return [
(torch.randn((100, 10)).float(), torch.randint(0, 10, size=[100]), 1),
(torch.rand((100, 500)).float(), torch.randint(0, 500, size=[100]), 1),
# updated batches
(torch.normal(0.0, 5.0, size=(100, 10)).float(), torch.randint(0, 10, size=[100]), 16),
(torch.normal(5.0, 3.0, size=(100, 200)).float(), torch.randint(0, 200, size=[100]), 16),
# image segmentation
(torch.randn((100, 5, 32, 32)).float(), torch.randint(0, 5, size=(100, 32, 32)), 16),
(torch.randn((100, 5, 224, 224)).float(), torch.randint(0, 5, size=(100, 224, 224)), 16),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]):
mi = MutualInformation()

y_pred, y, batch_size = test_case

mi.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
mi.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
mi.update((y_pred, y))

np_res = np_mutual_information(y_pred.numpy())
res = mi.compute()

assert isinstance(res, float)
assert pytest.approx(np_res, rel=1e-4) == res


def test_accumulator_detached():
mi = MutualInformation()

y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True)
y = torch.zeros(2)
mi.update((y_pred, y))

assert not mi._sum_of_probabilities.requires_grad


@pytest.mark.usefixtures("distributed")
class TestDistributed:
def test_integration(self):
tol = 1e-4
n_iters = 100
batch_size = 10
n_cls = 50
device = idist.device()
rank = idist.get_rank()
torch.manual_seed(12 + rank)

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)

for metric_device in metric_devices:
y_true = torch.randint(0, n_cls, size=[n_iters * batch_size], dtype=torch.long).to(device)
y_preds = torch.normal(0.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device)

engine = Engine(
lambda e, i: (
y_preds[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)
)

m = MutualInformation(device=metric_device)
m.attach(engine, "mutual_information")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "mutual_information" in engine.state.metrics
res = engine.state.metrics["mutual_information"]

true_res = np_mutual_information(y_preds.cpu().numpy())

assert pytest.approx(true_res, rel=tol) == res

def test_accumulator_device(self):
device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
for metric_device in metric_devices:
mi = MutualInformation(device=metric_device)

devices = (mi._device, mi._sum_of_probabilities.device)
for dev in devices:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True)
y = torch.zeros(2)
mi.update((y_pred, y))

devices = (mi._device, mi._sum_of_probabilities.device)
for dev in devices:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"