diff --git a/docs/en/api/metrics.rst b/docs/en/api/metrics.rst index 0e9d75fd..ab7bb3f7 100644 --- a/docs/en/api/metrics.rst +++ b/docs/en/api/metrics.rst @@ -48,3 +48,4 @@ Metrics ConnectivityError DOTAMeanAP ROUGE + KeypointEndPointError diff --git a/docs/zh_cn/api/metrics.rst b/docs/zh_cn/api/metrics.rst index 0e9d75fd..ab7bb3f7 100644 --- a/docs/zh_cn/api/metrics.rst +++ b/docs/zh_cn/api/metrics.rst @@ -48,3 +48,4 @@ Metrics ConnectivityError DOTAMeanAP ROUGE + KeypointEndPointError diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 171e7196..0d845d30 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -12,6 +12,7 @@ from .f1_score import F1Score from .gradient_error import GradientError from .hmean_iou import HmeanIoU +from .keypoint_epe import KeypointEndPointError from .mae import MeanAbsoluteError from .matting_mse import MattingMeanSquaredError from .mean_iou import MeanIoU @@ -36,7 +37,7 @@ 'StructuralSimilarity', 'SignalNoiseRatio', 'MultiLabelMetric', 'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP', 'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError', - 'ConnectivityError', 'ROUGE' + 'ConnectivityError', 'ROUGE', 'KeypointEndPointError' ] _deprecated_msg = ( diff --git a/mmeval/metrics/keypoint_epe.py b/mmeval/metrics/keypoint_epe.py new file mode 100644 index 00000000..f2fe914a --- /dev/null +++ b/mmeval/metrics/keypoint_epe.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +import numpy as np +from typing import Dict, Sequence + +from mmeval.core.base_metric import BaseMetric +from .utils import calc_distances + +logger = logging.getLogger(__name__) + + +def keypoint_epe_accuracy(pred: np.ndarray, gt: np.ndarray, + mask: np.ndarray) -> float: + """Calculate the end-point error. + + Note: + - instance number: N + - keypoint number: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + + Returns: + float: Average end-point error. + """ + + distances = calc_distances( + pred, gt, mask, + np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32)) + distance_valid = distances[distances != -1] + return distance_valid.sum() / max(1, len(distance_valid)) + + +class KeypointEndPointError(BaseMetric): + """EPE evaluation metric. + + Calculate the end-point error (EPE) of keypoints. + + Note: + - length of dataset: N + - num_keypoints: K + - number of keypoint dimensions: D (typically D = 2) + + Examples: + + >>> from mmeval.metrics import KeypointEndPointError + >>> import numpy as np + >>> output = np.array([[[10., 4.], + ... [10., 18.], + ... [ 0., 0.], + ... [40., 40.], + ... [20., 10.]]]) + >>> target = np.array([[[10., 0.], + ... [10., 10.], + ... [ 0., -1.], + ... [30., 30.], + ... [ 0., 10.]]]) + >>> keypoints_visible = np.array([[True, True, False, True, True]]) + >>> predictions = [{'coords': output}] + >>> groundtruths = [{'coords': target, 'mask': keypoints_visible}] + >>> epe_metric = KeypointEndPointError() + >>> epe_metric(predictions, groundtruths) + {'EPE': 11.535533905029297} + """ + + def add(self, predictions: Sequence[Dict], groundtruths: Sequence[Dict]) -> None: # type: ignore # yapf: disable # noqa: E501 + """Process one batch of predictions and groundtruths and add the + intermediate results to `self._results`. + + Args: + predictions (Sequence[dict]): Predictions from the model. + Each prediction dict has the following keys: + + - coords (np.ndarray, [1, K, D]): predicted keypoints + coordinates + + groundtruths (Sequence[dict]): The ground truth labels. + Each groundtruth dict has the following keys: + + - coords (np.ndarray, [1, K, D]): ground truth keypoints + coordinates + - mask (np.ndarray, [1, K]): ground truth keypoints_visible + """ + for prediction, groundtruth in zip(predictions, groundtruths): + self._results.append((prediction, groundtruth)) + + def compute_metric(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + # split gt and prediction list + preds, gts = zip(*results) + + # pred_coords: [N, K, D] + pred_coords = np.concatenate([pred['coords'] for pred in preds]) + # gt_coords: [N, K, D] + gt_coords = np.concatenate([gt['coords'] for gt in gts]) + # mask: [N, K] + mask = np.concatenate([gt['mask'] for gt in gts]) + + logger.info(f'Evaluating {self.__class__.__name__}...') + + epe = keypoint_epe_accuracy(pred_coords, gt_coords, mask) + + return {'EPE': epe} diff --git a/tests/test_metrics/test_keypoint_epe.py b/tests/test_metrics/test_keypoint_epe.py new file mode 100644 index 00000000..c25539df --- /dev/null +++ b/tests/test_metrics/test_keypoint_epe.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from unittest import TestCase + +from mmeval.metrics import KeypointEndPointError + + +class TestKeypointEndPointError(TestCase): + + def setUp(self): + """Setup some variables which are used in every test method. + + TestCase calls functions in this order: setUp() -> testMethod() -> + tearDown() -> cleanUp() + """ + self.output = np.zeros((1, 5, 2)) + self.target = np.zeros((1, 5, 2)) + # first channel + self.output[0, 0] = [10, 4] + self.target[0, 0] = [10, 0] + # second channel + self.output[0, 1] = [10, 18] + self.target[0, 1] = [10, 10] + # third channel + self.output[0, 2] = [0, 0] + self.target[0, 2] = [0, -1] + # fourth channel + self.output[0, 3] = [40, 40] + self.target[0, 3] = [30, 30] + # fifth channel + self.output[0, 4] = [20, 10] + self.target[0, 4] = [0, 10] + + self.keypoints_visible = np.array([[True, True, False, True, True]]) + + def test_epe_evaluate(self): + """test EPE evaluation metric.""" + # case 1: test normal use case + epe_metric = KeypointEndPointError() + + prediction = {'coords': self.output} + groundtruth = {'coords': self.target, 'mask': self.keypoints_visible} + predictions = [prediction] + groundtruths = [groundtruth] + + epe_results = epe_metric(predictions, groundtruths) + self.assertAlmostEqual(epe_results['EPE'], 11.5355339) + + # case 2: use ``add`` multiple times then ``compute`` + epe_metric._results = [] + preds1 = [{'coords': self.output[:3]}] + preds2 = [{'coords': self.output[3:]}] + gts1 = [{ + 'coords': self.target[:3], + 'mask': self.keypoints_visible[:3] + }] + gts2 = [{ + 'coords': self.target[3:], + 'mask': self.keypoints_visible[3:] + }] + + epe_metric.add(preds1, gts1) + epe_metric.add(preds2, gts2) + + epe_results = epe_metric.compute_metric(epe_metric._results) + self.assertAlmostEqual(epe_results['EPE'], 11.5355339)