diff --git a/mmtrack/evaluation/metrics/coco_video_metric.py b/mmtrack/evaluation/metrics/coco_video_metric.py index 158917cbd..aa858a2e6 100644 --- a/mmtrack/evaluation/metrics/coco_video_metric.py +++ b/mmtrack/evaluation/metrics/coco_video_metric.py @@ -1,15 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings -from typing import Optional, Sequence +from typing import Sequence -from mmdet.datasets.api_wrappers import COCO from mmdet.evaluation import CocoMetric from mmdet.structures.mask import encode_mask_results -from mmengine.dist import broadcast_object_list, is_main_process -from mmengine.fileio import FileClient from mmtrack.registry import METRICS -from .base_video_metrics import collect_tracking_results @METRICS.register_module() @@ -19,22 +14,20 @@ class CocoVideoMetric(CocoMetric): Evaluate AR, AP, and mAP for detection tasks including proposal/box detection and instance segmentation. Please refer to https://cocodataset.org/#detection-eval for more details. + + dist_collect_mode (str, optional): The method of concatenating the + collected synchronization results. This depends on how the + distributed data is split. Currently only 'unzip' and 'cat' are + supported. For samplers in MMTrakcking, 'cat' should + be used. Defaults to 'cat'. """ - def __init__(self, ann_file: Optional[str] = None, **kwargs) -> None: - super().__init__(**kwargs) - # if ann_file is not specified, - # initialize coco api with the converted dataset - if ann_file: - file_client = FileClient.infer_client(uri=ann_file) - with file_client.get_local_path(ann_file) as local_path: - self._coco_api = COCO(local_path) - else: - self._coco_api = None + def __init__(self, dist_collect_mode='cat', **kwargs) -> None: + super().__init__(dist_collect_mode=dist_collect_mode, **kwargs) def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: """Process one batch of data samples and predictions. The processed - results should be stored in ``self.results``, which will be used to + results should be stored in ``self._results``, which will be used to compute the metrics when all batches have been processed. Note that we only modify ``pred['pred_instances']`` in ``CocoMetric`` @@ -45,67 +38,28 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: data_samples (Sequence[dict]): A batch of data samples that contain annotations and predictions. """ + predictions, groundtruths = [], [] for data_sample in data_samples: - result = dict() - pred = data_sample['pred_det_instances'] - result['img_id'] = data_sample['img_id'] - result['bboxes'] = pred['bboxes'].cpu().numpy() - result['scores'] = pred['scores'].cpu().numpy() - result['labels'] = pred['labels'].cpu().numpy() - # encode mask to RLE - if 'masks' in pred: - result['masks'] = encode_mask_results( - pred['masks'].detach().cpu().numpy()) + pred = dict() + pred_instances = data_sample['pred_det_instances'] + pred['img_id'] = data_sample['img_id'] + pred['bboxes'] = pred_instances['bboxes'].cpu().numpy() + pred['scores'] = pred_instances['scores'].cpu().numpy() + pred['labels'] = pred_instances['labels'].cpu().numpy() + if 'masks' in pred_instances: + pred['masks'] = encode_mask_results( + pred_instances['masks'].detach().cpu().numpy()) # some detectors use different scores for bbox and mask - if 'mask_scores' in pred: - result['mask_scores'] = pred['mask_scores'].cpu().numpy() + if 'mask_scores' in pred_instances: + pred['mask_scores'] = \ + pred_instances['mask_scores'].cpu().numpy() + predictions.append(pred) # parse gt - gt = dict() - gt['width'] = data_sample['ori_shape'][1] - gt['height'] = data_sample['ori_shape'][0] - gt['img_id'] = data_sample['img_id'] if self._coco_api is None: - assert 'instances' in data_sample, \ - 'ground truth is required for evaluation when ' \ - '`ann_file` is not provided' - gt['anns'] = data_sample['instances'] - # add converted result to the results list - self.results.append((gt, result)) - - def evaluate(self, size: int) -> dict: - """Evaluate the model performance of the whole dataset after processing - all batches. - - Args: - size (int): Length of the entire validation dataset. - - Returns: - dict: Evaluation metrics dict on the val dataset. The keys are the - names of the metrics, and the values are corresponding results. - """ - if len(self.results) == 0: - warnings.warn( - f'{self.__class__.__name__} got empty `self.results`. Please ' - 'ensure that the processed results are properly added into ' - '`self.results` in `process` method.') - - results = collect_tracking_results(self.results, self.collect_device) - - if is_main_process(): - _metrics = self.compute_metrics(results) # type: ignore - # Add prefix to metric names - if self.prefix: - _metrics = { - '/'.join((self.prefix, k)): v - for k, v in _metrics.items() - } - metrics = [_metrics] - else: - metrics = [None] # type: ignore - - broadcast_object_list(metrics) + ann = self.add_gt(data_sample) + else: + ann = dict() + groundtruths.append(ann) - # reset the results list - self.results.clear() - return metrics[0] + self.add(predictions, groundtruths)