-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add KeypointEndPointError (#92)
* [Feat] Add KeypointEndPointError * Update example & docs * Update examples & tests according to the comments * Update support_matrix doc * Restore support_matrix doc * Apply suggestions from code review * Update mmeval/metrics/keypoint_epe.py --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
- Loading branch information
Showing
5 changed files
with
185 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,3 +48,4 @@ Metrics | |
ConnectivityError | ||
DOTAMeanAP | ||
ROUGE | ||
KeypointEndPointError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,3 +48,4 @@ Metrics | |
ConnectivityError | ||
DOTAMeanAP | ||
ROUGE | ||
KeypointEndPointError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import logging | ||
import numpy as np | ||
from typing import Dict, Sequence | ||
|
||
from mmeval.core.base_metric import BaseMetric | ||
from .utils import calc_distances | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def keypoint_epe_accuracy(pred: np.ndarray, gt: np.ndarray, | ||
mask: np.ndarray) -> float: | ||
"""Calculate the end-point error. | ||
Note: | ||
- instance number: N | ||
- keypoint number: K | ||
Args: | ||
pred (np.ndarray[N, K, 2]): Predicted keypoint location. | ||
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. | ||
mask (np.ndarray[N, K]): Visibility of the target. False for invisible | ||
joints, and True for visible. Invisible joints will be ignored for | ||
accuracy calculation. | ||
Returns: | ||
float: Average end-point error. | ||
""" | ||
|
||
distances = calc_distances( | ||
pred, gt, mask, | ||
np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32)) | ||
distance_valid = distances[distances != -1] | ||
return distance_valid.sum() / max(1, len(distance_valid)) | ||
|
||
|
||
class KeypointEndPointError(BaseMetric): | ||
"""EPE evaluation metric. | ||
Calculate the end-point error (EPE) of keypoints. | ||
Note: | ||
- length of dataset: N | ||
- num_keypoints: K | ||
- number of keypoint dimensions: D (typically D = 2) | ||
Examples: | ||
>>> from mmeval.metrics import KeypointEndPointError | ||
>>> import numpy as np | ||
>>> output = np.array([[[10., 4.], | ||
... [10., 18.], | ||
... [ 0., 0.], | ||
... [40., 40.], | ||
... [20., 10.]]]) | ||
>>> target = np.array([[[10., 0.], | ||
... [10., 10.], | ||
... [ 0., -1.], | ||
... [30., 30.], | ||
... [ 0., 10.]]]) | ||
>>> keypoints_visible = np.array([[True, True, False, True, True]]) | ||
>>> predictions = [{'coords': output}] | ||
>>> groundtruths = [{'coords': target, 'mask': keypoints_visible}] | ||
>>> epe_metric = KeypointEndPointError() | ||
>>> epe_metric(predictions, groundtruths) | ||
{'EPE': 11.535533905029297} | ||
""" | ||
|
||
def add(self, predictions: Sequence[Dict], groundtruths: Sequence[Dict]) -> None: # type: ignore # yapf: disable # noqa: E501 | ||
"""Process one batch of predictions and groundtruths and add the | ||
intermediate results to `self._results`. | ||
Args: | ||
predictions (Sequence[dict]): Predictions from the model. | ||
Each prediction dict has the following keys: | ||
- coords (np.ndarray, [1, K, D]): predicted keypoints | ||
coordinates | ||
groundtruths (Sequence[dict]): The ground truth labels. | ||
Each groundtruth dict has the following keys: | ||
- coords (np.ndarray, [1, K, D]): ground truth keypoints | ||
coordinates | ||
- mask (np.ndarray, [1, K]): ground truth keypoints_visible | ||
""" | ||
for prediction, groundtruth in zip(predictions, groundtruths): | ||
self._results.append((prediction, groundtruth)) | ||
|
||
def compute_metric(self, results: list) -> Dict[str, float]: | ||
"""Compute the metrics from processed results. | ||
Args: | ||
results (list): The processed results of each batch. | ||
Returns: | ||
Dict[str, float]: The computed metrics. The keys are the names of | ||
the metrics, and the values are corresponding results. | ||
""" | ||
# split gt and prediction list | ||
preds, gts = zip(*results) | ||
|
||
# pred_coords: [N, K, D] | ||
pred_coords = np.concatenate([pred['coords'] for pred in preds]) | ||
# gt_coords: [N, K, D] | ||
gt_coords = np.concatenate([gt['coords'] for gt in gts]) | ||
# mask: [N, K] | ||
mask = np.concatenate([gt['mask'] for gt in gts]) | ||
|
||
logger.info(f'Evaluating {self.__class__.__name__}...') | ||
|
||
epe = keypoint_epe_accuracy(pred_coords, gt_coords, mask) | ||
|
||
return {'EPE': epe} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import numpy as np | ||
from unittest import TestCase | ||
|
||
from mmeval.metrics import KeypointEndPointError | ||
|
||
|
||
class TestKeypointEndPointError(TestCase): | ||
|
||
def setUp(self): | ||
"""Setup some variables which are used in every test method. | ||
TestCase calls functions in this order: setUp() -> testMethod() -> | ||
tearDown() -> cleanUp() | ||
""" | ||
self.output = np.zeros((1, 5, 2)) | ||
self.target = np.zeros((1, 5, 2)) | ||
# first channel | ||
self.output[0, 0] = [10, 4] | ||
self.target[0, 0] = [10, 0] | ||
# second channel | ||
self.output[0, 1] = [10, 18] | ||
self.target[0, 1] = [10, 10] | ||
# third channel | ||
self.output[0, 2] = [0, 0] | ||
self.target[0, 2] = [0, -1] | ||
# fourth channel | ||
self.output[0, 3] = [40, 40] | ||
self.target[0, 3] = [30, 30] | ||
# fifth channel | ||
self.output[0, 4] = [20, 10] | ||
self.target[0, 4] = [0, 10] | ||
|
||
self.keypoints_visible = np.array([[True, True, False, True, True]]) | ||
|
||
def test_epe_evaluate(self): | ||
"""test EPE evaluation metric.""" | ||
# case 1: test normal use case | ||
epe_metric = KeypointEndPointError() | ||
|
||
prediction = {'coords': self.output} | ||
groundtruth = {'coords': self.target, 'mask': self.keypoints_visible} | ||
predictions = [prediction] | ||
groundtruths = [groundtruth] | ||
|
||
epe_results = epe_metric(predictions, groundtruths) | ||
self.assertAlmostEqual(epe_results['EPE'], 11.5355339) | ||
|
||
# case 2: use ``add`` multiple times then ``compute`` | ||
epe_metric._results = [] | ||
preds1 = [{'coords': self.output[:3]}] | ||
preds2 = [{'coords': self.output[3:]}] | ||
gts1 = [{ | ||
'coords': self.target[:3], | ||
'mask': self.keypoints_visible[:3] | ||
}] | ||
gts2 = [{ | ||
'coords': self.target[3:], | ||
'mask': self.keypoints_visible[3:] | ||
}] | ||
|
||
epe_metric.add(preds1, gts1) | ||
epe_metric.add(preds2, gts2) | ||
|
||
epe_results = epe_metric.compute_metric(epe_metric._results) | ||
self.assertAlmostEqual(epe_results['EPE'], 11.5355339) |