diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index ef125031481..0e4979f82a1 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -335,6 +335,7 @@ Complete list of metrics MeanPairwiseDistance MeanSquaredError metric.Metric + metric_group.MetricGroup metrics_lambda.MetricsLambda MultiLabelConfusionMatrix MutualInformation diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index e4f4e24337c..142a13e5934 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -22,6 +22,7 @@ from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance from ignite.metrics.mean_squared_error import MeanSquaredError from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage +from ignite.metrics.metric_group import MetricGroup from ignite.metrics.metrics_lambda import MetricsLambda from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix from ignite.metrics.mutual_information import MutualInformation @@ -41,6 +42,7 @@ "Metric", "Accuracy", "Loss", + "MetricGroup", "MetricsLambda", "MeanAbsoluteError", "MeanPairwiseDistance", diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py index fdbba17ad5c..8b925392bdc 100644 --- a/ignite/metrics/metric_group.py +++ b/ignite/metrics/metric_group.py @@ -1,12 +1,41 @@ -from typing import Any, Dict +from typing import Any, Callable, Dict from ignite.metrics import Metric class MetricGroup(Metric): - def __init__(self, metrics: Dict[str, Metric]): + """ + A class for grouping metrics so that user could manage them easier. + + Args: + metrics: a dictionary of names to metric instances. + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. `output_transform` of each metric in the group is also + called upon its update. + + Examples: + We construct a group of metrics, attach them to the engine at once and retrieve their result. + + .. code-block:: python + metric_group = {'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())} + metric_group.attach(default_evaluator, "eval_metrics") + y_true = torch.tensor([1, 0, 1, 1, 0, 1]) + y_pred = torch.tensor([1, 0, 1, 0, 1, 1]) + state = default_evaluator.run([[y_pred, y_true]]) + + # Metrics individually available in `state.metrics` + state.metrics["acc"], state.metrics["precision"], state.metrics["loss"] + + # And also altogether + state.metrics["eval_metrics] + """ + + _state_dict_all_req_keys = ("metrics",) + + def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x): self.metrics = metrics - super(MetricGroup, self).__init__() + super(MetricGroup, self).__init__(output_transform=output_transform) def reset(self): for m in self.metrics.values(): diff --git a/tests/ignite/metrics/test_metric_group.py b/tests/ignite/metrics/test_metric_group.py new file mode 100644 index 00000000000..237df966e05 --- /dev/null +++ b/tests/ignite/metrics/test_metric_group.py @@ -0,0 +1,118 @@ +import pytest +import torch + +from ignite import distributed as idist +from ignite.engine import Engine +from ignite.metrics import Accuracy, MetricGroup, Precision + +torch.manual_seed(41) + + +def test_update(): + precision = Precision() + accuracy = Accuracy() + + group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()}) + + y_pred = torch.randint(0, 2, (100,)) + y = torch.randint(0, 2, (100,)) + + precision.update((y_pred, y)) + accuracy.update((y_pred, y)) + group.update((y_pred, y)) + + assert precision.state_dict() == group.metrics["precision"].state_dict() + assert accuracy.state_dict() == group.metrics["accuracy"].state_dict() + + +def test_output_transform(): + def drop_first(output): + y_pred, y = output + return (y_pred[1:], y[1:]) + + precision = Precision(output_transform=drop_first) + accuracy = Accuracy(output_transform=drop_first) + + group = MetricGroup( + {"precision": Precision(output_transform=drop_first), "accuracy": Accuracy(output_transform=drop_first)} + ) + + y_pred = torch.randint(0, 2, (100,)) + y = torch.randint(0, 2, (100,)) + + precision.update(drop_first(drop_first((y_pred, y)))) + accuracy.update(drop_first(drop_first((y_pred, y)))) + group.update(drop_first((y_pred, y))) + + assert precision.state_dict() == group.metrics["precision"].state_dict() + assert accuracy.state_dict() == group.metrics["accuracy"].state_dict() + + +def test_compute(): + precision = Precision() + accuracy = Accuracy() + + group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()}) + + for _ in range(3): + y_pred = torch.randint(0, 2, (100,)) + y = torch.randint(0, 2, (100,)) + + precision.update((y_pred, y)) + accuracy.update((y_pred, y)) + group.update((y_pred, y)) + + assert group.compute() == {"precision": precision.compute(), "accuracy": accuracy.compute()} + + precision.reset() + accuracy.reset() + group.reset() + + assert precision.state_dict() == group.metrics["precision"].state_dict() + assert accuracy.state_dict() == group.metrics["accuracy"].state_dict() + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + rank = idist.get_rank() + torch.manual_seed(12 + rank) + + n_epochs = 3 + n_iters = 5 + batch_size = 10 + device = idist.device() + + y_true = torch.randint(0, 2, size=(n_iters * batch_size,)).to(device) + y_pred = torch.randint(0, 2, (n_iters * batch_size,)).to(device) + + def update(_, i): + return ( + y_pred[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + + engine = Engine(update) + + precision = Precision() + precision.attach(engine, "precision") + + accuracy = Accuracy() + accuracy.attach(engine, "accuracy") + + group = MetricGroup({"eval_metrics.accuracy": Accuracy(), "eval_metrics.precision": Precision()}) + group.attach(engine, "eval_metrics") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + assert "eval_metrics" in engine.state.metrics + assert "eval_metrics.accuracy" in engine.state.metrics + assert "eval_metrics.precision" in engine.state.metrics + + assert engine.state.metrics["eval_metrics"] == { + "eval_metrics.accuracy": engine.state.metrics["accuracy"], + "eval_metrics.precision": engine.state.metrics["precision"], + } + assert engine.state.metrics["eval_metrics.accuracy"] == engine.state.metrics["accuracy"] + assert engine.state.metrics["eval_metrics.precision"] == engine.state.metrics["precision"]