Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Refactor CocoVideoMetric to use MMEval #742

Open
wants to merge 4 commits into
base: dev-1.x
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 29 additions & 75 deletions mmtrack/evaluation/metrics/coco_video_metric.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Optional, Sequence
from typing import Sequence

from mmdet.datasets.api_wrappers import COCO
from mmdet.evaluation import CocoMetric
from mmdet.structures.mask import encode_mask_results
from mmengine.dist import broadcast_object_list, is_main_process
from mmengine.fileio import FileClient

from mmtrack.registry import METRICS
from .base_video_metrics import collect_tracking_results


@METRICS.register_module()
Expand All @@ -19,22 +14,20 @@ class CocoVideoMetric(CocoMetric):
Evaluate AR, AP, and mAP for detection tasks including proposal/box
detection and instance segmentation. Please refer to
https://cocodataset.org/#detection-eval for more details.

dist_collect_mode (str, optional): The method of concatenating the
collected synchronization results. This depends on how the
distributed data is split. Currently only 'unzip' and 'cat' are
supported. For samplers in MMTrakcking, 'cat' should
be used. Defaults to 'cat'.
"""

def __init__(self, ann_file: Optional[str] = None, **kwargs) -> None:
super().__init__(**kwargs)
# if ann_file is not specified,
# initialize coco api with the converted dataset
if ann_file:
file_client = FileClient.infer_client(uri=ann_file)
with file_client.get_local_path(ann_file) as local_path:
self._coco_api = COCO(local_path)
else:
self._coco_api = None
def __init__(self, dist_collect_mode='cat', **kwargs) -> None:
super().__init__(dist_collect_mode=dist_collect_mode, **kwargs)

def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
results should be stored in ``self._results``, which will be used to
compute the metrics when all batches have been processed.

Note that we only modify ``pred['pred_instances']`` in ``CocoMetric``
Expand All @@ -45,67 +38,28 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
data_samples (Sequence[dict]): A batch of data samples that
contain annotations and predictions.
"""
predictions, groundtruths = [], []
for data_sample in data_samples:
result = dict()
pred = data_sample['pred_det_instances']
result['img_id'] = data_sample['img_id']
result['bboxes'] = pred['bboxes'].cpu().numpy()
result['scores'] = pred['scores'].cpu().numpy()
result['labels'] = pred['labels'].cpu().numpy()
# encode mask to RLE
if 'masks' in pred:
result['masks'] = encode_mask_results(
pred['masks'].detach().cpu().numpy())
pred = dict()
pred_instances = data_sample['pred_det_instances']
pred['img_id'] = data_sample['img_id']
pred['bboxes'] = pred_instances['bboxes'].cpu().numpy()
pred['scores'] = pred_instances['scores'].cpu().numpy()
pred['labels'] = pred_instances['labels'].cpu().numpy()
if 'masks' in pred_instances:
pred['masks'] = encode_mask_results(
pred_instances['masks'].detach().cpu().numpy())
# some detectors use different scores for bbox and mask
if 'mask_scores' in pred:
result['mask_scores'] = pred['mask_scores'].cpu().numpy()
if 'mask_scores' in pred_instances:
pred['mask_scores'] = \
pred_instances['mask_scores'].cpu().numpy()
predictions.append(pred)

# parse gt
gt = dict()
gt['width'] = data_sample['ori_shape'][1]
gt['height'] = data_sample['ori_shape'][0]
gt['img_id'] = data_sample['img_id']
if self._coco_api is None:
assert 'instances' in data_sample, \
'ground truth is required for evaluation when ' \
'`ann_file` is not provided'
gt['anns'] = data_sample['instances']
# add converted result to the results list
self.results.append((gt, result))

def evaluate(self, size: int) -> dict:
"""Evaluate the model performance of the whole dataset after processing
all batches.

Args:
size (int): Length of the entire validation dataset.

Returns:
dict: Evaluation metrics dict on the val dataset. The keys are the
names of the metrics, and the values are corresponding results.
"""
if len(self.results) == 0:
warnings.warn(
f'{self.__class__.__name__} got empty `self.results`. Please '
'ensure that the processed results are properly added into '
'`self.results` in `process` method.')

results = collect_tracking_results(self.results, self.collect_device)

if is_main_process():
_metrics = self.compute_metrics(results) # type: ignore
# Add prefix to metric names
if self.prefix:
_metrics = {
'/'.join((self.prefix, k)): v
for k, v in _metrics.items()
}
metrics = [_metrics]
else:
metrics = [None] # type: ignore

broadcast_object_list(metrics)
ann = self.add_gt(data_sample)
else:
ann = dict()
groundtruths.append(ann)

# reset the results list
self.results.clear()
return metrics[0]
self.add(predictions, groundtruths)