From f57f53117ce91c5e70a9a51869e9a3cca021d560 Mon Sep 17 00:00:00 2001 From: huyingfan Date: Tue, 17 Jan 2023 14:28:08 +0800 Subject: [PATCH 1/7] refactor single and multi label metrics --- mmeval/metrics/__init__.py | 18 +- mmeval/metrics/average_precision.py | 384 ++++++++ mmeval/metrics/multi_label.py | 865 ------------------ ...e_label.py => precision_recall_f1score.py} | 424 ++++++++- mmeval/metrics/utils/__init__.py | 4 +- mmeval/metrics/utils/multi_label.py | 164 ++++ ...t_multi_label_precision_recall_f1score.py} | 31 +- .../test_precision_recall_f1score.py | 38 + ..._single_label_precision_recall_f1score.py} | 41 +- 9 files changed, 1049 insertions(+), 920 deletions(-) create mode 100644 mmeval/metrics/average_precision.py delete mode 100644 mmeval/metrics/multi_label.py rename mmeval/metrics/{single_label.py => precision_recall_f1score.py} (53%) create mode 100644 mmeval/metrics/utils/multi_label.py rename tests/test_metrics/{test_multi_label.py => test_multi_label_precision_recall_f1score.py} (91%) create mode 100644 tests/test_metrics/test_precision_recall_f1score.py rename tests/test_metrics/{test_single_label_metric.py => test_single_label_precision_recall_f1score.py} (85%) diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 9ae21aef..533a54a1 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -4,6 +4,7 @@ from .accuracy import Accuracy from .ava_map import AVAMeanAP +from .average_precision import AveragePrecision from .bleu import BLEU from .coco_detection import COCODetection from .connectivity_error import ConnectivityError @@ -19,16 +20,17 @@ from .matting_mse import MattingMeanSquaredError from .mean_iou import MeanIoU from .mse import MeanSquaredError -from .multi_label import AveragePrecision, MultiLabelMetric from .niqe import NaturalImageQualityEvaluator from .oid_map import OIDMeanAP from .pck_accuracy import JhmdbPCKAccuracy, MpiiPCKAccuracy, PCKAccuracy from .perplexity import Perplexity +from .precision_recall_f1score import (MultiLabelPrecsionRecallF1score, + PrecsionRecallF1score, + SingleLabelPrecsionRecallF1score) from .proposal_recall import ProposalRecall from .psnr import PeakSignalNoiseRatio from .rouge import ROUGE from .sad import SumAbsoluteDifferences -from .single_label import SingleLabelMetric from .snr import SignalNoiseRatio from .ssim import StructuralSimilarity from .voc_map import VOCMeanAP @@ -36,15 +38,15 @@ __all__ = [ 'Accuracy', 'MeanIoU', 'VOCMeanAP', 'OIDMeanAP', 'EndPointError', - 'F1Score', 'HmeanIoU', 'SingleLabelMetric', 'COCODetection', 'PCKAccuracy', - 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall', - 'PeakSignalNoiseRatio', 'MeanAbsoluteError', 'MeanSquaredError', - 'StructuralSimilarity', 'SignalNoiseRatio', 'MultiLabelMetric', - 'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP', + 'F1Score', 'HmeanIoU', 'COCODetection', 'PCKAccuracy', 'MpiiPCKAccuracy', + 'JhmdbPCKAccuracy', 'ProposalRecall', 'PeakSignalNoiseRatio', + 'MeanAbsoluteError', 'MeanSquaredError', 'StructuralSimilarity', + 'SignalNoiseRatio', 'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP', 'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError', 'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError', 'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator', - 'WordAccuracy' + 'WordAccuracy', 'PrecsionRecallF1score', + 'SingleLabelPrecsionRecallF1score', 'MultiLabelPrecsionRecallF1score' ] _deprecated_msg = ( diff --git a/mmeval/metrics/average_precision.py b/mmeval/metrics/average_precision.py new file mode 100644 index 00000000..496d02e9 --- /dev/null +++ b/mmeval/metrics/average_precision.py @@ -0,0 +1,384 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from typing import (TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, + Union, overload) + +from mmeval.core.base_metric import BaseMetric +from mmeval.core.dispatcher import dispatch +from mmeval.metrics.utils import MultiLabelMixin, format_data +from mmeval.utils import try_import + +if TYPE_CHECKING: + import oneflow + import oneflow as flow + import torch +else: + torch = try_import('torch') + flow = try_import('oneflow') + +NUMPY_IMPL_HINTS = Tuple[Union[np.ndarray, np.number], Union[np.ndarray, + np.number]] +TORCH_IMPL_HINTS = Tuple['torch.Tensor', 'torch.Tensor'] +ONEFLOW_IMPL_HINTS = Tuple['oneflow.Tensor', 'oneflow.Tensor'] +BUILTIN_IMPL_HINTS = Tuple[Union[int, Sequence[Union[int, float]]], + Union[int, Sequence[int]]] + + +def _average_precision_torch(preds: 'torch.Tensor', labels: 'torch.Tensor', + average) -> 'torch.Tensor': + r"""Calculate the average precision for torch. + + AP summarizes a precision-recall curve as the weighted mean of maximum + precisions obtained for any r'>r, where r is the recall: + + .. math:: + \text{AP} = \sum_n (R_n - R_{n-1}) P_n + + Note that no approximation is involved since the curve is piecewise + constant. + + Args: + preds (torch.Tensor): The model prediction with shape + ``(N, num_classes)``. + labels (torch.Tensor): The target of predictions with shape + ``(N, num_classes)``. + + Returns: + torch.Tensor: average precision result. + """ + # sort examples along classes + sorted_pred_inds = torch.argsort(preds, dim=0, descending=True) + sorted_target = torch.gather(labels, 0, sorted_pred_inds) + + # get indexes when gt_true is positive + pos_inds = sorted_target == 1 + + # Calculate cumulative tp case numbers + tps = torch.cumsum(pos_inds, 0) + total_pos = tps[-1].clone() # the last of tensor may change later + + # Calculate cumulative tp&fp(pred_poss) case numbers + pred_pos_nums = torch.arange(1, len(sorted_target) + 1).to(preds.device) + + tps[torch.logical_not(pos_inds)] = 0 + precision = tps / pred_pos_nums.unsqueeze(-1).float() # divide along rows + ap = torch.sum(precision, 0) / torch.clamp(total_pos, min=1) + + if average == 'macro': + return ap.mean() * 100.0 + else: + return ap * 100 + + +def _average_precision_oneflow(preds: 'oneflow.Tensor', + labels: 'oneflow.Tensor', + average) -> 'oneflow.Tensor': + r"""Calculate the average precision for oneflow. + + AP summarizes a precision-recall curve as the weighted mean of maximum + precisions obtained for any r'>r, where r is the recall: + + .. math:: + \text{AP} = \sum_n (R_n - R_{n-1}) P_n + + Note that no approximation is involved since the curve is piecewise + constant. + + Args: + preds (oneflow.Tensor): The model prediction with shape + ``(N, num_classes)``. + labels (oneflow.Tensor): The target of predictions with shape + ``(N, num_classes)``. + + Returns: + oneflow.Tensor: average precision result. + """ + # sort examples along classes + sorted_pred_inds = flow.argsort(preds, dim=0, descending=True) + sorted_target = flow.gather(labels, 0, sorted_pred_inds) + + # get indexes when gt_true is positive + pos_inds = sorted_target == 1 + + # Calculate cumulative tp case numbers + tps = flow.cumsum(pos_inds, 0) + total_pos = tps[-1].clone() # the last of tensor may change later + + # Calculate cumulative tp&fp(pred_poss) case numbers + pred_pos_nums = flow.arange(1, len(sorted_target) + 1).to(preds.device) + + tps[flow.logical_not(pos_inds)] = 0 + precision = tps / pred_pos_nums.unsqueeze(-1).float() # divide along rows + ap = flow.sum(precision, 0) / flow.clamp(total_pos, min=1) + + if average == 'macro': + return ap.mean() * 100.0 + else: + return ap * 100 + + +def _average_precision(preds: np.ndarray, labels: np.ndarray, + average) -> np.ndarray: + r"""Calculate the average precision for numpy. + + AP summarizes a precision-recall curve as the weighted mean of maximum + precisions obtained for any r'>r, where r is the recall: + + .. math:: + \text{AP} = \sum_n (R_n - R_{n-1}) P_n + + Note that no approximation is involved since the curve is piecewise + constant. + + Args: + preds (np.ndarray): The model prediction with shape + ``(N, num_classes)``. + labels (np.ndarray): The target of predictions with shape + ``(N, num_classes)``. + + Returns: + np.ndarray: average precision result. + """ + # sort examples along classes + sorted_pred_inds = np.argsort(-preds, axis=0) + sorted_target = np.take_along_axis(labels, sorted_pred_inds, axis=0) + + # get indexes when gt_true is positive + pos_inds = sorted_target == 1 + + # Calculate cumulative tp case numbers + tps = np.cumsum(pos_inds, 0) + total_pos = tps[-1].copy() # the last of tensor may change later + + # Calculate cumulative tp&fp(pred_poss) case numbers + pred_pos_nums = np.arange(1, len(sorted_target) + 1) + + tps[np.logical_not(pos_inds)] = 0 + precision = np.divide( + tps, np.expand_dims(pred_pos_nums, -1), dtype=np.float32) + ap = np.divide( + np.sum(precision, 0), np.clip(total_pos, 1, np.inf), dtype=np.float32) + + if average == 'macro': + return ap.mean() * 100.0 + else: + return ap * 100 + + +class AveragePrecision(MultiLabelMixin, BaseMetric): + """Calculate the average precision with respect of classes. + + Args: + average (str, optional): The average method. It supports two modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `None`: Return scores of all categories. + + Defaults to "macro". + + References + ---------- + .. [1] `Wikipedia entry for the Average precision + `_ + + Examples: + + >>> from mmeval import AveragePrecision + >>> average_precision = AveragePrecision() + + Use Builtin implementation with label-format labels: + + >>> preds = [[0.9, 0.8, 0.3, 0.2], + [0.1, 0.2, 0.2, 0.1], + [0.7, 0.5, 0.9, 0.3], + [0.8, 0.1, 0.1, 0.2]] + >>> labels = [[0, 1], [1], [2], [0]] + >>> average_precision(preds, labels) + {'mAP': 70.833..} + + Use Builtin implementation with one-hot encoding labels: + + >>> preds = [[0.9, 0.8, 0.3, 0.2], + [0.1, 0.2, 0.2, 0.1], + [0.7, 0.5, 0.9, 0.3], + [0.8, 0.1, 0.1, 0.2]] + >>> labels = [[1, 1, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [1, 0, 0, 0]] + >>> average_precision(preds, labels) + {'mAP': 70.833..} + + Use NumPy implementation with label-format labels: + + >>> import numpy as np + >>> preds = np.array([[0.9, 0.8, 0.3, 0.2], + [0.1, 0.2, 0.2, 0.1], + [0.7, 0.5, 0.9, 0.3], + [0.8, 0.1, 0.1, 0.2]]) + >>> labels = [np.array([0, 1]), np.array([1]), np.array([2]), np.array([0])] # noqa + >>> average_precision(preds, labels) + {'mAP': 70.833..} + + Use PyTorch implementation with one-hot encoding labels:: + + >>> import torch + >>> preds = torch.Tensor([[0.9, 0.8, 0.3, 0.2], + [0.1, 0.2, 0.2, 0.1], + [0.7, 0.5, 0.9, 0.3], + [0.8, 0.1, 0.1, 0.2]]) + >>> labels = torch.Tensor([[1, 1, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [1, 0, 0, 0]]) + >>> average_precision(preds, labels) + {'mAP': 70.833..} + + Computing with `None` average mode: + + >>> preds = np.array([[0.9, 0.8, 0.3, 0.2], + [0.1, 0.2, 0.2, 0.1], + [0.7, 0.5, 0.9, 0.3], + [0.8, 0.1, 0.1, 0.2]]) + >>> labels = [np.array([0, 1]), np.array([1]), np.array([2]), np.array([0])] # noqa + >>> average_precision = AveragePrecision(average=None) + >>> average_precision(preds, labels) + {'AP_classwise': [100.0, 83.33, 100.00, 0.0]} # rounded results + + Accumulate batch: + + >>> for i in range(10): + ... preds = torch.randint(0, 4, size=(100, 10)) + ... labels = torch.randint(0, 4, size=(100, )) + ... average_precision.add(preds, labels) + >>> average_precision.compute() # doctest: +SKIP + """ + + def __init__(self, average: Optional[str] = 'macro', **kwargs) -> None: + super().__init__(**kwargs) + average_options = ['macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specify from {average_options}.' + self.average = average + + def add(self, preds: Sequence, labels: Sequence) -> None: # type: ignore # yapf: disable # noqa: E501 + """Add the intermediate results to `self._results`. + + Args: + preds (Sequence): Predictions from the model. It should + be scores of every class (N, C). + labels (Sequence): The ground truth labels. It should be (N, ). + """ + for pred, target in zip(preds, labels): + self._results.append((pred, target)) + + def _format_metric_results(self, ap): + """Format the given metric results into a dictionary. + + Args: + results (list): Results of precision, recall, f1 and support. + + Returns: + dict: The formatted dictionary. + """ + result_metrics = dict() + + if self.average is None: + _result = ap[0].tolist() + result_metrics['AP_classwise'] = [round(_r, 4) for _r in _result] + else: + result_metrics['mAP'] = round(ap[0].item(), 4) + + return result_metrics + + @overload + @dispatch + def _compute_metric(self, preds: Sequence['torch.Tensor'], + labels: Sequence['torch.Tensor']) -> List[List]: + """A PyTorch implementation that computes the metric.""" + + preds = torch.stack(preds) + num_classes = preds.shape[1] + labels = format_data(labels, num_classes, self._label_is_onehot).long() + + assert preds.shape[0] == labels.shape[0], \ + 'Number of samples does not match between preds' \ + f'({preds.shape[0]}) and labels ({labels.shape[0]}).' + + return _average_precision_torch(preds, labels, self.average) + + @overload # type: ignore + @dispatch + def _compute_metric( # type: ignore + self, preds: Sequence['oneflow.Tensor'], + labels: Sequence['oneflow.Tensor']) -> List[List]: + """A OneFlow implementation that computes the metric.""" + + preds = flow.stack(preds) + num_classes = preds.shape[1] + labels = format_data(labels, num_classes, self._label_is_onehot).long() + + assert preds.shape[0] == labels.shape[0], \ + 'Number of samples does not match between preds' \ + f'({preds.shape[0]}) and labels ({labels.shape[0]}).' + + return _average_precision_oneflow(preds, labels, self.average) + + @overload + @dispatch + def _compute_metric( + self, preds: Sequence[Union[int, Sequence[Union[int, float]]]], + labels: Sequence[Union[int, Sequence[int]]]) -> List[List]: + """A Builtin implementation that computes the metric.""" + + return self._compute_metric([np.array(pred) for pred in preds], + [np.array(target) for target in labels]) + + @dispatch + def _compute_metric( + self, preds: Sequence[Union[np.ndarray, np.number]], + labels: Sequence[Union[np.ndarray, np.number]]) -> List[List]: + """A NumPy implementation that computes the metric.""" + + preds = np.stack(preds) + num_classes = preds.shape[1] + labels = format_data(labels, num_classes, + self._label_is_onehot).astype(np.int64) + + assert preds.shape[0] == labels.shape[0], \ + 'Number of samples does not match between preds' \ + f'({preds.shape[0]}) and labels ({labels.shape[0]}).' + + return _average_precision(preds, labels, self.average) + + def compute_metric( + self, results: List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS, + ONEFLOW_IMPL_HINTS, BUILTIN_IMPL_HINTS]] + ) -> Dict[str, float]: + """Compute the metric. + + Currently, there are 3 implementations of this method: NumPy and + PyTorch and OneFlow. Which implementation to use is determined by the + type of the calling parameters. e.g. `numpy.ndarray` or + `torch.Tensor`, `oneflow.Tensor`. + + This method would be invoked in `BaseMetric.compute` after distributed + synchronization. + Args: + results (List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS, + ONEFLOW_IMPL_HINTS]]): A list of tuples that consisting the + prediction and label. This list has already been synced across + all ranks. + + Returns: + Dict[str, float]: The computed metric. + """ + preds = [res[0] for res in results] + labels = [res[1] for res in results] + assert self._pred_is_onehot is False, '`self._pred_is_onehot` should' \ + f'be `False` for {self.__class__.__name__}, because scores are' \ + 'necessary for compute the metric.' + metric_results = self._compute_metric(preds, labels) + return self._format_metric_results(metric_results) diff --git a/mmeval/metrics/multi_label.py b/mmeval/metrics/multi_label.py deleted file mode 100644 index 5de5114e..00000000 --- a/mmeval/metrics/multi_label.py +++ /dev/null @@ -1,865 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np -import warnings -from typing import (TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, - Union, overload) - -from mmeval.core.base_metric import BaseMetric -from mmeval.core.dispatcher import dispatch -from mmeval.utils import try_import -from .single_label import _precision_recall_f1_support - -if TYPE_CHECKING: - import oneflow - import oneflow as flow - import torch -else: - torch = try_import('torch') - flow = try_import('oneflow') - -NUMPY_IMPL_HINTS = Tuple[Union[np.ndarray, np.number], Union[np.ndarray, - np.number]] -TORCH_IMPL_HINTS = Tuple['torch.Tensor', 'torch.Tensor'] -ONEFLOW_IMPL_HINTS = Tuple['oneflow.Tensor', 'oneflow.Tensor'] -BUILTIN_IMPL_HINTS = Tuple[Union[int, Sequence[Union[int, float]]], - Union[int, Sequence[int]]] - - -def label_to_onehot( - label: Union[np.ndarray, 'torch.Tensor', - 'oneflow.Tensor'], num_classes: int -) -> Union[np.ndarray, 'torch.Tensor', 'oneflow.Tensor']: - """Convert the label-format input to one-hot encodings. - - Args: - label (torch.Tensor or oneflow.Tensor or np.ndarray): - The label-format input. The format of item must be label-format. - num_classes (int): The number of classes. - - Return: - torch.Tensor or oneflow.Tensor or np.ndarray: - The converted one-hot encodings. - """ - if torch and isinstance(label, torch.Tensor): - label = label.long() - onehot = label.new_zeros((num_classes, )) - elif flow and isinstance(label, flow.Tensor): - label = label.long() - onehot = label.new_zeros((num_classes, )) - else: - label = label.astype(np.int64) - onehot = np.zeros((num_classes, ), dtype=np.int64) - assert label.max().item() < num_classes, \ - 'Max index is out of `num_classes` {num_classes}' - assert label.min().item() >= 0 - onehot[label] = 1 - return onehot - - -def format_data( - data: Union[Sequence[Union[np.ndarray, 'torch.Tensor', 'oneflow.Tensor']], - np.ndarray, 'torch.Tensor', 'oneflow.Tensor'], - num_classes: int, - is_onehot: bool = False -) -> Union[np.ndarray, 'torch.Tensor', 'oneflow.Tensor']: - """Format data from different inputs such as prediction scores, label- - format data and one-hot encodings into the same output shape of `(N, - num_classes)`. - - Args: - data (Union[Sequence[np.ndarray, 'torch.Tensor', 'oneflow.Tensor'], - np.ndarray, 'torch.Tensor', 'oneflow.Tensor']): - The input data of prediction or labels. - num_classes (int): The number of classes. - is_onehot (bool): Whether the data is one-hot encodings. - - Return: - torch.Tensor or oneflow.Tensor or np.ndarray: - One-hot encodings or predict scores. - """ - if torch and isinstance(data[0], torch.Tensor): - stack_func = torch.stack - elif flow and isinstance(data[0], flow.Tensor): - stack_func = flow.stack - elif isinstance(data[0], (np.ndarray, np.number)): - stack_func = np.stack - else: - raise NotImplementedError(f'Data type of {type(data[0])}' - 'is not supported.') - - try: - # try stack scores or one-hot indices directly - formated_data = stack_func(data) - # all assertions below is to find labels that are - # raw indices which should be caught in exception - # to convert to one-hot indices. - # - # 1. all raw indices has only 1 dims - assert formated_data.ndim == 2 - # 2. all raw indices has the same dims - assert formated_data.shape[1] == num_classes - # 3. all raw indices has the same dims as num_classes - # then max indices should greater than 1 for num_classes > 2 - assert formated_data.max() <= 1 - # 4. corner case, num_classes=2, then one-hot indices - # and raw indices are undistinguishable, for instance: - # [[0, 1], [0, 1]] can be one-hot indices of 2 positives - # or raw indices of 4 positives. - # Extra induction is needed. - if num_classes == 2: - warnings.warn('Ambiguous data detected, reckoned as scores' - ' or label-format data as defaults. Please set ' - 'parms related to `is_onehot` if use one-hot ' - 'encoding data to compute metrics.') - assert is_onehot - # Error corresponds to np, torch, oneflow, stack_func respectively - except (ValueError, RuntimeError, AssertionError): - # convert label-format inputs to one-hot encodings - formated_data = stack_func( - [label_to_onehot(sample, num_classes) for sample in data]) - return formated_data - - -class MultiLabelMixin: - """A Mixin for Multilabel Metrics to clarify whether the input is one-hot - encodings or label-format inputs for corner case with minimal user - awareness.""" - - def __init__(self, *args, **kwargs) -> None: - # pass arguments for multiple inheritances - super().__init__(*args, **kwargs) # type: ignore - self._pred_is_onehot = False - self._label_is_onehot = False - - @property - def pred_is_onehot(self) -> bool: - """Whether prediction is one-hot encodings. - - Only works for corner case when num_classes=2 to distinguish one-hot - encodings or label-format. - """ - return self._pred_is_onehot - - @pred_is_onehot.setter - def pred_is_onehot(self, is_onehot: bool): - """Set a flag of whether prediction is one-hot encodings. - - Only works for corner case when num_classes=2 to distinguish one-hot - encodings or label-format. - """ - self._pred_is_onehot = is_onehot - - @property - def label_is_onehot(self) -> bool: - """Whether label is one-hot encodings. - - Only works for corner case when num_classes=2 to distinguish one-hot - encodings or label-format. - """ - return self._label_is_onehot - - @label_is_onehot.setter - def label_is_onehot(self, is_onehot: bool): - """Set a flag of whether label is one-hot encodings. - - Only works for corner case when num_classes=2 to distinguish one-hot - encodings or label-format. - """ - self._label_is_onehot = is_onehot - - -class MultiLabelMetric(MultiLabelMixin, BaseMetric): - """A collection of metrics for multi-label multi-class classification task - based on confusion matrix. - - It includes precision, recall, f1-score and support. - - Args: - num_classes (int): Number of classes. Needed for different inputs - as extra check. - thr (float, optional): Predictions with scores under the thresholds - are considered as negative. Defaults to None. - topk (int, optional): Predictions with the k-th highest scores are - considered as positive. Defaults to None. - items (Sequence[str]): The detailed metric items to evaluate. Here is - the available options: - - - `"precision"`: The ratio tp / (tp + fp) where tp is the - number of true positives and fp the number of false - positives. - - `"recall"`: The ratio tp / (tp + fn) where tp is the number - of true positives and fn the number of false negatives. - - `"f1-score"`: The f1-score is the harmonic mean of the - precision and recall. - - `"support"`: The total number of positive of each category - in the target. - - Defaults to ('precision', 'recall', 'f1-score'). - average (str | None): The average method. It supports three average - modes: - - - `"macro"`: Calculate metrics for each category, and calculate - the mean value over all categories. - - `"micro"`: Calculate metrics globally by counting the total - true positives, false negatives and false positives. - - `None`: Return scores of all categories. - - Defaults to "macro". - - .. note:: - MultiLabelMetric supports different kinds of inputs. Such as: - 1. Each sample has scores for every classes. (Only for predictions) - 2. Each sample has one-hot indices for every classes. - 3. Each sample has label-format indices. - - Examples: - - >>> from mmeval import MultiLabelMetric - >>> multi_lable_metic = MultiLabelMetric(num_classes=4) - - Use Builtin implementation with raw indices: - - >>> preds = [[0], [1], [2], [0, 3]] - >>> labels = [[0], [1, 2], [2], [3]] - >>> multi_lable_metic(preds, labels) - {'precision': 87.5, 'recall': 87.5, 'f1-score': 83.33} - - Use Builtin implementation with one-hot indices: - - >>> preds = [[1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [1, 0, 0, 1]] - >>> labels = [[1, 0, 0, 0], - [0, 1, 1, 0], - [0, 0, 1, 0], - [0, 0, 0, 1]] - >>> multi_lable_metic(preds, labels) - {'precision': 87.5, 'recall': 87.5, 'f1-score': 83.33} - - Use Builtin implementation with scores: - - >>> preds = [[0.9, 0.1, 0.2, 0.3], - [0.1, 0.8, 0.1, 0.1], - [0.4, 0.3, 0.7, 0.1], - [0.8, 0.1, 0.1, 0.9]] - >>> labels = [[1, 0, 0, 0], - [0, 1, 1, 0], - [0, 0, 1, 0], - [0, 0, 0, 1]] - >>> multi_lable_metic(preds, labels) - {'precision': 87.5, 'recall': 87.5, 'f1-score': 83.33} - - Use NumPy implementation with raw indices: - - >>> import numpy as np - >>> preds = [np.array([0]), np.array([1, 2]), np.array([2]), np.array([3])] # noqa - >>> labels = [np.array([0]), np.array([1]), np.array([2]), np.array([0, 3])] # noqa - >>> multi_lable_metic(preds, labels) - {'precision': 87.5, 'recall': 87.5, 'f1-score': 83.33} - - Use PyTorch implementation: - - >>> import torch - >>> preds = [torch.tensor([0]), torch.tensor([1, 2]), torch.tensor([2]), torch.tensor([3])] # noqa - >>> labels = [torch.tensor([0]), torch.tensor([1]), torch.tensor([2]), torch.tensor([0, 3])] # noqa - >>> multi_lable_metic(preds, labels) - {'precision': 87.5, 'recall': 87.5, 'f1-score': 83.33} - - Computing with `micro` average mode with `topk=2`: - - >>> preds = np.array([ - [0.7, 0.1, 0.1, 0.1], - [0.1, 0.3, 0.4, 0.2], - [0.3, 0.4, 0.2, 0.1], - [0.0, 0.0, 0.1, 0.9]]) - >>> labels = np.array([0, 1, 2, 3]) - >>> multi_lable_metic = MultiLabelMetric(4, average='micro', topk=2) - >>> multi_lable_metic(preds, labels) - {'precision_top2_micro': 37.5, 'recall_top2_micro': 75.0, 'f1-score_top2_micro': 50.0} # noqa - - Accumulate batch: - - >>> for i in range(10): - ... labels = torch.randint(0, 4, size=(100, )) - ... predicts = torch.randint(0, 4, size=(100, )) - ... multi_lable_metic.add(predicts, labels) - >>> multi_lable_metic.compute() # doctest: +SKIP - """ - - def __init__(self, - num_classes: int, - thr: Optional[float] = None, - topk: Optional[int] = None, - items: Sequence[str] = ('precision', 'recall', 'f1-score'), - average: Optional[str] = 'macro', - **kwargs) -> None: - super().__init__(**kwargs) - - if thr is None and topk is None: - thr = 0.5 - warnings.warn('Neither thr nor k is given, set thr as 0.5 by ' - 'default.') - elif thr is not None and topk is not None: - warnings.warn('Both thr and topk are given, ' - 'use threshold in favor of top-k.') - - self.thr = thr - self.topk = topk - - for item in items: - assert item in ['precision', 'recall', 'f1-score', 'support'], \ - f'The metric {item} is not supported by `MultiLabelMetric`,' \ - ' please specify from "precision", "recall", "f1-score" and ' \ - '"support".' - self.items = tuple(items) - - average_options = ['micro', 'macro', None] - assert average in average_options, 'Invalid `average` argument, ' \ - f'please specify from {average_options}.' - self.average = average - self.num_classes = num_classes - - def add(self, predictions: Sequence, labels: Sequence) -> None: # type: ignore # yapf: disable # noqa: E501 - """Add the intermediate results to `self._results`. - - Args: - predictions (Sequence): Predictions from the model. It can be - labels (N, ), or scores of every class (N, C). - labels (Sequence): The ground truth labels. It should be (N, ). - """ - for pred, label in zip(predictions, labels): - self._results.append((pred, label)) - - def _format_metric_results(self, results: List) -> Dict: - """Format the given metric results into a dictionary. - - Args: - results (list): Results of precision, recall, f1 and support. - - Returns: - dict: The formatted dictionary. - """ - metrics = {} - - def pack_results(precision, recall, f1_score, support): - single_metrics = {} - if 'precision' in self.items: - single_metrics['precision'] = precision - if 'recall' in self.items: - single_metrics['recall'] = recall - if 'f1-score' in self.items: - single_metrics['f1-score'] = f1_score - if 'support' in self.items: - single_metrics['support'] = support - return single_metrics - - if self.thr: - suffix = '' if self.thr == 0.5 else f'_thr-{self.thr:.2f}' - for k, v in pack_results(*results).items(): - metrics[k + suffix] = v - else: - for k, v in pack_results(*results).items(): - metrics[k + f'_top{self.topk}'] = v - - result_metrics = dict() - for k, v in metrics.items(): - - if self.average is None: - result_metrics[k + '_classwise'] = v.tolist() - elif self.average == 'micro': - result_metrics[k + f'_{self.average}'] = v.item() - else: - result_metrics[k] = v.item() - - return result_metrics - - @overload - @dispatch - def _compute_metric(self, predictions: Sequence['torch.Tensor'], - labels: Sequence['torch.Tensor']) -> List: - """A PyTorch implementation that computes the metric.""" - - preds = format_data(predictions, self.num_classes, - self._pred_is_onehot) - labels = format_data(labels, self.num_classes, - self._label_is_onehot).long() - - # cannot be raised in current implementation because - # `and` method will guarantee the equal length. - # However length check should remain somewhere. - assert preds.shape[0] == labels.shape[0], \ - 'Number of samples does not match between preds' \ - f'({preds.shape[0]}) and labels ({labels.shape[0]}).' - - if self.thr is not None: - # a label is predicted positive if larger than self. - # work for index as well - pos_inds = (preds >= self.thr).long() - else: - # top-k labels will be predicted positive for any example - _, topk_indices = preds.topk(self.topk) - pos_inds = torch.zeros_like(preds).scatter_(1, topk_indices, 1) - pos_inds = pos_inds.long() - - return _precision_recall_f1_support( # type: ignore - pos_inds, labels, self.average) - - @overload # type: ignore - @dispatch - def _compute_metric( # type: ignore - self, predictions: Sequence['oneflow.Tensor'], - labels: Sequence['oneflow.Tensor']) -> List: - """A OneFlow implementation that computes the metric.""" - - preds = format_data(predictions, self.num_classes, - self._pred_is_onehot) - labels = format_data(labels, self.num_classes, - self._label_is_onehot).long() - - # cannot be raised in current implementation because - # `and` method will guarantee the equal length. - # However length check should remain somewhere. - assert preds.shape[0] == labels.shape[0], \ - 'Number of samples does not match between preds' \ - f'({preds.shape[0]}) and labels ({labels.shape[0]}).' - - if self.thr is not None: - # a label is predicted positive if larger than self. - # work for index as well - pos_inds = (preds >= self.thr).long() - else: - # top-k labels will be predicted positive for any example - _, topk_indices = preds.topk(self.topk, dim=-1) - pos_inds = flow.zeros_like(preds).scatter_(1, topk_indices, 1) - pos_inds = pos_inds.long() - - return _precision_recall_f1_support( # type: ignore - pos_inds, labels, self.average) - - @overload - @dispatch - def _compute_metric(self, preds: Sequence[Union[int, - Sequence[Union[int, - float]]]], - labels: Sequence[Union[int, Sequence[int]]]) -> List: - """A Builtin implementation that computes the metric.""" - - return self._compute_metric([np.array(pred) for pred in preds], - [np.array(target) for target in labels]) - - @dispatch - def _compute_metric( - self, preds: Sequence[Union[np.ndarray, np.number]], - labels: Sequence[Union[np.ndarray, np.number]]) -> List: - """A NumPy implementation that computes the metric.""" - - preds = format_data(preds, self.num_classes, self._pred_is_onehot) - labels = format_data(labels, self.num_classes, - self._label_is_onehot).astype(np.int64) - - # cannot be raised in current implementation because - # `and` method will guarantee the equal length. - # However length check should remain somewhere. - assert preds.shape[0] == labels.shape[0], \ - 'Number of samples does not match between preds' \ - f'({preds.shape[0]}) and labels ({labels.shape[0]}).' - - if self.thr is not None: - # a label is predicted positive if larger than self. - # work for index as well - pos_inds = (preds >= self.thr).astype(np.int64) - else: - # top-k labels will be predicted positive for any example - topk_indices = np.argpartition( - preds, -self.topk, axis=-1)[:, -self.topk:] # type: ignore - pos_inds = np.zeros(preds.shape, dtype=np.int64) - np.put_along_axis(pos_inds, topk_indices, 1, axis=1) - - return _precision_recall_f1_support( # type: ignore - pos_inds, labels, self.average) - - def compute_metric( - self, results: List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS, - ONEFLOW_IMPL_HINTS, BUILTIN_IMPL_HINTS]] - ) -> Dict[str, float]: - """Compute the metric. - - Currently, there are 3 implementations of this method: NumPy and - PyTorch and OneFlow. Which implementation to use is determined by the - type of the calling parameters. e.g. `numpy.ndarray` or `torch.Tensor` - or `oneflow.Tensor`. - This method would be invoked in `BaseMetric.compute` after distributed - synchronization. - - Args: - results (List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS, - ONEFLOW_IMPL_HINTS]]): A listof tuples that consisting the - prediction and label. This list has already been synced across all - ranks. - Returns: - Dict[str, float]: The computed metric. - """ - preds = [res[0] for res in results] - labels = [res[1] for res in results] - metric_results = self._compute_metric(preds, labels) - return self._format_metric_results(metric_results) - - -def _average_precision_torch(preds: 'torch.Tensor', labels: 'torch.Tensor', - average) -> 'torch.Tensor': - r"""Calculate the average precision for torch. - - AP summarizes a precision-recall curve as the weighted mean of maximum - precisions obtained for any r'>r, where r is the recall: - - .. math:: - \text{AP} = \sum_n (R_n - R_{n-1}) P_n - - Note that no approximation is involved since the curve is piecewise - constant. - - Args: - preds (torch.Tensor): The model prediction with shape - ``(N, num_classes)``. - labels (torch.Tensor): The target of predictions with shape - ``(N, num_classes)``. - - Returns: - torch.Tensor: average precision result. - """ - # sort examples along classes - sorted_pred_inds = torch.argsort(preds, dim=0, descending=True) - sorted_target = torch.gather(labels, 0, sorted_pred_inds) - - # get indexes when gt_true is positive - pos_inds = sorted_target == 1 - - # Calculate cumulative tp case numbers - tps = torch.cumsum(pos_inds, 0) - total_pos = tps[-1].clone() # the last of tensor may change later - - # Calculate cumulative tp&fp(pred_poss) case numbers - pred_pos_nums = torch.arange(1, len(sorted_target) + 1).to(preds.device) - - tps[torch.logical_not(pos_inds)] = 0 - precision = tps / pred_pos_nums.unsqueeze(-1).float() # divide along rows - ap = torch.sum(precision, 0) / torch.clamp(total_pos, min=1) - - if average == 'macro': - return ap.mean() * 100.0 - else: - return ap * 100 - - -def _average_precision_oneflow(preds: 'oneflow.Tensor', - labels: 'oneflow.Tensor', - average) -> 'oneflow.Tensor': - r"""Calculate the average precision for oneflow. - - AP summarizes a precision-recall curve as the weighted mean of maximum - precisions obtained for any r'>r, where r is the recall: - - .. math:: - \text{AP} = \sum_n (R_n - R_{n-1}) P_n - - Note that no approximation is involved since the curve is piecewise - constant. - - Args: - preds (oneflow.Tensor): The model prediction with shape - ``(N, num_classes)``. - labels (oneflow.Tensor): The target of predictions with shape - ``(N, num_classes)``. - - Returns: - oneflow.Tensor: average precision result. - """ - # sort examples along classes - sorted_pred_inds = flow.argsort(preds, dim=0, descending=True) - sorted_target = flow.gather(labels, 0, sorted_pred_inds) - - # get indexes when gt_true is positive - pos_inds = sorted_target == 1 - - # Calculate cumulative tp case numbers - tps = flow.cumsum(pos_inds, 0) - total_pos = tps[-1].clone() # the last of tensor may change later - - # Calculate cumulative tp&fp(pred_poss) case numbers - pred_pos_nums = flow.arange(1, len(sorted_target) + 1).to(preds.device) - - tps[flow.logical_not(pos_inds)] = 0 - precision = tps / pred_pos_nums.unsqueeze(-1).float() # divide along rows - ap = flow.sum(precision, 0) / flow.clamp(total_pos, min=1) - - if average == 'macro': - return ap.mean() * 100.0 - else: - return ap * 100 - - -def _average_precision(preds: np.ndarray, labels: np.ndarray, - average) -> np.ndarray: - r"""Calculate the average precision for numpy. - - AP summarizes a precision-recall curve as the weighted mean of maximum - precisions obtained for any r'>r, where r is the recall: - - .. math:: - \text{AP} = \sum_n (R_n - R_{n-1}) P_n - - Note that no approximation is involved since the curve is piecewise - constant. - - Args: - preds (np.ndarray): The model prediction with shape - ``(N, num_classes)``. - labels (np.ndarray): The target of predictions with shape - ``(N, num_classes)``. - - Returns: - np.ndarray: average precision result. - """ - # sort examples along classes - sorted_pred_inds = np.argsort(-preds, axis=0) - sorted_target = np.take_along_axis(labels, sorted_pred_inds, axis=0) - - # get indexes when gt_true is positive - pos_inds = sorted_target == 1 - - # Calculate cumulative tp case numbers - tps = np.cumsum(pos_inds, 0) - total_pos = tps[-1].copy() # the last of tensor may change later - - # Calculate cumulative tp&fp(pred_poss) case numbers - pred_pos_nums = np.arange(1, len(sorted_target) + 1) - - tps[np.logical_not(pos_inds)] = 0 - precision = np.divide( - tps, np.expand_dims(pred_pos_nums, -1), dtype=np.float32) - ap = np.divide( - np.sum(precision, 0), np.clip(total_pos, 1, np.inf), dtype=np.float32) - - if average == 'macro': - return ap.mean() * 100.0 - else: - return ap * 100 - - -class AveragePrecision(MultiLabelMixin, BaseMetric): - """Calculate the average precision with respect of classes. - - Args: - average (str, optional): The average method. It supports two modes: - - - `"macro"`: Calculate metrics for each category, and calculate - the mean value over all categories. - - `None`: Return scores of all categories. - - Defaults to "macro". - - References - ---------- - .. [1] `Wikipedia entry for the Average precision - `_ - - Examples: - - >>> from mmeval import AveragePrecision - >>> average_precision = AveragePrecision() - - Use Builtin implementation with label-format labels: - - >>> preds = [[0.9, 0.8, 0.3, 0.2], - [0.1, 0.2, 0.2, 0.1], - [0.7, 0.5, 0.9, 0.3], - [0.8, 0.1, 0.1, 0.2]] - >>> labels = [[0, 1], [1], [2], [0]] - >>> average_precision(preds, labels) - {'mAP': 70.833..} - - Use Builtin implementation with one-hot encoding labels: - - >>> preds = [[0.9, 0.8, 0.3, 0.2], - [0.1, 0.2, 0.2, 0.1], - [0.7, 0.5, 0.9, 0.3], - [0.8, 0.1, 0.1, 0.2]] - >>> labels = [[1, 1, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [1, 0, 0, 0]] - >>> average_precision(preds, labels) - {'mAP': 70.833..} - - Use NumPy implementation with label-format labels: - - >>> import numpy as np - >>> preds = np.array([[0.9, 0.8, 0.3, 0.2], - [0.1, 0.2, 0.2, 0.1], - [0.7, 0.5, 0.9, 0.3], - [0.8, 0.1, 0.1, 0.2]]) - >>> labels = [np.array([0, 1]), np.array([1]), np.array([2]), np.array([0])] # noqa - >>> average_precision(preds, labels) - {'mAP': 70.833..} - - Use PyTorch implementation with one-hot encoding labels:: - - >>> import torch - >>> preds = torch.Tensor([[0.9, 0.8, 0.3, 0.2], - [0.1, 0.2, 0.2, 0.1], - [0.7, 0.5, 0.9, 0.3], - [0.8, 0.1, 0.1, 0.2]]) - >>> labels = torch.Tensor([[1, 1, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [1, 0, 0, 0]]) - >>> average_precision(preds, labels) - {'mAP': 70.833..} - - Computing with `None` average mode: - - >>> preds = np.array([[0.9, 0.8, 0.3, 0.2], - [0.1, 0.2, 0.2, 0.1], - [0.7, 0.5, 0.9, 0.3], - [0.8, 0.1, 0.1, 0.2]]) - >>> labels = [np.array([0, 1]), np.array([1]), np.array([2]), np.array([0])] # noqa - >>> average_precision = AveragePrecision(average=None) - >>> average_precision(preds, labels) - {'AP_classwise': [100.0, 83.33, 100.00, 0.0]} # rounded results - - Accumulate batch: - - >>> for i in range(10): - ... preds = torch.randint(0, 4, size=(100, 10)) - ... labels = torch.randint(0, 4, size=(100, )) - ... average_precision.add(preds, labels) - >>> average_precision.compute() # doctest: +SKIP - """ - - def __init__(self, average: Optional[str] = 'macro', **kwargs) -> None: - super().__init__(**kwargs) - average_options = ['macro', None] - assert average in average_options, 'Invalid `average` argument, ' \ - f'please specify from {average_options}.' - self.average = average - - def add(self, preds: Sequence, labels: Sequence) -> None: # type: ignore # yapf: disable # noqa: E501 - """Add the intermediate results to `self._results`. - - Args: - preds (Sequence): Predictions from the model. It should - be scores of every class (N, C). - labels (Sequence): The ground truth labels. It should be (N, ). - """ - for pred, target in zip(preds, labels): - self._results.append((pred, target)) - - def _format_metric_results(self, ap): - """Format the given metric results into a dictionary. - - Args: - results (list): Results of precision, recall, f1 and support. - - Returns: - dict: The formatted dictionary. - """ - result_metrics = dict() - - if self.average is None: - result_metrics['AP_classwise'] = ap[0].tolist() - else: - result_metrics['mAP'] = ap[0].item() - - return result_metrics - - @overload - @dispatch - def _compute_metric(self, preds: Sequence['torch.Tensor'], - labels: Sequence['torch.Tensor']) -> List[List]: - """A PyTorch implementation that computes the metric.""" - - preds = torch.stack(preds) - num_classes = preds.shape[1] - labels = format_data(labels, num_classes, self._label_is_onehot).long() - - assert preds.shape[0] == labels.shape[0], \ - 'Number of samples does not match between preds' \ - f'({preds.shape[0]}) and labels ({labels.shape[0]}).' - - return _average_precision_torch(preds, labels, self.average) - - @overload # type: ignore - @dispatch - def _compute_metric( # type: ignore - self, preds: Sequence['oneflow.Tensor'], - labels: Sequence['oneflow.Tensor']) -> List[List]: - """A OneFlow implementation that computes the metric.""" - - preds = flow.stack(preds) - num_classes = preds.shape[1] - labels = format_data(labels, num_classes, self._label_is_onehot).long() - - assert preds.shape[0] == labels.shape[0], \ - 'Number of samples does not match between preds' \ - f'({preds.shape[0]}) and labels ({labels.shape[0]}).' - - return _average_precision_oneflow(preds, labels, self.average) - - @overload - @dispatch - def _compute_metric( - self, preds: Sequence[Union[int, Sequence[Union[int, float]]]], - labels: Sequence[Union[int, Sequence[int]]]) -> List[List]: - """A Builtin implementation that computes the metric.""" - - return self._compute_metric([np.array(pred) for pred in preds], - [np.array(target) for target in labels]) - - @dispatch - def _compute_metric( - self, preds: Sequence[Union[np.ndarray, np.number]], - labels: Sequence[Union[np.ndarray, np.number]]) -> List[List]: - """A NumPy implementation that computes the metric.""" - - preds = np.stack(preds) - num_classes = preds.shape[1] - labels = format_data(labels, num_classes, - self._label_is_onehot).astype(np.int64) - - assert preds.shape[0] == labels.shape[0], \ - 'Number of samples does not match between preds' \ - f'({preds.shape[0]}) and labels ({labels.shape[0]}).' - - return _average_precision(preds, labels, self.average) - - def compute_metric( - self, results: List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS, - ONEFLOW_IMPL_HINTS, BUILTIN_IMPL_HINTS]] - ) -> Dict[str, float]: - """Compute the metric. - - Currently, there are 3 implementations of this method: NumPy and - PyTorch and OneFlow. Which implementation to use is determined by the - type of the calling parameters. e.g. `numpy.ndarray` or - `torch.Tensor`, `oneflow.Tensor`. - - This method would be invoked in `BaseMetric.compute` after distributed - synchronization. - Args: - results (List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS, - ONEFLOW_IMPL_HINTS]]): A list of tuples that consisting the - prediction and label. This list has already been synced across - all ranks. - - Returns: - Dict[str, float]: The computed metric. - """ - preds = [res[0] for res in results] - labels = [res[1] for res in results] - assert self._pred_is_onehot is False, '`self._pred_is_onehot` should' \ - f'be `False` for {self.__class__.__name__}, because scores are' \ - 'necessary for compute the metric.' - metric_results = self._compute_metric(preds, labels) - return self._format_metric_results(metric_results) diff --git a/mmeval/metrics/single_label.py b/mmeval/metrics/precision_recall_f1score.py similarity index 53% rename from mmeval/metrics/single_label.py rename to mmeval/metrics/precision_recall_f1score.py index 176c5145..83e9776f 100644 --- a/mmeval/metrics/single_label.py +++ b/mmeval/metrics/precision_recall_f1score.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. - import numpy as np +import warnings from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union, overload) from mmeval.core.base_metric import BaseMetric from mmeval.core.dispatcher import dispatch +from mmeval.metrics.utils import MultiLabelMixin, format_data from mmeval.utils import try_import if TYPE_CHECKING: @@ -20,13 +21,69 @@ flow = try_import('oneflow') of_F = try_import('oneflow.nn.functional') -NUMPY_IMPL_HINTS = Tuple[Union[np.ndarray, np.number], np.number] +NUMPY_IMPL_HINTS = Tuple[Union[np.ndarray, np.number], Union[np.ndarray, + np.number]] TORCH_IMPL_HINTS = Tuple['torch.Tensor', 'torch.Tensor'] ONEFLOW_IMPL_HINTS = Tuple['oneflow.Tensor', 'oneflow.Tensor'] BUILTIN_IMPL_HINTS = Tuple[Union[int, Sequence[Union[int, float]]], Union[int, Sequence[int]]] +class PrecsionRecallF1score: + """Wrapper to get different task of PrecsionRecallF1score calculation, by + setting the ``task`` argument to either ``'singlelabel'`` or + ``multilabel``. + + See the documentation of :mod:`SingleLabelPrecsionRecallF1score` and + :mod:`MultiLabelPrecsionRecallF1score` for the detailed usages and + examples. + + Examples: + >>> import torch + >>> preds = torch.tensor([2, 0, 1, 1]) + >>> labels = torch.tensor([2, 1, 2, 0]) + >>> metric = PrecsionRecallF1score(num_classes=3) + >>> metric(preds, labels) + {'precision': 33.3333, 'recall': 16.6667, 'f1-score': 22.2222} + >>> metric = PrecsionRecallF1score( + task="multilabel", average='micro', num_classes=3) + >>> metric(preds, labels) + {'precision_micro': 25.0, 'recall_micro': 25.0, 'f1-score_micro': 25.0} + """ + + def __new__(cls, + task: str = 'singlelabel', + num_classes: Optional[int] = None, + thrs: Union[float, Sequence[Optional[float]], None] = None, + topk: Optional[int] = None, + items: Sequence[str] = ('precision', 'recall', 'f1-score'), + average: Optional[str] = 'macro', + **kwargs): + + if task == 'singlelabel': + return SingleLabelPrecsionRecallF1score( + num_classes=num_classes, + thrs=thrs, + items=items, + average=average, + **kwargs) + if task == 'multilabel': + assert isinstance(thrs, float) or thrs is None, \ + "task `'multilabel'` only supports single threshold or None." + assert isinstance(num_classes, int), \ + '`num_classes` is necessary for multi-label metrics.' + return MultiLabelPrecsionRecallF1score( + num_classes=num_classes, + thr=thrs, + topk=topk, + items=items, + average=average, + **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'singlelabel'` or " + f"`'multilabel'` but got {task}") + + def _precision_recall_f1_support(pred_positive: Union[np.ndarray, 'torch.Tensor', 'oneflow.Tensor'], @@ -112,7 +169,7 @@ def _precision_recall_f1_support(pred_positive: Union[np.ndarray, return precision, recall, f1_score, support -class SingleLabelMetric(BaseMetric): +class SingleLabelPrecsionRecallF1score(BaseMetric): """A collection of metrics for single-label multi-class classification task based on confusion matrix. @@ -153,8 +210,8 @@ class SingleLabelMetric(BaseMetric): Examples: - >>> from mmeval import SingleLabelMetric - >>> single_lable_metic = SingleLabelMetric(num_classes=4) + >>> from mmeval import SingleLabelPrecsionRecallF1score + >>> single_lable_metic = SingleLabelPrecsionRecallF1score(num_classes=4) Use NumPy implementation: @@ -180,7 +237,7 @@ class SingleLabelMetric(BaseMetric): [0.3, 0.4, 0.2, 0.1], [0.0, 0.0, 0.1, 0.9]]) >>> labels = np.asarray([0, 1, 2, 3]) - >>> single_lable_metic = SingleLabelMetric(average='micro') + >>> single_lable_metic = SingleLabelPrecsionRecallF1score(average='micro') >>> single_lable_metic(preds, labels) {'precision_micro': 50.0, 'recall_micro': 50.0, 'f1-score_micro': 50.0} # noqa @@ -194,10 +251,10 @@ class SingleLabelMetric(BaseMetric): """ def __init__(self, - thrs: Union[float, Sequence[Optional[float]], None] = 0., + num_classes: Optional[int] = None, + thrs: Union[float, Sequence[Optional[float]], None] = None, items: Sequence[str] = ('precision', 'recall', 'f1-score'), average: Optional[str] = 'macro', - num_classes: Optional[int] = None, **kwargs) -> None: super().__init__(**kwargs) @@ -271,11 +328,14 @@ def pack_results(precision, recall, f1_score, support): for k, v in metrics.items(): if self.average is None: - result_metrics[k + '_classwise'] = v.tolist() + _result = v.tolist() + result_metrics[k + '_classwise'] = [ + round(_r, 4) for _r in _result + ] elif self.average == 'micro': - result_metrics[k + f'_{self.average}'] = v.item() + result_metrics[k + f'_{self.average}'] = round(v.item(), 4) else: - result_metrics[k] = v.item() + result_metrics[k] = round(v.item(), 4) return result_metrics @@ -470,3 +530,345 @@ def compute_metric( labels = [res[1] for res in results] metric_results = self._compute_metric(predictions, labels) return self._format_metric_results(metric_results) + + +class MultiLabelPrecsionRecallF1score(MultiLabelMixin, BaseMetric): + """A collection of metrics for multi-label multi-class classification task + based on confusion matrix. + + It includes precision, recall, f1-score and support. + + Args: + num_classes (int): Number of classes. Needed for different inputs + as extra check. + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + items (Sequence[str]): The detailed metric items to evaluate. Here is + the available options: + + - `"precision"`: The ratio tp / (tp + fp) where tp is the + number of true positives and fp the number of false + positives. + - `"recall"`: The ratio tp / (tp + fn) where tp is the number + of true positives and fn the number of false negatives. + - `"f1-score"`: The f1-score is the harmonic mean of the + precision and recall. + - `"support"`: The total number of positive of each category + in the target. + + Defaults to ('precision', 'recall', 'f1-score'). + average (str | None): The average method. It supports three average + modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Calculate metrics globally by counting the total + true positives, false negatives and false positives. + - `None`: Return scores of all categories. + + Defaults to "macro". + + .. note:: + MultiLabelPrecsionRecallF1score supports different kinds of inputs. Such as: + 1. Each sample has scores for every classes. (Only for predictions) + 2. Each sample has one-hot indices for every classes. + 3. Each sample has label-format indices. + + Examples: + + >>> from mmeval import MultiLabelPrecsionRecallF1score + >>> multi_lable_metic = MultiLabelPrecsionRecallF1score(num_classes=4) + + Use Builtin implementation with raw indices: + + >>> preds = [[0], [1], [2], [0, 3]] + >>> labels = [[0], [1, 2], [2], [3]] + >>> multi_lable_metic(preds, labels) + {'precision': 87.5, 'recall': 87.5, 'f1-score': 83.33} + + Use Builtin implementation with one-hot indices: + + >>> preds = [[1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [1, 0, 0, 1]] + >>> labels = [[1, 0, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]] + >>> multi_lable_metic(preds, labels) + {'precision': 87.5, 'recall': 87.5, 'f1-score': 83.33} + + Use Builtin implementation with scores: + + >>> preds = [[0.9, 0.1, 0.2, 0.3], + [0.1, 0.8, 0.1, 0.1], + [0.4, 0.3, 0.7, 0.1], + [0.8, 0.1, 0.1, 0.9]] + >>> labels = [[1, 0, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]] + >>> multi_lable_metic(preds, labels) + {'precision': 87.5, 'recall': 87.5, 'f1-score': 83.33} + + Use NumPy implementation with raw indices: + + >>> import numpy as np + >>> preds = [np.array([0]), np.array([1, 2]), np.array([2]), np.array([3])] # noqa + >>> labels = [np.array([0]), np.array([1]), np.array([2]), np.array([0, 3])] # noqa + >>> multi_lable_metic(preds, labels) + {'precision': 87.5, 'recall': 87.5, 'f1-score': 83.33} + + Use PyTorch implementation: + + >>> import torch + >>> preds = [torch.tensor([0]), torch.tensor([1, 2]), torch.tensor([2]), torch.tensor([3])] # noqa + >>> labels = [torch.tensor([0]), torch.tensor([1]), torch.tensor([2]), torch.tensor([0, 3])] # noqa + >>> multi_lable_metic(preds, labels) + {'precision': 87.5, 'recall': 87.5, 'f1-score': 83.33} + + Computing with `micro` average mode with `topk=2`: + + >>> preds = np.array([ + [0.7, 0.1, 0.1, 0.1], + [0.1, 0.3, 0.4, 0.2], + [0.3, 0.4, 0.2, 0.1], + [0.0, 0.0, 0.1, 0.9]]) + >>> labels = np.array([0, 1, 2, 3]) + >>> multi_lable_metic = MultiLabelPrecsionRecallF1score(4, average='micro', topk=2) + >>> multi_lable_metic(preds, labels) + {'precision_top2_micro': 37.5, 'recall_top2_micro': 75.0, 'f1-score_top2_micro': 50.0} # noqa + + Accumulate batch: + + >>> for i in range(10): + ... labels = torch.randint(0, 4, size=(100, )) + ... predicts = torch.randint(0, 4, size=(100, )) + ... multi_lable_metic.add(predicts, labels) + >>> multi_lable_metic.compute() # doctest: +SKIP + """ + + def __init__(self, + num_classes: int, + thr: Optional[float] = None, + topk: Optional[int] = None, + items: Sequence[str] = ('precision', 'recall', 'f1-score'), + average: Optional[str] = 'macro', + **kwargs) -> None: + super().__init__(**kwargs) + + if thr is None and topk is None: + thr = 0.5 + warnings.warn('Neither thr nor k is given, set thr as 0.5 by ' + 'default.') + elif thr is not None and topk is not None: + warnings.warn('Both thr and topk are given, ' + 'use threshold in favor of top-k.') + + self.thr = thr + self.topk = topk + + for item in items: + assert item in ['precision', 'recall', 'f1-score', 'support'], \ + f'The metric {item} is not supported by' \ + ' `MultiLabelPrecsionRecallF1score`,' \ + ' please specify from "precision", "recall", "f1-score" and ' \ + '"support".' + self.items = tuple(items) + + average_options = ['micro', 'macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specify from {average_options}.' + self.average = average + self.num_classes = num_classes + + def add(self, predictions: Sequence, labels: Sequence) -> None: # type: ignore # yapf: disable # noqa: E501 + """Add the intermediate results to `self._results`. + + Args: + predictions (Sequence): Predictions from the model. It can be + labels (N, ), or scores of every class (N, C). + labels (Sequence): The ground truth labels. It should be (N, ). + """ + for pred, label in zip(predictions, labels): + self._results.append((pred, label)) + + def _format_metric_results(self, results: List) -> Dict: + """Format the given metric results into a dictionary. + + Args: + results (list): Results of precision, recall, f1 and support. + + Returns: + dict: The formatted dictionary. + """ + metrics = {} + + def pack_results(precision, recall, f1_score, support): + single_metrics = {} + if 'precision' in self.items: + single_metrics['precision'] = precision + if 'recall' in self.items: + single_metrics['recall'] = recall + if 'f1-score' in self.items: + single_metrics['f1-score'] = f1_score + if 'support' in self.items: + single_metrics['support'] = support + return single_metrics + + if self.thr: + suffix = '' if self.thr == 0.5 else f'_thr-{self.thr:.2f}' + for k, v in pack_results(*results).items(): + metrics[k + suffix] = v + else: + for k, v in pack_results(*results).items(): + metrics[k + f'_top{self.topk}'] = v + + result_metrics = dict() + for k, v in metrics.items(): + + if self.average is None: + _result = v.tolist() + result_metrics[k + '_classwise'] = [ + round(_r, 4) for _r in _result + ] + elif self.average == 'micro': + result_metrics[k + f'_{self.average}'] = round(v.item(), 4) + else: + result_metrics[k] = round(v.item(), 4) + + return result_metrics + + @overload + @dispatch + def _compute_metric(self, predictions: Sequence['torch.Tensor'], + labels: Sequence['torch.Tensor']) -> List: + """A PyTorch implementation that computes the metric.""" + + preds = format_data(predictions, self.num_classes, + self._pred_is_onehot) + labels = format_data(labels, self.num_classes, + self._label_is_onehot).long() + + # cannot be raised in current implementation because + # `and` method will guarantee the equal length. + # However length check should remain somewhere. + assert preds.shape[0] == labels.shape[0], \ + 'Number of samples does not match between preds' \ + f'({preds.shape[0]}) and labels ({labels.shape[0]}).' + + if self.thr is not None: + # a label is predicted positive if larger than self. + # work for index as well + pos_inds = (preds >= self.thr).long() + else: + # top-k labels will be predicted positive for any example + _, topk_indices = preds.topk(self.topk) + pos_inds = torch.zeros_like(preds).scatter_(1, topk_indices, 1) + pos_inds = pos_inds.long() + + return _precision_recall_f1_support( # type: ignore + pos_inds, labels, self.average) + + @overload # type: ignore + @dispatch + def _compute_metric( # type: ignore + self, predictions: Sequence['oneflow.Tensor'], + labels: Sequence['oneflow.Tensor']) -> List: + """A OneFlow implementation that computes the metric.""" + + preds = format_data(predictions, self.num_classes, + self._pred_is_onehot) + labels = format_data(labels, self.num_classes, + self._label_is_onehot).long() + + # cannot be raised in current implementation because + # `and` method will guarantee the equal length. + # However length check should remain somewhere. + assert preds.shape[0] == labels.shape[0], \ + 'Number of samples does not match between preds' \ + f'({preds.shape[0]}) and labels ({labels.shape[0]}).' + + if self.thr is not None: + # a label is predicted positive if larger than self. + # work for index as well + pos_inds = (preds >= self.thr).long() + else: + # top-k labels will be predicted positive for any example + _, topk_indices = preds.topk(self.topk, dim=-1) + pos_inds = flow.zeros_like(preds).scatter_(1, topk_indices, 1) + pos_inds = pos_inds.long() + + return _precision_recall_f1_support( # type: ignore + pos_inds, labels, self.average) + + @overload + @dispatch + def _compute_metric(self, preds: Sequence[Union[int, + Sequence[Union[int, + float]]]], + labels: Sequence[Union[int, Sequence[int]]]) -> List: + """A Builtin implementation that computes the metric.""" + + return self._compute_metric([np.array(pred) for pred in preds], + [np.array(target) for target in labels]) + + @dispatch + def _compute_metric( + self, preds: Sequence[Union[np.ndarray, np.number]], + labels: Sequence[Union[np.ndarray, np.number]]) -> List: + """A NumPy implementation that computes the metric.""" + + preds = format_data(preds, self.num_classes, self._pred_is_onehot) + labels = format_data(labels, self.num_classes, + self._label_is_onehot).astype(np.int64) + + # cannot be raised in current implementation because + # `and` method will guarantee the equal length. + # However length check should remain somewhere. + assert preds.shape[0] == labels.shape[0], \ + 'Number of samples does not match between preds' \ + f'({preds.shape[0]}) and labels ({labels.shape[0]}).' + + if self.thr is not None: + # a label is predicted positive if larger than self. + # work for index as well + pos_inds = (preds >= self.thr).astype(np.int64) + else: + # top-k labels will be predicted positive for any example + topk_indices = np.argpartition( + preds, -self.topk, axis=-1)[:, -self.topk:] # type: ignore + pos_inds = np.zeros(preds.shape, dtype=np.int64) + np.put_along_axis(pos_inds, topk_indices, 1, axis=1) + + return _precision_recall_f1_support( # type: ignore + pos_inds, labels, self.average) + + def compute_metric( + self, results: List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS, + ONEFLOW_IMPL_HINTS, BUILTIN_IMPL_HINTS]] + ) -> Dict[str, float]: + """Compute the metric. + + Currently, there are 3 implementations of this method: NumPy and + PyTorch and OneFlow. Which implementation to use is determined by the + type of the calling parameters. e.g. `numpy.ndarray` or `torch.Tensor` + or `oneflow.Tensor`. + This method would be invoked in `BaseMetric.compute` after distributed + synchronization. + + Args: + results (List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS, + ONEFLOW_IMPL_HINTS]]): A listof tuples that consisting the + prediction and label. This list has already been synced across all + ranks. + Returns: + Dict[str, float]: The computed metric. + """ + preds = [res[0] for res in results] + labels = [res[1] for res in results] + metric_results = self._compute_metric(preds, labels) + return self._format_metric_results(metric_results) diff --git a/mmeval/metrics/utils/__init__.py b/mmeval/metrics/utils/__init__.py index a85e1b1f..53380b82 100644 --- a/mmeval/metrics/utils/__init__.py +++ b/mmeval/metrics/utils/__init__.py @@ -5,6 +5,7 @@ from .grammar import get_n_gram, get_tokenizer, infer_language from .image_transforms import reorder_and_crop from .keypoint import calc_distances, distance_acc +from .multi_label import MultiLabelMixin, format_data from .polygon import (poly2shapely, poly_intersection, poly_iou, poly_make_valid, poly_union, polys2shapely) @@ -13,5 +14,6 @@ 'poly_make_valid', 'poly_iou', 'calc_distances', 'distance_acc', 'calculate_overlaps', 'calculate_bboxes_area', 'reorder_and_crop', 'calculate_bboxes_area_rotated', 'calculate_overlaps_rotated', - 'get_n_gram', 'get_tokenizer', 'infer_language' + 'get_n_gram', 'get_tokenizer', 'infer_language', 'MultiLabelMixin', + 'format_data' ] diff --git a/mmeval/metrics/utils/multi_label.py b/mmeval/metrics/utils/multi_label.py new file mode 100644 index 00000000..73b30eb4 --- /dev/null +++ b/mmeval/metrics/utils/multi_label.py @@ -0,0 +1,164 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import warnings +from typing import TYPE_CHECKING, Sequence, Tuple, Union + +from mmeval.utils import try_import + +if TYPE_CHECKING: + import oneflow + import oneflow as flow + import torch +else: + torch = try_import('torch') + flow = try_import('oneflow') + +NUMPY_IMPL_HINTS = Tuple[Union[np.ndarray, np.number], Union[np.ndarray, + np.number]] +TORCH_IMPL_HINTS = Tuple['torch.Tensor', 'torch.Tensor'] +ONEFLOW_IMPL_HINTS = Tuple['oneflow.Tensor', 'oneflow.Tensor'] +BUILTIN_IMPL_HINTS = Tuple[Union[int, Sequence[Union[int, float]]], + Union[int, Sequence[int]]] + + +def label_to_onehot( + label: Union[np.ndarray, 'torch.Tensor', + 'oneflow.Tensor'], num_classes: int +) -> Union[np.ndarray, 'torch.Tensor', 'oneflow.Tensor']: + """Convert the label-format input to one-hot encodings. + + Args: + label (torch.Tensor or oneflow.Tensor or np.ndarray): + The label-format input. The format of item must be label-format. + num_classes (int): The number of classes. + + Return: + torch.Tensor or oneflow.Tensor or np.ndarray: + The converted one-hot encodings. + """ + if torch and isinstance(label, torch.Tensor): + label = label.long() + onehot = label.new_zeros((num_classes, )) + elif flow and isinstance(label, flow.Tensor): + label = label.long() + onehot = label.new_zeros((num_classes, )) + else: + label = label.astype(np.int64) + onehot = np.zeros((num_classes, ), dtype=np.int64) + assert label.max().item() < num_classes, \ + 'Max index is out of `num_classes` {num_classes}' + assert label.min().item() >= 0 + onehot[label] = 1 + return onehot + + +def format_data( + data: Union[Sequence[Union[np.ndarray, 'torch.Tensor', 'oneflow.Tensor']], + np.ndarray, 'torch.Tensor', 'oneflow.Tensor'], + num_classes: int, + is_onehot: bool = False +) -> Union[np.ndarray, 'torch.Tensor', 'oneflow.Tensor']: + """Format data from different inputs such as prediction scores, label- + format data and one-hot encodings into the same output shape of `(N, + num_classes)`. + + Args: + data (Union[Sequence[np.ndarray, 'torch.Tensor', 'oneflow.Tensor'], + np.ndarray, 'torch.Tensor', 'oneflow.Tensor']): + The input data of prediction or labels. + num_classes (int): The number of classes. + is_onehot (bool): Whether the data is one-hot encodings. + + Return: + torch.Tensor or oneflow.Tensor or np.ndarray: + One-hot encodings or predict scores. + """ + if torch and isinstance(data[0], torch.Tensor): + stack_func = torch.stack + elif flow and isinstance(data[0], flow.Tensor): + stack_func = flow.stack + elif isinstance(data[0], (np.ndarray, np.number)): + stack_func = np.stack + else: + raise NotImplementedError(f'Data type of {type(data[0])}' + 'is not supported.') + + try: + # try stack scores or one-hot indices directly + formated_data = stack_func(data) + # all assertions below is to find labels that are + # raw indices which should be caught in exception + # to convert to one-hot indices. + # + # 1. all raw indices has only 1 dims + assert formated_data.ndim == 2 + # 2. all raw indices has the same dims + assert formated_data.shape[1] == num_classes + # 3. all raw indices has the same dims as num_classes + # then max indices should greater than 1 for num_classes > 2 + assert formated_data.max() <= 1 + # 4. corner case, num_classes=2, then one-hot indices + # and raw indices are undistinguishable, for instance: + # [[0, 1], [0, 1]] can be one-hot indices of 2 positives + # or raw indices of 4 positives. + # Extra induction is needed. + if num_classes == 2: + warnings.warn('Ambiguous data detected, reckoned as scores' + ' or label-format data as defaults. Please set ' + 'parms related to `is_onehot` if use one-hot ' + 'encoding data to compute metrics.') + assert is_onehot + # Error corresponds to np, torch, oneflow, stack_func respectively + except (ValueError, RuntimeError, AssertionError): + # convert label-format inputs to one-hot encodings + formated_data = stack_func( + [label_to_onehot(sample, num_classes) for sample in data]) + return formated_data + + +class MultiLabelMixin: + """A Mixin for Multilabel Metrics to clarify whether the input is one-hot + encodings or label-format inputs for corner case with minimal user + awareness.""" + + def __init__(self, *args, **kwargs) -> None: + # pass arguments for multiple inheritances + super().__init__(*args, **kwargs) # type: ignore + self._pred_is_onehot = False + self._label_is_onehot = False + + @property + def pred_is_onehot(self) -> bool: + """Whether prediction is one-hot encodings. + + Only works for corner case when num_classes=2 to distinguish one-hot + encodings or label-format. + """ + return self._pred_is_onehot + + @pred_is_onehot.setter + def pred_is_onehot(self, is_onehot: bool): + """Set a flag of whether prediction is one-hot encodings. + + Only works for corner case when num_classes=2 to distinguish one-hot + encodings or label-format. + """ + self._pred_is_onehot = is_onehot + + @property + def label_is_onehot(self) -> bool: + """Whether label is one-hot encodings. + + Only works for corner case when num_classes=2 to distinguish one-hot + encodings or label-format. + """ + return self._label_is_onehot + + @label_is_onehot.setter + def label_is_onehot(self, is_onehot: bool): + """Set a flag of whether label is one-hot encodings. + + Only works for corner case when num_classes=2 to distinguish one-hot + encodings or label-format. + """ + self._label_is_onehot = is_onehot diff --git a/tests/test_metrics/test_multi_label.py b/tests/test_metrics/test_multi_label_precision_recall_f1score.py similarity index 91% rename from tests/test_metrics/test_multi_label.py rename to tests/test_metrics/test_multi_label_precision_recall_f1score.py index 85835e1d..8bdd31b6 100644 --- a/tests/test_metrics/test_multi_label.py +++ b/tests/test_metrics/test_multi_label_precision_recall_f1score.py @@ -7,7 +7,7 @@ from distutils.version import LooseVersion from mmeval.core.base_metric import BaseMetric -from mmeval.metrics import MultiLabelMetric +from mmeval.metrics import MultiLabelPrecsionRecallF1score from mmeval.utils import try_import torch = try_import('torch') @@ -17,10 +17,10 @@ def test_metric_init_assertion(): with pytest.raises(AssertionError, match='Invalid `average` argument'): - MultiLabelMetric(3, average='mean') + MultiLabelPrecsionRecallF1score(3, average='mean') with pytest.raises(AssertionError, match='The metric map is not supported'): - MultiLabelMetric(3, items=('map',)) + MultiLabelPrecsionRecallF1score(3, items=('map',)) @pytest.mark.parametrize( @@ -40,7 +40,8 @@ def test_metric_init_assertion(): ) def test_metric_inputs(metric_kwargs): # test predictions with labels - multi_label_metric = MultiLabelMetric(num_classes=3, **metric_kwargs) + multi_label_metric = MultiLabelPrecsionRecallF1score( + num_classes=3, **metric_kwargs) assert isinstance(multi_label_metric, BaseMetric) results = multi_label_metric( np.asarray([[0.1, 0.9, 0.8], [0.5, 0.5, 0.8]]), np.asarray([0, 1])) @@ -58,7 +59,7 @@ def test_metric_inputs(metric_kwargs): ]) def test_metric_interface_builtin(metric_kwargs, preds, labels): """Test builtin inputs.""" - multi_label_metric = MultiLabelMetric(**metric_kwargs) + multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -77,7 +78,7 @@ def test_metric_interface_builtin(metric_kwargs, preds, labels): ]) def test_metric_interface_topk(metric_kwargs, preds, labels): """Test scores inputs with topk.""" - multi_label_metric = MultiLabelMetric(**metric_kwargs) + multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -97,7 +98,7 @@ def test_metric_interface_topk(metric_kwargs, preds, labels): @pytest.mark.skipif(torch is None, reason='PyTorch is not available!') def test_metric_interface_torch_topk(metric_kwargs, preds, labels): """Test scores inputs with topk in torch.""" - multi_label_metric = MultiLabelMetric(**metric_kwargs) + multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -127,7 +128,7 @@ def test_metric_interface_oneflow_topk(metric_kwargs, preds, labels): labels = flow.tensor(labels) else: labels = list(flow.tensor(label) for label in labels) - multi_label_metric = MultiLabelMetric(**metric_kwargs) + multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -150,7 +151,7 @@ def test_metric_interface_oneflow_topk(metric_kwargs, preds, labels): ]) def test_metric_interface(metric_kwargs, preds, labels): """Test all kinds of inputs.""" - multi_label_metric = MultiLabelMetric(**metric_kwargs) + multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -174,7 +175,7 @@ def test_metric_interface(metric_kwargs, preds, labels): @pytest.mark.skipif(torch is None, reason='PyTorch is not available!') def test_metric_interface_torch(metric_kwargs, preds, labels): """Test all kinds of inputs in torch.""" - multi_label_metric = MultiLabelMetric(**metric_kwargs) + multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -208,7 +209,7 @@ def test_metric_interface_oneflow(metric_kwargs, preds, labels): labels = flow.tensor(labels) else: labels = list(flow.tensor(label) for label in labels) - multi_label_metric = MultiLabelMetric(**metric_kwargs) + multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -243,14 +244,14 @@ def test_metric_interface_oneflow(metric_kwargs, preds, labels): ) def test_metric_accurate(metric_kwargs, preds, labels, results): """Test accurate.""" - multi_label_metric = MultiLabelMetric(**metric_kwargs) + multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) assert multi_label_metric( np.asarray(preds), np.asarray(labels)) == results def test_metric_accurate_is_onehot(): """Test ambiguous cases when num_classes=2.""" - multi_label_metric = MultiLabelMetric(num_classes=2, items=('precision', 'recall')) # noqa + multi_label_metric = MultiLabelPrecsionRecallF1score(num_classes=2, items=('precision', 'recall')) # noqa assert multi_label_metric.pred_is_onehot is False assert multi_label_metric.label_is_onehot is False assert multi_label_metric([[0, 1], [1, 0]], [[0, 1], [0, 1]]) == {'precision': 100.0, 'recall': 100.0} # noqa @@ -272,7 +273,7 @@ def test_metric_accurate_is_onehot(): ) def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length): """Metamorphic testing for NumPy and PyTorch implementation.""" - multi_label_metric = MultiLabelMetric(**metric_kwargs) + multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) preds = np.random.rand(length, classes_num) labels = np.random.randint(0, classes_num, length) @@ -304,7 +305,7 @@ def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length): ) def test_metamorphic_numpy_oneflow(metric_kwargs, classes_num, length): """Metamorphic testing for NumPy and OneFlow implementation.""" - multi_label_metric = MultiLabelMetric(**metric_kwargs) + multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) preds = np.random.rand(length, classes_num) labels = np.random.randint(0, classes_num, length) diff --git a/tests/test_metrics/test_precision_recall_f1score.py b/tests/test_metrics/test_precision_recall_f1score.py new file mode 100644 index 00000000..00956a03 --- /dev/null +++ b/tests/test_metrics/test_precision_recall_f1score.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# yapf: disable + +import numpy as np +import pytest + +from mmeval.core.base_metric import BaseMetric +from mmeval.metrics import PrecsionRecallF1score + + +def test_metric_init_assertion(): + with pytest.raises( + AssertionError, match='`num_classes` is necessary'): + PrecsionRecallF1score(task='multilabel', num_classes=None) + with pytest.raises( + AssertionError, match="task `'multilabel'` only supports"): + PrecsionRecallF1score(task='multilabel', thrs=(0., 0.1, 0.4)) + with pytest.raises( + ValueError, match='Expected argument `task` to either be'): + PrecsionRecallF1score(task='threeclasses') + + +def test_metric_interface(): + preds = np.array([2, 0, 1, 1]) + labels = np.array([2, 1, 2, 0]) + + # test predictions with labels + metric = PrecsionRecallF1score(task='singlelabel', num_classes=3) + assert isinstance(metric, BaseMetric) + results = metric(preds, labels) + assert isinstance(results, dict) + + # test predictions with pred_scores + metric = PrecsionRecallF1score(task='multilabel', num_classes=3) + assert isinstance(metric, BaseMetric) + results = metric(preds, labels) + assert isinstance(results, dict) diff --git a/tests/test_metrics/test_single_label_metric.py b/tests/test_metrics/test_single_label_precision_recall_f1score.py similarity index 85% rename from tests/test_metrics/test_single_label_metric.py rename to tests/test_metrics/test_single_label_precision_recall_f1score.py index d52985b4..f8e70506 100644 --- a/tests/test_metrics/test_single_label_metric.py +++ b/tests/test_metrics/test_single_label_precision_recall_f1score.py @@ -7,7 +7,7 @@ from distutils.version import LooseVersion from mmeval.core.base_metric import BaseMetric -from mmeval.metrics import SingleLabelMetric +from mmeval.metrics import SingleLabelPrecsionRecallF1score from mmeval.utils import try_import torch = try_import('torch') @@ -17,20 +17,20 @@ def test_metric_init_assertion(): with pytest.raises(AssertionError, match='Invalid `average` argument'): - SingleLabelMetric(average='mean') + SingleLabelPrecsionRecallF1score(average='mean') with pytest.raises(AssertionError, match='The metric map is not supported'): - SingleLabelMetric(items=('map',)) + SingleLabelPrecsionRecallF1score(items=('map',)) def test_metric_assertion(): - single_label_metric = SingleLabelMetric() + single_label_metric = SingleLabelPrecsionRecallF1score() with pytest.raises(AssertionError, match='Please specify `num_classes`'): single_label_metric( np.asarray([1, 2, 3]), np.asarray([3, 2, 1])) - single_label_metric = SingleLabelMetric(num_classes=2) + single_label_metric = SingleLabelPrecsionRecallF1score(num_classes=2) with pytest.raises(AssertionError, match='Number of classes does not match'): single_label_metric( @@ -39,12 +39,12 @@ def test_metric_assertion(): @pytest.mark.skipif(torch is None, reason='PyTorch is not available!') def test_metric_torch_assertion(): - single_label_metric = SingleLabelMetric() + single_label_metric = SingleLabelPrecsionRecallF1score() with pytest.raises(AssertionError, match='Please specify `num_classes`'): single_label_metric( torch.Tensor([1, 2, 3]), torch.Tensor([3, 2, 1])) - single_label_metric = SingleLabelMetric(num_classes=2) + single_label_metric = SingleLabelPrecsionRecallF1score(num_classes=2) with pytest.raises(AssertionError, match='Number of classes does not match'): single_label_metric( @@ -55,12 +55,12 @@ def test_metric_torch_assertion(): LooseVersion(flow.__version__) < '0.8.1', reason='OneFlow >= 0.8.1 is required!') def test_metric_oneflow_assertion(): - single_label_metric = SingleLabelMetric() + single_label_metric = SingleLabelPrecsionRecallF1score() with pytest.raises(AssertionError, match='Please specify `num_classes`'): single_label_metric( flow.Tensor([1, 2, 3]), flow.Tensor([3, 2, 1])) - single_label_metric = SingleLabelMetric(num_classes=2) + single_label_metric = SingleLabelPrecsionRecallF1score(num_classes=2) with pytest.raises(AssertionError, match='Number of classes does not match'): single_label_metric( @@ -82,14 +82,15 @@ def test_metric_oneflow_assertion(): ) def test_metric_interface(metric_kwargs): # test predictions with labels - single_label_metric = SingleLabelMetric(**metric_kwargs) + single_label_metric = SingleLabelPrecsionRecallF1score(**metric_kwargs) assert isinstance(single_label_metric, BaseMetric) assert isinstance(single_label_metric.thrs, tuple) results = single_label_metric( np.asarray([[0.1, 0.9], [0.5, 0.5]]), np.asarray([0, 1])) # test predictions with pred_scores - single_label_metric = SingleLabelMetric(**metric_kwargs, num_classes=4) + single_label_metric = SingleLabelPrecsionRecallF1score( + **metric_kwargs, num_classes=4) assert isinstance(single_label_metric, BaseMetric) assert isinstance(single_label_metric.thrs, tuple) results = single_label_metric( @@ -100,13 +101,13 @@ def test_metric_interface(metric_kwargs): @pytest.mark.skipif(torch is None, reason='PyTorch is not available!') def test_metric_input_torch(): # test predictions with labels - single_label_metric = SingleLabelMetric() + single_label_metric = SingleLabelPrecsionRecallF1score() results = single_label_metric( torch.Tensor([[0.1, 0.9], [0.5, 0.5]]), torch.Tensor([0, 1])) assert isinstance(results, dict) # test predictions with pred_scores - single_label_metric = SingleLabelMetric(num_classes=4) + single_label_metric = SingleLabelPrecsionRecallF1score(num_classes=4) results = single_label_metric( torch.Tensor([1, 2, 3]), torch.Tensor([3, 2, 1])) assert isinstance(results, dict) @@ -117,13 +118,13 @@ def test_metric_input_torch(): reason='OneFlow >= 0.8.1 is required!') def test_metric_input_oneflow(): # test predictions with labels - single_label_metric = SingleLabelMetric() + single_label_metric = SingleLabelPrecsionRecallF1score() results = single_label_metric( flow.Tensor([[0.1, 0.9], [0.5, 0.5]]), flow.Tensor([0, 1])) assert isinstance(results, dict) # test predictions with pred_scores - single_label_metric = SingleLabelMetric(num_classes=4) + single_label_metric = SingleLabelPrecsionRecallF1score(num_classes=4) results = single_label_metric( flow.Tensor([1, 2, 3]), flow.Tensor([3, 2, 1])) assert isinstance(results, dict) @@ -132,13 +133,13 @@ def test_metric_input_oneflow(): @pytest.mark.skipif(torch is None, reason='PyTorch is not available!') def test_metric_input_builtin(): # test predictions with labels - single_label_metric = SingleLabelMetric() + single_label_metric = SingleLabelPrecsionRecallF1score() results = single_label_metric( [[0.1, 0.9], [0.5, 0.5]], [0, 1]) assert isinstance(results, dict) # test predictions with pred_scores - single_label_metric = SingleLabelMetric(num_classes=4) + single_label_metric = SingleLabelPrecsionRecallF1score(num_classes=4) results = single_label_metric( [1, 2, 3], [3, 2, 1]) assert isinstance(results, dict) @@ -173,7 +174,7 @@ def test_metric_input_builtin(): ] ) def test_metric_accurate(metric_kwargs, predictions, labels, results): - single_label_metric = SingleLabelMetric(**metric_kwargs) + single_label_metric = SingleLabelPrecsionRecallF1score(**metric_kwargs) assert single_label_metric( np.asarray(predictions), np.asarray(labels)) == results @@ -190,7 +191,7 @@ def test_metric_accurate(metric_kwargs, predictions, labels, results): ) def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length): """Metamorphic testing for NumPy and PyTorch implementation.""" - single_label_metric = SingleLabelMetric(**metric_kwargs) + single_label_metric = SingleLabelPrecsionRecallF1score(**metric_kwargs) predictions = np.random.rand(length, classes_num) labels = np.random.randint(0, classes_num, length) @@ -223,7 +224,7 @@ def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length): ) def test_metamorphic_numpy_oneflow(metric_kwargs, classes_num, length): """Metamorphic testing for NumPy and OneFlow implementation.""" - single_label_metric = SingleLabelMetric(**metric_kwargs) + single_label_metric = SingleLabelPrecsionRecallF1score(**metric_kwargs) predictions = np.random.rand(length, classes_num) labels = np.random.randint(0, classes_num, length) From 2a6efd14795a91dc3c4bd25c6a1e10cd285d4779 Mon Sep 17 00:00:00 2001 From: huyingfan Date: Wed, 15 Feb 2023 14:15:39 +0800 Subject: [PATCH 2/7] fix according to comments --- mmeval/metrics/__init__.py | 10 +++--- mmeval/metrics/average_precision.py | 6 ++-- mmeval/metrics/precision_recall_f1score.py | 36 +++++++++---------- mmeval/metrics/utils/multi_label.py | 5 +++ ...st_multi_label_precision_recall_f1score.py | 30 ++++++++-------- .../test_precision_recall_f1score.py | 14 ++++---- 6 files changed, 55 insertions(+), 46 deletions(-) diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 533a54a1..bccc09bd 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -24,9 +24,9 @@ from .oid_map import OIDMeanAP from .pck_accuracy import JhmdbPCKAccuracy, MpiiPCKAccuracy, PCKAccuracy from .perplexity import Perplexity -from .precision_recall_f1score import (MultiLabelPrecsionRecallF1score, - PrecsionRecallF1score, - SingleLabelPrecsionRecallF1score) +from .precision_recall_f1score import (MultiLabelPrecisionRecallF1score, + PrecisionRecallF1score, + SingleLabelPrecisionRecallF1score) from .proposal_recall import ProposalRecall from .psnr import PeakSignalNoiseRatio from .rouge import ROUGE @@ -45,8 +45,8 @@ 'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError', 'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError', 'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator', - 'WordAccuracy', 'PrecsionRecallF1score', - 'SingleLabelPrecsionRecallF1score', 'MultiLabelPrecsionRecallF1score' + 'WordAccuracy', 'PrecisionRecallF1score', + 'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score' ] _deprecated_msg = ( diff --git a/mmeval/metrics/average_precision.py b/mmeval/metrics/average_precision.py index 496d02e9..031c5149 100644 --- a/mmeval/metrics/average_precision.py +++ b/mmeval/metrics/average_precision.py @@ -269,7 +269,8 @@ def add(self, preds: Sequence, labels: Sequence) -> None: # type: ignore # yapf Args: preds (Sequence): Predictions from the model. It should be scores of every class (N, C). - labels (Sequence): The ground truth labels. It should be (N, ). + labels (Sequence): The ground truth labels. It should be (N, ) for + label-format, or (N, C) for one-hot encoding. """ for pred, target in zip(preds, labels): self._results.append((pred, target)) @@ -278,7 +279,8 @@ def _format_metric_results(self, ap): """Format the given metric results into a dictionary. Args: - results (list): Results of precision, recall, f1 and support. + ap (list): Results of average precision for each categories + or the single marco result. Returns: dict: The formatted dictionary. diff --git a/mmeval/metrics/precision_recall_f1score.py b/mmeval/metrics/precision_recall_f1score.py index 83e9776f..a9cb912e 100644 --- a/mmeval/metrics/precision_recall_f1score.py +++ b/mmeval/metrics/precision_recall_f1score.py @@ -29,23 +29,23 @@ Union[int, Sequence[int]]] -class PrecsionRecallF1score: - """Wrapper to get different task of PrecsionRecallF1score calculation, by +class PrecisionRecallF1score: + """Wrapper to get different task of PrecisionRecallF1score calculation, by setting the ``task`` argument to either ``'singlelabel'`` or ``multilabel``. - See the documentation of :mod:`SingleLabelPrecsionRecallF1score` and - :mod:`MultiLabelPrecsionRecallF1score` for the detailed usages and + See the documentation of :mod:`SingleLabelPrecisionRecallF1score` and + :mod:`MultiLabelPrecisionRecallF1score` for the detailed usages and examples. Examples: >>> import torch >>> preds = torch.tensor([2, 0, 1, 1]) >>> labels = torch.tensor([2, 1, 2, 0]) - >>> metric = PrecsionRecallF1score(num_classes=3) + >>> metric = PrecisionRecallF1score(num_classes=3) >>> metric(preds, labels) {'precision': 33.3333, 'recall': 16.6667, 'f1-score': 22.2222} - >>> metric = PrecsionRecallF1score( + >>> metric = PrecisionRecallF1score( task="multilabel", average='micro', num_classes=3) >>> metric(preds, labels) {'precision_micro': 25.0, 'recall_micro': 25.0, 'f1-score_micro': 25.0} @@ -61,7 +61,7 @@ def __new__(cls, **kwargs): if task == 'singlelabel': - return SingleLabelPrecsionRecallF1score( + return SingleLabelPrecisionRecallF1score( num_classes=num_classes, thrs=thrs, items=items, @@ -72,7 +72,7 @@ def __new__(cls, "task `'multilabel'` only supports single threshold or None." assert isinstance(num_classes, int), \ '`num_classes` is necessary for multi-label metrics.' - return MultiLabelPrecsionRecallF1score( + return MultiLabelPrecisionRecallF1score( num_classes=num_classes, thr=thrs, topk=topk, @@ -169,7 +169,7 @@ def _precision_recall_f1_support(pred_positive: Union[np.ndarray, return precision, recall, f1_score, support -class SingleLabelPrecsionRecallF1score(BaseMetric): +class SingleLabelPrecisionRecallF1score(BaseMetric): """A collection of metrics for single-label multi-class classification task based on confusion matrix. @@ -210,8 +210,8 @@ class SingleLabelPrecsionRecallF1score(BaseMetric): Examples: - >>> from mmeval import SingleLabelPrecsionRecallF1score - >>> single_lable_metic = SingleLabelPrecsionRecallF1score(num_classes=4) + >>> from mmeval import SingleLabelPrecisionRecallF1score + >>> single_lable_metic = SingleLabelPrecisionRecallF1score(num_classes=4) Use NumPy implementation: @@ -237,7 +237,7 @@ class SingleLabelPrecsionRecallF1score(BaseMetric): [0.3, 0.4, 0.2, 0.1], [0.0, 0.0, 0.1, 0.9]]) >>> labels = np.asarray([0, 1, 2, 3]) - >>> single_lable_metic = SingleLabelPrecsionRecallF1score(average='micro') + >>> single_lable_metic = SingleLabelPrecisionRecallF1score(average='micro') >>> single_lable_metic(preds, labels) {'precision_micro': 50.0, 'recall_micro': 50.0, 'f1-score_micro': 50.0} # noqa @@ -532,7 +532,7 @@ def compute_metric( return self._format_metric_results(metric_results) -class MultiLabelPrecsionRecallF1score(MultiLabelMixin, BaseMetric): +class MultiLabelPrecisionRecallF1score(MultiLabelMixin, BaseMetric): """A collection of metrics for multi-label multi-class classification task based on confusion matrix. @@ -571,15 +571,15 @@ class MultiLabelPrecsionRecallF1score(MultiLabelMixin, BaseMetric): Defaults to "macro". .. note:: - MultiLabelPrecsionRecallF1score supports different kinds of inputs. Such as: + MultiLabelPrecisionRecallF1score supports different kinds of inputs. Such as: 1. Each sample has scores for every classes. (Only for predictions) 2. Each sample has one-hot indices for every classes. 3. Each sample has label-format indices. Examples: - >>> from mmeval import MultiLabelPrecsionRecallF1score - >>> multi_lable_metic = MultiLabelPrecsionRecallF1score(num_classes=4) + >>> from mmeval import MultiLabelPrecisionRecallF1score + >>> multi_lable_metic = MultiLabelPrecisionRecallF1score(num_classes=4) Use Builtin implementation with raw indices: @@ -638,7 +638,7 @@ class MultiLabelPrecsionRecallF1score(MultiLabelMixin, BaseMetric): [0.3, 0.4, 0.2, 0.1], [0.0, 0.0, 0.1, 0.9]]) >>> labels = np.array([0, 1, 2, 3]) - >>> multi_lable_metic = MultiLabelPrecsionRecallF1score(4, average='micro', topk=2) + >>> multi_lable_metic = MultiLabelPrecisionRecallF1score(4, average='micro', topk=2) >>> multi_lable_metic(preds, labels) {'precision_top2_micro': 37.5, 'recall_top2_micro': 75.0, 'f1-score_top2_micro': 50.0} # noqa @@ -674,7 +674,7 @@ def __init__(self, for item in items: assert item in ['precision', 'recall', 'f1-score', 'support'], \ f'The metric {item} is not supported by' \ - ' `MultiLabelPrecsionRecallF1score`,' \ + ' `MultiLabelPrecisionRecallF1score`,' \ ' please specify from "precision", "recall", "f1-score" and ' \ '"support".' self.items = tuple(items) diff --git a/mmeval/metrics/utils/multi_label.py b/mmeval/metrics/utils/multi_label.py index 73b30eb4..042308f6 100644 --- a/mmeval/metrics/utils/multi_label.py +++ b/mmeval/metrics/utils/multi_label.py @@ -108,11 +108,16 @@ def format_data( 'parms related to `is_onehot` if use one-hot ' 'encoding data to compute metrics.') assert is_onehot + is_onehot = True # Error corresponds to np, torch, oneflow, stack_func respectively except (ValueError, RuntimeError, AssertionError): + is_onehot = False + + if not is_onehot: # convert label-format inputs to one-hot encodings formated_data = stack_func( [label_to_onehot(sample, num_classes) for sample in data]) + return formated_data diff --git a/tests/test_metrics/test_multi_label_precision_recall_f1score.py b/tests/test_metrics/test_multi_label_precision_recall_f1score.py index 8bdd31b6..e72d057f 100644 --- a/tests/test_metrics/test_multi_label_precision_recall_f1score.py +++ b/tests/test_metrics/test_multi_label_precision_recall_f1score.py @@ -7,7 +7,7 @@ from distutils.version import LooseVersion from mmeval.core.base_metric import BaseMetric -from mmeval.metrics import MultiLabelPrecsionRecallF1score +from mmeval.metrics import MultiLabelPrecisionRecallF1score from mmeval.utils import try_import torch = try_import('torch') @@ -17,10 +17,10 @@ def test_metric_init_assertion(): with pytest.raises(AssertionError, match='Invalid `average` argument'): - MultiLabelPrecsionRecallF1score(3, average='mean') + MultiLabelPrecisionRecallF1score(3, average='mean') with pytest.raises(AssertionError, match='The metric map is not supported'): - MultiLabelPrecsionRecallF1score(3, items=('map',)) + MultiLabelPrecisionRecallF1score(3, items=('map',)) @pytest.mark.parametrize( @@ -40,7 +40,7 @@ def test_metric_init_assertion(): ) def test_metric_inputs(metric_kwargs): # test predictions with labels - multi_label_metric = MultiLabelPrecsionRecallF1score( + multi_label_metric = MultiLabelPrecisionRecallF1score( num_classes=3, **metric_kwargs) assert isinstance(multi_label_metric, BaseMetric) results = multi_label_metric( @@ -59,7 +59,7 @@ def test_metric_inputs(metric_kwargs): ]) def test_metric_interface_builtin(metric_kwargs, preds, labels): """Test builtin inputs.""" - multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) + multi_label_metric = MultiLabelPrecisionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -78,7 +78,7 @@ def test_metric_interface_builtin(metric_kwargs, preds, labels): ]) def test_metric_interface_topk(metric_kwargs, preds, labels): """Test scores inputs with topk.""" - multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) + multi_label_metric = MultiLabelPrecisionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -98,7 +98,7 @@ def test_metric_interface_topk(metric_kwargs, preds, labels): @pytest.mark.skipif(torch is None, reason='PyTorch is not available!') def test_metric_interface_torch_topk(metric_kwargs, preds, labels): """Test scores inputs with topk in torch.""" - multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) + multi_label_metric = MultiLabelPrecisionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -128,7 +128,7 @@ def test_metric_interface_oneflow_topk(metric_kwargs, preds, labels): labels = flow.tensor(labels) else: labels = list(flow.tensor(label) for label in labels) - multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) + multi_label_metric = MultiLabelPrecisionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -151,7 +151,7 @@ def test_metric_interface_oneflow_topk(metric_kwargs, preds, labels): ]) def test_metric_interface(metric_kwargs, preds, labels): """Test all kinds of inputs.""" - multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) + multi_label_metric = MultiLabelPrecisionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -175,7 +175,7 @@ def test_metric_interface(metric_kwargs, preds, labels): @pytest.mark.skipif(torch is None, reason='PyTorch is not available!') def test_metric_interface_torch(metric_kwargs, preds, labels): """Test all kinds of inputs in torch.""" - multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) + multi_label_metric = MultiLabelPrecisionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -209,7 +209,7 @@ def test_metric_interface_oneflow(metric_kwargs, preds, labels): labels = flow.tensor(labels) else: labels = list(flow.tensor(label) for label in labels) - multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) + multi_label_metric = MultiLabelPrecisionRecallF1score(**metric_kwargs) results = multi_label_metric(preds, labels) assert isinstance(results, dict) @@ -244,14 +244,14 @@ def test_metric_interface_oneflow(metric_kwargs, preds, labels): ) def test_metric_accurate(metric_kwargs, preds, labels, results): """Test accurate.""" - multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) + multi_label_metric = MultiLabelPrecisionRecallF1score(**metric_kwargs) assert multi_label_metric( np.asarray(preds), np.asarray(labels)) == results def test_metric_accurate_is_onehot(): """Test ambiguous cases when num_classes=2.""" - multi_label_metric = MultiLabelPrecsionRecallF1score(num_classes=2, items=('precision', 'recall')) # noqa + multi_label_metric = MultiLabelPrecisionRecallF1score(num_classes=2, items=('precision', 'recall')) # noqa assert multi_label_metric.pred_is_onehot is False assert multi_label_metric.label_is_onehot is False assert multi_label_metric([[0, 1], [1, 0]], [[0, 1], [0, 1]]) == {'precision': 100.0, 'recall': 100.0} # noqa @@ -273,7 +273,7 @@ def test_metric_accurate_is_onehot(): ) def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length): """Metamorphic testing for NumPy and PyTorch implementation.""" - multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) + multi_label_metric = MultiLabelPrecisionRecallF1score(**metric_kwargs) preds = np.random.rand(length, classes_num) labels = np.random.randint(0, classes_num, length) @@ -305,7 +305,7 @@ def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length): ) def test_metamorphic_numpy_oneflow(metric_kwargs, classes_num, length): """Metamorphic testing for NumPy and OneFlow implementation.""" - multi_label_metric = MultiLabelPrecsionRecallF1score(**metric_kwargs) + multi_label_metric = MultiLabelPrecisionRecallF1score(**metric_kwargs) preds = np.random.rand(length, classes_num) labels = np.random.randint(0, classes_num, length) diff --git a/tests/test_metrics/test_precision_recall_f1score.py b/tests/test_metrics/test_precision_recall_f1score.py index 00956a03..543cac7b 100644 --- a/tests/test_metrics/test_precision_recall_f1score.py +++ b/tests/test_metrics/test_precision_recall_f1score.py @@ -6,19 +6,19 @@ import pytest from mmeval.core.base_metric import BaseMetric -from mmeval.metrics import PrecsionRecallF1score +from mmeval.metrics import PrecisionRecallF1score def test_metric_init_assertion(): with pytest.raises( AssertionError, match='`num_classes` is necessary'): - PrecsionRecallF1score(task='multilabel', num_classes=None) + PrecisionRecallF1score(task='multilabel', num_classes=None) with pytest.raises( AssertionError, match="task `'multilabel'` only supports"): - PrecsionRecallF1score(task='multilabel', thrs=(0., 0.1, 0.4)) + PrecisionRecallF1score(task='multilabel', thrs=(0., 0.1, 0.4)) with pytest.raises( ValueError, match='Expected argument `task` to either be'): - PrecsionRecallF1score(task='threeclasses') + PrecisionRecallF1score(task='threeclasses') def test_metric_interface(): @@ -26,13 +26,15 @@ def test_metric_interface(): labels = np.array([2, 1, 2, 0]) # test predictions with labels - metric = PrecsionRecallF1score(task='singlelabel', num_classes=3) + metric = PrecisionRecallF1score(task='singlelabel', num_classes=3) assert isinstance(metric, BaseMetric) results = metric(preds, labels) assert isinstance(results, dict) + assert results == {'precision': 33.3333, 'recall': 16.6667, 'f1-score': 22.2222} # noqa # test predictions with pred_scores - metric = PrecsionRecallF1score(task='multilabel', num_classes=3) + metric = PrecisionRecallF1score(task='multilabel', num_classes=3) assert isinstance(metric, BaseMetric) results = metric(preds, labels) assert isinstance(results, dict) + assert results == {'precision': 33.3333, 'recall': 16.6667, 'f1-score': 22.2222} # noqa From b021a52d0cc8e6b527c08d55d5ce5133d442b60d Mon Sep 17 00:00:00 2001 From: huyingfan Date: Fri, 17 Feb 2023 15:43:07 +0800 Subject: [PATCH 3/7] minor fix --- mmeval/metrics/average_precision.py | 1 + mmeval/metrics/utils/multi_label.py | 55 ++++++++++--------- ...t_single_label_precision_recall_f1score.py | 40 +++++++------- 3 files changed, 49 insertions(+), 47 deletions(-) diff --git a/mmeval/metrics/average_precision.py b/mmeval/metrics/average_precision.py index 031c5149..f04616a9 100644 --- a/mmeval/metrics/average_precision.py +++ b/mmeval/metrics/average_precision.py @@ -368,6 +368,7 @@ def compute_metric( This method would be invoked in `BaseMetric.compute` after distributed synchronization. + Args: results (List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS, ONEFLOW_IMPL_HINTS]]): A list of tuples that consisting the diff --git a/mmeval/metrics/utils/multi_label.py b/mmeval/metrics/utils/multi_label.py index 042308f6..b6cc9b22 100644 --- a/mmeval/metrics/utils/multi_label.py +++ b/mmeval/metrics/utils/multi_label.py @@ -83,34 +83,35 @@ def format_data( raise NotImplementedError(f'Data type of {type(data[0])}' 'is not supported.') - try: - # try stack scores or one-hot indices directly + shapes = {d.shape for d in data} + if len(shapes) == 1: + # stack scores or one-hot indices directly if have same shapes formated_data = stack_func(data) - # all assertions below is to find labels that are - # raw indices which should be caught in exception - # to convert to one-hot indices. - # - # 1. all raw indices has only 1 dims - assert formated_data.ndim == 2 - # 2. all raw indices has the same dims - assert formated_data.shape[1] == num_classes - # 3. all raw indices has the same dims as num_classes - # then max indices should greater than 1 for num_classes > 2 - assert formated_data.max() <= 1 - # 4. corner case, num_classes=2, then one-hot indices - # and raw indices are undistinguishable, for instance: - # [[0, 1], [0, 1]] can be one-hot indices of 2 positives - # or raw indices of 4 positives. - # Extra induction is needed. - if num_classes == 2: - warnings.warn('Ambiguous data detected, reckoned as scores' - ' or label-format data as defaults. Please set ' - 'parms related to `is_onehot` if use one-hot ' - 'encoding data to compute metrics.') - assert is_onehot - is_onehot = True - # Error corresponds to np, torch, oneflow, stack_func respectively - except (ValueError, RuntimeError, AssertionError): + # all the conditions below is to find whether labels that are + # raw indices which should be converted to one-hot indices. + # 1. one-hot indices should has 2 dims; + # 2. one-hot indices should has num_classes as the second dim; + # 3. one-hot indices values should always smaller than 2. + if formated_data.ndim == 2 \ + and formated_data.shape[1] == num_classes \ + and formated_data.max() <= 1: + if num_classes > 2: + is_onehot = True + elif num_classes == 2: + # 4. corner case, num_classes=2, then one-hot indices + # and raw indices are undistinguishable, for instance: + # [[0, 1], [0, 1]] can be one-hot indices of 2 positives + # or raw indices of 4 positives. + # Extra induction is needed. + warnings.warn('Ambiguous data detected, reckoned as scores' + ' or label-format data as defaults. Please set ' + 'parms related to `is_onehot` if use one-hot ' + 'encoding data to compute metrics.') + else: + raise ValueError( + 'num_classes should greater than 2 in multi label metrics.' + ) + else: is_onehot = False if not is_onehot: diff --git a/tests/test_metrics/test_single_label_precision_recall_f1score.py b/tests/test_metrics/test_single_label_precision_recall_f1score.py index f8e70506..d0a6a6bc 100644 --- a/tests/test_metrics/test_single_label_precision_recall_f1score.py +++ b/tests/test_metrics/test_single_label_precision_recall_f1score.py @@ -7,7 +7,7 @@ from distutils.version import LooseVersion from mmeval.core.base_metric import BaseMetric -from mmeval.metrics import SingleLabelPrecsionRecallF1score +from mmeval.metrics import SingleLabelPrecisionRecallF1score from mmeval.utils import try_import torch = try_import('torch') @@ -17,20 +17,20 @@ def test_metric_init_assertion(): with pytest.raises(AssertionError, match='Invalid `average` argument'): - SingleLabelPrecsionRecallF1score(average='mean') + SingleLabelPrecisionRecallF1score(average='mean') with pytest.raises(AssertionError, match='The metric map is not supported'): - SingleLabelPrecsionRecallF1score(items=('map',)) + SingleLabelPrecisionRecallF1score(items=('map',)) def test_metric_assertion(): - single_label_metric = SingleLabelPrecsionRecallF1score() + single_label_metric = SingleLabelPrecisionRecallF1score() with pytest.raises(AssertionError, match='Please specify `num_classes`'): single_label_metric( np.asarray([1, 2, 3]), np.asarray([3, 2, 1])) - single_label_metric = SingleLabelPrecsionRecallF1score(num_classes=2) + single_label_metric = SingleLabelPrecisionRecallF1score(num_classes=2) with pytest.raises(AssertionError, match='Number of classes does not match'): single_label_metric( @@ -39,12 +39,12 @@ def test_metric_assertion(): @pytest.mark.skipif(torch is None, reason='PyTorch is not available!') def test_metric_torch_assertion(): - single_label_metric = SingleLabelPrecsionRecallF1score() + single_label_metric = SingleLabelPrecisionRecallF1score() with pytest.raises(AssertionError, match='Please specify `num_classes`'): single_label_metric( torch.Tensor([1, 2, 3]), torch.Tensor([3, 2, 1])) - single_label_metric = SingleLabelPrecsionRecallF1score(num_classes=2) + single_label_metric = SingleLabelPrecisionRecallF1score(num_classes=2) with pytest.raises(AssertionError, match='Number of classes does not match'): single_label_metric( @@ -55,12 +55,12 @@ def test_metric_torch_assertion(): LooseVersion(flow.__version__) < '0.8.1', reason='OneFlow >= 0.8.1 is required!') def test_metric_oneflow_assertion(): - single_label_metric = SingleLabelPrecsionRecallF1score() + single_label_metric = SingleLabelPrecisionRecallF1score() with pytest.raises(AssertionError, match='Please specify `num_classes`'): single_label_metric( flow.Tensor([1, 2, 3]), flow.Tensor([3, 2, 1])) - single_label_metric = SingleLabelPrecsionRecallF1score(num_classes=2) + single_label_metric = SingleLabelPrecisionRecallF1score(num_classes=2) with pytest.raises(AssertionError, match='Number of classes does not match'): single_label_metric( @@ -82,14 +82,14 @@ def test_metric_oneflow_assertion(): ) def test_metric_interface(metric_kwargs): # test predictions with labels - single_label_metric = SingleLabelPrecsionRecallF1score(**metric_kwargs) + single_label_metric = SingleLabelPrecisionRecallF1score(**metric_kwargs) assert isinstance(single_label_metric, BaseMetric) assert isinstance(single_label_metric.thrs, tuple) results = single_label_metric( np.asarray([[0.1, 0.9], [0.5, 0.5]]), np.asarray([0, 1])) # test predictions with pred_scores - single_label_metric = SingleLabelPrecsionRecallF1score( + single_label_metric = SingleLabelPrecisionRecallF1score( **metric_kwargs, num_classes=4) assert isinstance(single_label_metric, BaseMetric) assert isinstance(single_label_metric.thrs, tuple) @@ -101,13 +101,13 @@ def test_metric_interface(metric_kwargs): @pytest.mark.skipif(torch is None, reason='PyTorch is not available!') def test_metric_input_torch(): # test predictions with labels - single_label_metric = SingleLabelPrecsionRecallF1score() + single_label_metric = SingleLabelPrecisionRecallF1score() results = single_label_metric( torch.Tensor([[0.1, 0.9], [0.5, 0.5]]), torch.Tensor([0, 1])) assert isinstance(results, dict) # test predictions with pred_scores - single_label_metric = SingleLabelPrecsionRecallF1score(num_classes=4) + single_label_metric = SingleLabelPrecisionRecallF1score(num_classes=4) results = single_label_metric( torch.Tensor([1, 2, 3]), torch.Tensor([3, 2, 1])) assert isinstance(results, dict) @@ -118,13 +118,13 @@ def test_metric_input_torch(): reason='OneFlow >= 0.8.1 is required!') def test_metric_input_oneflow(): # test predictions with labels - single_label_metric = SingleLabelPrecsionRecallF1score() + single_label_metric = SingleLabelPrecisionRecallF1score() results = single_label_metric( flow.Tensor([[0.1, 0.9], [0.5, 0.5]]), flow.Tensor([0, 1])) assert isinstance(results, dict) # test predictions with pred_scores - single_label_metric = SingleLabelPrecsionRecallF1score(num_classes=4) + single_label_metric = SingleLabelPrecisionRecallF1score(num_classes=4) results = single_label_metric( flow.Tensor([1, 2, 3]), flow.Tensor([3, 2, 1])) assert isinstance(results, dict) @@ -133,13 +133,13 @@ def test_metric_input_oneflow(): @pytest.mark.skipif(torch is None, reason='PyTorch is not available!') def test_metric_input_builtin(): # test predictions with labels - single_label_metric = SingleLabelPrecsionRecallF1score() + single_label_metric = SingleLabelPrecisionRecallF1score() results = single_label_metric( [[0.1, 0.9], [0.5, 0.5]], [0, 1]) assert isinstance(results, dict) # test predictions with pred_scores - single_label_metric = SingleLabelPrecsionRecallF1score(num_classes=4) + single_label_metric = SingleLabelPrecisionRecallF1score(num_classes=4) results = single_label_metric( [1, 2, 3], [3, 2, 1]) assert isinstance(results, dict) @@ -174,7 +174,7 @@ def test_metric_input_builtin(): ] ) def test_metric_accurate(metric_kwargs, predictions, labels, results): - single_label_metric = SingleLabelPrecsionRecallF1score(**metric_kwargs) + single_label_metric = SingleLabelPrecisionRecallF1score(**metric_kwargs) assert single_label_metric( np.asarray(predictions), np.asarray(labels)) == results @@ -191,7 +191,7 @@ def test_metric_accurate(metric_kwargs, predictions, labels, results): ) def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length): """Metamorphic testing for NumPy and PyTorch implementation.""" - single_label_metric = SingleLabelPrecsionRecallF1score(**metric_kwargs) + single_label_metric = SingleLabelPrecisionRecallF1score(**metric_kwargs) predictions = np.random.rand(length, classes_num) labels = np.random.randint(0, classes_num, length) @@ -224,7 +224,7 @@ def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length): ) def test_metamorphic_numpy_oneflow(metric_kwargs, classes_num, length): """Metamorphic testing for NumPy and OneFlow implementation.""" - single_label_metric = SingleLabelPrecsionRecallF1score(**metric_kwargs) + single_label_metric = SingleLabelPrecisionRecallF1score(**metric_kwargs) predictions = np.random.rand(length, classes_num) labels = np.random.randint(0, classes_num, length) From 48b88bc13910dcc66078dc38fb81341ce1318bf1 Mon Sep 17 00:00:00 2001 From: huyingfan Date: Mon, 20 Feb 2023 11:33:29 +0800 Subject: [PATCH 4/7] minor fix --- mmeval/metrics/average_precision.py | 1 + mmeval/metrics/utils/multi_label.py | 94 ++++++++++--------- ...st_multi_label_precision_recall_f1score.py | 8 +- 3 files changed, 59 insertions(+), 44 deletions(-) diff --git a/mmeval/metrics/average_precision.py b/mmeval/metrics/average_precision.py index f04616a9..62efcc0d 100644 --- a/mmeval/metrics/average_precision.py +++ b/mmeval/metrics/average_precision.py @@ -262,6 +262,7 @@ def __init__(self, average: Optional[str] = 'macro', **kwargs) -> None: assert average in average_options, 'Invalid `average` argument, ' \ f'please specify from {average_options}.' self.average = average + self.pred_is_onehot = False def add(self, preds: Sequence, labels: Sequence) -> None: # type: ignore # yapf: disable # noqa: E501 """Add the intermediate results to `self._results`. diff --git a/mmeval/metrics/utils/multi_label.py b/mmeval/metrics/utils/multi_label.py index b6cc9b22..a46aa9d1 100644 --- a/mmeval/metrics/utils/multi_label.py +++ b/mmeval/metrics/utils/multi_label.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np import warnings -from typing import TYPE_CHECKING, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union from mmeval.utils import try_import @@ -56,7 +56,7 @@ def format_data( data: Union[Sequence[Union[np.ndarray, 'torch.Tensor', 'oneflow.Tensor']], np.ndarray, 'torch.Tensor', 'oneflow.Tensor'], num_classes: int, - is_onehot: bool = False + is_onehot: Optional[bool] = None ) -> Union[np.ndarray, 'torch.Tensor', 'oneflow.Tensor']: """Format data from different inputs such as prediction scores, label- format data and one-hot encodings into the same output shape of `(N, @@ -83,41 +83,51 @@ def format_data( raise NotImplementedError(f'Data type of {type(data[0])}' 'is not supported.') - shapes = {d.shape for d in data} - if len(shapes) == 1: - # stack scores or one-hot indices directly if have same shapes - formated_data = stack_func(data) - # all the conditions below is to find whether labels that are - # raw indices which should be converted to one-hot indices. - # 1. one-hot indices should has 2 dims; - # 2. one-hot indices should has num_classes as the second dim; - # 3. one-hot indices values should always smaller than 2. - if formated_data.ndim == 2 \ - and formated_data.shape[1] == num_classes \ - and formated_data.max() <= 1: - if num_classes > 2: - is_onehot = True - elif num_classes == 2: - # 4. corner case, num_classes=2, then one-hot indices - # and raw indices are undistinguishable, for instance: - # [[0, 1], [0, 1]] can be one-hot indices of 2 positives - # or raw indices of 4 positives. - # Extra induction is needed. - warnings.warn('Ambiguous data detected, reckoned as scores' - ' or label-format data as defaults. Please set ' - 'parms related to `is_onehot` if use one-hot ' - 'encoding data to compute metrics.') - else: - raise ValueError( - 'num_classes should greater than 2 in multi label metrics.' - ) - else: - is_onehot = False + def _induct_is_onehot(inferred_data): + """Conduct the input data format.""" + shapes = {d.shape for d in inferred_data} + if len(shapes) == 1: + # stack scores or one-hot indices directly if have same shapes + cand_formated_data = stack_func(inferred_data) + # all the conditions below is to find whether labels that are + # raw indices which should be converted to one-hot indices. + # 1. one-hot indices should has 2 dims; + # 2. one-hot indices should has num_classes as the second dim; + # 3. one-hot indices values should always smaller than 2. + if cand_formated_data.ndim == 2 \ + and cand_formated_data.shape[1] == num_classes \ + and cand_formated_data.max() <= 1: + if num_classes > 2: + return True, cand_formated_data + elif num_classes == 2: + # 4. corner case, num_classes=2, then one-hot indices + # and raw indices are undistinguishable, for instance: + # [[0, 1], [0, 1]] can be one-hot indices of 2 positives + # or raw indices of 4 positives. + # Extra induction is needed. + warnings.warn( + 'Ambiguous data detected, reckoned as scores' + ' or label-format data as defaults. Please set ' + 'parms related to `is_onehot` to `True` if ' + 'use one-hot encoding data to compute metrics.') + return False, None + else: + raise ValueError( + 'num_classes should greater than 2 in multi label' + 'metrics.') + return False, None + + formated_data = None + if is_onehot is None: + is_onehot, formated_data = _induct_is_onehot(data) if not is_onehot: # convert label-format inputs to one-hot encodings formated_data = stack_func( [label_to_onehot(sample, num_classes) for sample in data]) + elif is_onehot and formated_data is None: + # directly stack data if `is_onehot` is set to True without induction + formated_data = stack_func(data) return formated_data @@ -130,41 +140,41 @@ class MultiLabelMixin: def __init__(self, *args, **kwargs) -> None: # pass arguments for multiple inheritances super().__init__(*args, **kwargs) # type: ignore - self._pred_is_onehot = False - self._label_is_onehot = False + self._pred_is_onehot: Optional[bool] = None + self._label_is_onehot: Optional[bool] = None @property - def pred_is_onehot(self) -> bool: + def pred_is_onehot(self) -> Optional[bool]: """Whether prediction is one-hot encodings. - Only works for corner case when num_classes=2 to distinguish one-hot + Only needed for corner case when num_classes=2 to distinguish one-hot encodings or label-format. """ return self._pred_is_onehot @pred_is_onehot.setter - def pred_is_onehot(self, is_onehot: bool): + def pred_is_onehot(self, is_onehot: Optional[bool]): """Set a flag of whether prediction is one-hot encodings. - Only works for corner case when num_classes=2 to distinguish one-hot + Only needed for corner case when num_classes=2 to distinguish one-hot encodings or label-format. """ self._pred_is_onehot = is_onehot @property - def label_is_onehot(self) -> bool: + def label_is_onehot(self) -> Optional[bool]: """Whether label is one-hot encodings. - Only works for corner case when num_classes=2 to distinguish one-hot + Only needed for corner case when num_classes=2 to distinguish one-hot encodings or label-format. """ return self._label_is_onehot @label_is_onehot.setter - def label_is_onehot(self, is_onehot: bool): + def label_is_onehot(self, is_onehot: Optional[bool]): """Set a flag of whether label is one-hot encodings. - Only works for corner case when num_classes=2 to distinguish one-hot + Only needed for corner case when num_classes=2 to distinguish one-hot encodings or label-format. """ self._label_is_onehot = is_onehot diff --git a/tests/test_metrics/test_multi_label_precision_recall_f1score.py b/tests/test_metrics/test_multi_label_precision_recall_f1score.py index e72d057f..3ae482b8 100644 --- a/tests/test_metrics/test_multi_label_precision_recall_f1score.py +++ b/tests/test_metrics/test_multi_label_precision_recall_f1score.py @@ -252,8 +252,12 @@ def test_metric_accurate(metric_kwargs, preds, labels, results): def test_metric_accurate_is_onehot(): """Test ambiguous cases when num_classes=2.""" multi_label_metric = MultiLabelPrecisionRecallF1score(num_classes=2, items=('precision', 'recall')) # noqa - assert multi_label_metric.pred_is_onehot is False - assert multi_label_metric.label_is_onehot is False + assert multi_label_metric.pred_is_onehot is None + assert multi_label_metric.label_is_onehot is None + assert multi_label_metric([[0, 1], [1, 0]], [[0, 1], [0, 1]]) == {'precision': 100.0, 'recall': 100.0} # noqa + multi_label_metric.pred_is_onehot = False + assert multi_label_metric([[0, 1], [1, 0]], [[0, 1], [0, 1]]) == {'precision': 100.0, 'recall': 100.0} # noqa + multi_label_metric.pred_is_onehot = False assert multi_label_metric([[0, 1], [1, 0]], [[0, 1], [0, 1]]) == {'precision': 100.0, 'recall': 100.0} # noqa multi_label_metric.pred_is_onehot = True assert multi_label_metric([[0, 1], [1, 0]], [[0, 1], [0, 1]]) == {'precision': 100.0, 'recall': 50.0} # noqa From 6d6828cdc4252b887eaacc925d233623ae8dd59d Mon Sep 17 00:00:00 2001 From: huyingfan Date: Mon, 20 Feb 2023 11:49:48 +0800 Subject: [PATCH 5/7] minor fix --- tests/test_metrics/test_average_precision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_metrics/test_average_precision.py b/tests/test_metrics/test_average_precision.py index 760e7479..e6d513c1 100644 --- a/tests/test_metrics/test_average_precision.py +++ b/tests/test_metrics/test_average_precision.py @@ -170,7 +170,7 @@ def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length): for key in np_acc_results: # numpy use float64 however torch use float32 np.testing.assert_allclose( - np_acc_results[key], torch_acc_results[key], rtol=1e-5) + np_acc_results[key], torch_acc_results[key], atol=1e-4) @pytest.mark.skipif(flow is None, reason='OneFlow is not available!') @@ -199,4 +199,4 @@ def test_metamorphic_numpy_oneflow(metric_kwargs, classes_num, length): for key in np_acc_results: # numpy use float64 however oneflow use float32 np.testing.assert_allclose( - np_acc_results[key], oneflow_acc_results[key], rtol=1e-5) + np_acc_results[key], oneflow_acc_results[key], atol=1e-4) From beefd37f4624509cdd0834d4a0dc62ffb20969f4 Mon Sep 17 00:00:00 2001 From: huyingfan Date: Mon, 20 Feb 2023 15:46:01 +0800 Subject: [PATCH 6/7] minor fix --- mmeval/metrics/utils/multi_label.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mmeval/metrics/utils/multi_label.py b/mmeval/metrics/utils/multi_label.py index a46aa9d1..1034f49b 100644 --- a/mmeval/metrics/utils/multi_label.py +++ b/mmeval/metrics/utils/multi_label.py @@ -68,6 +68,8 @@ def format_data( The input data of prediction or labels. num_classes (int): The number of classes. is_onehot (bool): Whether the data is one-hot encodings. + If `None`, this will be automatically inducted. + Defaults to `None`. Return: torch.Tensor or oneflow.Tensor or np.ndarray: From 540633b63be6b97d42c101f270abdc70caeb73c9 Mon Sep 17 00:00:00 2001 From: yingfhu Date: Wed, 8 Mar 2023 19:39:30 +0800 Subject: [PATCH 7/7] fix bc-breaking --- mmeval/metrics/__init__.py | 4 +++- mmeval/metrics/precision_recall_f1score.py | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index bccc09bd..9d7a66aa 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -63,7 +63,9 @@ 'SNR': 'SignalNoiseRatio', 'SSIM': 'StructuralSimilarity', 'SAD': 'SumAbsoluteDifferences', - 'MattingMSE': 'MattingMeanSquaredError' + 'MattingMSE': 'MattingMeanSquaredError', + 'SingleLabelMetric': 'SingleLabelPrecisionRecallF1score', + 'MultiLabelMetric': 'MultiLabelPrecisionRecallF1score' } diff --git a/mmeval/metrics/precision_recall_f1score.py b/mmeval/metrics/precision_recall_f1score.py index a9cb912e..3ca661b7 100644 --- a/mmeval/metrics/precision_recall_f1score.py +++ b/mmeval/metrics/precision_recall_f1score.py @@ -872,3 +872,9 @@ def compute_metric( labels = [res[1] for res in results] metric_results = self._compute_metric(preds, labels) return self._format_metric_results(metric_results) + + +# Keep the deprecated metric name as an alias. +# The deprecated Metric names will be removed in 1.0.0! +SingleLabelMetric = SingleLabelPrecisionRecallF1score +MultiLabelMetric = MultiLabelPrecisionRecallF1score