From 0735267908efe9b69f2e0a380db97c59a5763855 Mon Sep 17 00:00:00 2001 From: liqikai Date: Thu, 8 Dec 2022 11:22:25 +0800 Subject: [PATCH 1/2] update CrowdPose evaluation results --- .../topdown_heatmap/crowdpose/hrnet_crowdpose.md | 6 +++--- .../topdown_heatmap/crowdpose/resnet_crowdpose.md | 12 ++++++------ mmpose/evaluation/metrics/coco_metric.py | 9 ++++----- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/configs/body_2d_keypoint/topdown_heatmap/crowdpose/hrnet_crowdpose.md b/configs/body_2d_keypoint/topdown_heatmap/crowdpose/hrnet_crowdpose.md index 8fd9e6820b..c0d24d4717 100644 --- a/configs/body_2d_keypoint/topdown_heatmap/crowdpose/hrnet_crowdpose.md +++ b/configs/body_2d_keypoint/topdown_heatmap/crowdpose/hrnet_crowdpose.md @@ -33,6 +33,6 @@ Results on CrowdPose test with [YOLOv3](https://github.com/eriklindernoren/PyTorch-YOLOv3) human detector -| Arch | Input Size | AP | AP50 | AP75 | AR | AR50 | ckpt | log | -| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: | -| [pose_hrnet_w32](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_hrnet-w32_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.675 | 0.825 | 0.729 | 0.816 | 0.769 | [ckpt](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_crowdpose_256x192-960be101_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_crowdpose_256x192_20201227.log.json) | +| Arch | Input Size | AP | AP50 | AP75 | AP (E) | AP (M) | AP (H) | ckpt | log | +| :--------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :----: | :----: | :----: | :--------------------------------------------: | :-------------------------------------------: | +| [pose_hrnet_w32](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_hrnet-w32_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.675 | 0.825 | 0.729 | 0.770 | 0.687 | 0.553 | [ckpt](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_crowdpose_256x192-960be101_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_crowdpose_256x192_20201227.log.json) | diff --git a/configs/body_2d_keypoint/topdown_heatmap/crowdpose/resnet_crowdpose.md b/configs/body_2d_keypoint/topdown_heatmap/crowdpose/resnet_crowdpose.md index bc509800ed..56a771806d 100644 --- a/configs/body_2d_keypoint/topdown_heatmap/crowdpose/resnet_crowdpose.md +++ b/configs/body_2d_keypoint/topdown_heatmap/crowdpose/resnet_crowdpose.md @@ -50,9 +50,9 @@ Results on CrowdPose test with [YOLOv3](https://github.com/eriklindernoren/PyTorch-YOLOv3) human detector -| Arch | Input Size | AP | AP50 | AP75 | AR | AR50 | ckpt | log | -| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: | -| [pose_resnet_50](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res50_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.637 | 0.808 | 0.692 | 0.785 | 0.738 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res50_crowdpose_256x192-c6a526b6_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res50_crowdpose_256x192_20201227.log.json) | -| [pose_resnet_101](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res101_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.647 | 0.810 | 0.703 | 0.796 | 0.746 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_256x192-8f5870f4_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_256x192_20201227.log.json) | -| [pose_resnet_101](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res101_8xb64-210e_crowdpose-320x256.py) | 320x256 | 0.661 | 0.821 | 0.714 | 0.800 | 0.759 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_320x256-c88c512a_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_320x256_20201227.log.json) | -| [pose_resnet_152](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res152_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.656 | 0.818 | 0.712 | 0.803 | 0.754 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res152_crowdpose_256x192-dbd49aba_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res152_crowdpose_256x192_20201227.log.json) | +| Arch | Input Size | AP | AP50 | AP75 | AP (E) | AP (M) | AP (H) | ckpt | log | +| :--------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :----: | :----: | :----: | :--------------------------------------------: | :-------------------------------------------: | +| [pose_resnet_50](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res50_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.637 | 0.808 | 0.692 | 0.738 | 0.650 | 0.506 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res50_crowdpose_256x192-c6a526b6_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res50_crowdpose_256x192_20201227.log.json) | +| [pose_resnet_101](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res101_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.647 | 0.810 | 0.703 | 0.745 | 0.658 | 0.521 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_256x192-8f5870f4_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_256x192_20201227.log.json) | +| [pose_resnet_101](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res101_8xb64-210e_crowdpose-320x256.py) | 320x256 | 0.661 | 0.821 | 0.714 | 0.759 | 0.672 | 0.534 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_320x256-c88c512a_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_320x256_20201227.log.json) | +| [pose_resnet_152](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res152_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.656 | 0.818 | 0.712 | 0.754 | 0.666 | 0.533 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res152_crowdpose_256x192-dbd49aba_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res152_crowdpose_256x192_20201227.log.json) | diff --git a/mmpose/evaluation/metrics/coco_metric.py b/mmpose/evaluation/metrics/coco_metric.py index 9be6b55734..4a97919966 100644 --- a/mmpose/evaluation/metrics/coco_metric.py +++ b/mmpose/evaluation/metrics/coco_metric.py @@ -18,7 +18,7 @@ @METRICS.register_module() class CocoMetric(BaseMetric): - """COCO evaluation metric. + """COCO pose estimation task evaluation metric. Evaluate AR, AP, and mAP for keypoint detection tasks. Support COCO dataset and other datasets in COCO format. Please refer to @@ -32,7 +32,7 @@ class CocoMetric(BaseMetric): use_area (bool): Whether to use ``'area'`` message in the annotations. If the ground truth annotations (e.g. CrowdPose, AIC) do not have the field ``'area'``, please set ``use_area=False``. - Default: ``True`` + Defaults to ``True`` iou_type (str): The same parameter as `iouType` in :class:`xtcocotools.COCOeval`, which can be ``'keypoints'``, or ``'keypoints_crowd'`` (used in CrowdPose dataset). @@ -72,10 +72,9 @@ class CocoMetric(BaseMetric): test submission when the ground truth annotations are absent. If set to ``True``, ``outfile_prefix`` should specify the path to store the output results. Defaults to ``False`` - outfile_prefix (str, optional): The prefix of json files. It includes + outfile_prefix (str | None): The prefix of json files. It includes the file path and the prefix of filename, e.g., ``'a/b/prefix'``. - If not specified, a temp file will be created. - Defaults to ``None`` + If not specified, a temp file will be created. Defaults to ``None`` collect_device (str): Device name used for collecting results from different ranks during distributed training. Must be ``'cpu'`` or ``'gpu'``. Defaults to ``'cpu'`` From 137704312357b9672356eddfb51d6746b693e7af Mon Sep 17 00:00:00 2001 From: liqikai Date: Thu, 8 Dec 2022 11:22:39 +0800 Subject: [PATCH 2/2] improve metrics --- .../wholebody/coco_wholebody_dataset.py | 5 + .../metrics/coco_wholebody_metric.py | 118 ++++++- .../evaluation/metrics/posetrack18_metric.py | 22 +- .../test_coco_wholebody_metric.py | 294 ++++++++++++++++++ .../test_metrics/test_posetrack18_metric.py | 2 +- 5 files changed, 414 insertions(+), 27 deletions(-) create mode 100644 tests/test_evaluation/test_metrics/test_coco_wholebody_metric.py diff --git a/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py b/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py index 29e7c8dfbc..00a2ea418f 100644 --- a/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py +++ b/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import os.path as osp from typing import Optional @@ -117,6 +118,10 @@ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: 'iscrowd': ann['iscrowd'], 'segmentation': ann['segmentation'], 'id': ann['id'], + 'category_id': ann['category_id'], + # store the raw annotation of the instance + # it is useful for evaluation without providing ann_file + 'raw_ann_info': copy.deepcopy(ann), } return data_info diff --git a/mmpose/evaluation/metrics/coco_wholebody_metric.py b/mmpose/evaluation/metrics/coco_wholebody_metric.py index 34d81aed20..c5675f54c8 100644 --- a/mmpose/evaluation/metrics/coco_wholebody_metric.py +++ b/mmpose/evaluation/metrics/coco_wholebody_metric.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Optional +import datetime +from typing import Dict, Optional, Sequence import numpy as np from mmengine.fileio import dump @@ -19,15 +20,17 @@ class CocoWholeBodyMetric(CocoMetric): for more details. Args: - ann_file (str): Path to the coco format annotation file. + ann_file (str, optional): Path to the coco format annotation file. + If not specified, ground truth annotations from the dataset will + be converted to coco format. Defaults to None use_area (bool): Whether to use ``'area'`` message in the annotations. If the ground truth annotations (e.g. CrowdPose, AIC) do not have the field ``'area'``, please set ``use_area=False``. - Default: ``True``. + Defaults to ``True`` iou_type (str): The same parameter as `iouType` in :class:`xtcocotools.COCOeval`, which can be ``'keypoints'``, or ``'keypoints_crowd'`` (used in CrowdPose dataset). - Defaults to ``'keypoints'``. + Defaults to ``'keypoints'`` score_mode (str): The mode to score the prediction results which should be one of the following options: @@ -62,17 +65,11 @@ class CocoWholeBodyMetric(CocoMetric): doing quantitative evaluation. This is designed for the need of test submission when the ground truth annotations are absent. If set to ``True``, ``outfile_prefix`` should specify the path to - store the output results. Default: ``False``. + store the output results. Defaults to ``False`` outfile_prefix (str | None): The prefix of json files. It includes the file path and the prefix of filename, e.g., ``'a/b/prefix'``. - If not specified, a temp file will be created. Default: ``None``. - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be ``'cpu'`` or - ``'gpu'``. Default: ``'cpu'``. - prefix (str, optional): The prefix that will be added in the metric - names to disambiguate homonymous metrics of different evaluators. - If prefix is not provided in the argument, ``self.default_prefix`` - will be used instead. Default: ``None``. + If not specified, a temp file will be created. Defaults to ``None`` + **kwargs: Keyword parameters passed to :class:`mmeval.BaseMetric` """ default_prefix: Optional[str] = 'coco-wholebody' body_num = 17 @@ -81,6 +78,101 @@ class CocoWholeBodyMetric(CocoMetric): left_hand_num = 21 right_hand_num = 21 + def gt_to_coco_json(self, gt_dicts: Sequence[dict], + outfile_prefix: str) -> str: + """Convert ground truth to coco format json file. + + Args: + gt_dicts (Sequence[dict]): Ground truth of the dataset. Each dict + contains the ground truth information about the data sample. + Required keys of the each `gt_dict` in `gt_dicts`: + - `img_id`: image id of the data sample + - `width`: original image width + - `height`: original image height + - `raw_ann_info`: the raw annotation information + Optional keys: + - `crowd_index`: measure the crowding level of an image, + defined in CrowdPose dataset + It is worth mentioning that, in order to compute `CocoMetric`, + there are some required keys in the `raw_ann_info`: + - `id`: the id to distinguish different annotations + - `image_id`: the image id of this annotation + - `category_id`: the category of the instance. + - `bbox`: the object bounding box + - `keypoints`: the keypoints cooridinates along with their + visibilities. Note that it need to be aligned + with the official COCO format, e.g., a list with length + N * 3, in which N is the number of keypoints. And each + triplet represent the [x, y, visible] of the keypoint. + - 'keypoints' + - `iscrowd`: indicating whether the annotation is a crowd. + It is useful when matching the detection results to + the ground truth. + There are some optional keys as well: + - `area`: it is necessary when `self.use_area` is `True` + - `num_keypoints`: it is necessary when `self.iou_type` + is set as `keypoints_crowd`. + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json file will be named + "somepath/xxx.gt.json". + Returns: + str: The filename of the json file. + """ + image_infos = [] + annotations = [] + img_ids = [] + ann_ids = [] + + for gt_dict in gt_dicts: + # filter duplicate image_info + if gt_dict['img_id'] not in img_ids: + image_info = dict( + id=gt_dict['img_id'], + width=gt_dict['width'], + height=gt_dict['height'], + ) + if self.iou_type == 'keypoints_crowd': + image_info['crowdIndex'] = gt_dict['crowd_index'] + + image_infos.append(image_info) + img_ids.append(gt_dict['img_id']) + + # filter duplicate annotations + for ann in gt_dict['raw_ann_info']: + annotation = dict( + id=ann['id'], + image_id=ann['image_id'], + category_id=ann['category_id'], + bbox=ann['bbox'], + keypoints=ann['keypoints'], + foot_kpts=ann['foot_kpts'], + face_kpts=ann['face_kpts'], + lefthand_kpts=ann['lefthand_kpts'], + righthand_kpts=ann['righthand_kpts'], + iscrowd=ann['iscrowd'], + ) + if self.use_area: + assert 'area' in ann, \ + '`area` is required when `self.use_area` is `True`' + annotation['area'] = ann['area'] + + annotations.append(annotation) + ann_ids.append(ann['id']) + + info = dict( + date_created=str(datetime.datetime.now()), + description='Coco json file converted by mmpose CocoMetric.') + coco_json: dict = dict( + info=info, + images=image_infos, + categories=self.dataset_meta['CLASSES'], + licenses=None, + annotations=annotations, + ) + converted_json_path = f'{outfile_prefix}.gt.json' + dump(coco_json, converted_json_path, sort_keys=True, indent=4) + return converted_json_path + def results2json(self, keypoints: Dict[int, list], outfile_prefix: str) -> str: """Dump the keypoint detection results to a COCO style json file. diff --git a/mmpose/evaluation/metrics/posetrack18_metric.py b/mmpose/evaluation/metrics/posetrack18_metric.py index 85d06c37ee..86f801455a 100644 --- a/mmpose/evaluation/metrics/posetrack18_metric.py +++ b/mmpose/evaluation/metrics/posetrack18_metric.py @@ -28,7 +28,9 @@ class PoseTrack18Metric(CocoMetric): for more details. Args: - ann_file (str): Path to the annotation file. + ann_file (str, optional): Path to the coco format annotation file. + If not specified, ground truth annotations from the dataset will + be converted to coco format. Defaults to None score_mode (str): The mode to score the prediction results which should be one of the following options: @@ -37,7 +39,7 @@ class PoseTrack18Metric(CocoMetric): - ``'bbox_keypoint'``: Use keypoint score to rescore the prediction results. - Defaults to ``'bbox'` + Defaults to ``'bbox_keypoint'` keypoint_score_thr (float): The threshold of keypoint score. The keypoints with score lower than it will not be included to rescore the prediction results. Valid only when ``score_mode`` is @@ -61,22 +63,16 @@ class PoseTrack18Metric(CocoMetric): doing quantitative evaluation. This is designed for the need of test submission when the ground truth annotations are absent. If set to ``True``, ``outfile_prefix`` should specify the path to - store the output results. Default: ``False``. + store the output results. Defaults to ``False`` outfile_prefix (str | None): The prefix of json files. It includes the file path and the prefix of filename, e.g., ``'a/b/prefix'``. - If not specified, a temp file will be created. Default: ``None``. - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be ``'cpu'`` or - ``'gpu'``. Default: ``'cpu'``. - prefix (str, optional): The prefix that will be added in the metric - names to disambiguate homonymous metrics of different evaluators. - If prefix is not provided in the argument, ``self.default_prefix`` - will be used instead. Default: ``None``. + If not specified, a temp file will be created. Defaults to ``None`` + **kwargs: Keyword parameters passed to :class:`mmeval.BaseMetric` """ default_prefix: Optional[str] = 'posetrack18' def __init__(self, - ann_file: str, + ann_file: Optional[str] = None, score_mode: str = 'bbox_keypoint', keypoint_score_thr: float = 0.2, nms_mode: str = 'oks_nms', @@ -216,7 +212,7 @@ def _do_python_keypoint_eval(self, outfile_prefix: str) -> List[tuple]: stats_names = [ 'Head AP', 'Shou AP', 'Elb AP', 'Wri AP', 'Hip AP', 'Knee AP', - 'Ankl AP', 'Total AP' + 'Ankl AP', 'AP' ] info_str = list(zip(stats_names, stats)) diff --git a/tests/test_evaluation/test_metrics/test_coco_wholebody_metric.py b/tests/test_evaluation/test_metrics/test_coco_wholebody_metric.py new file mode 100644 index 0000000000..46e8498851 --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_coco_wholebody_metric.py @@ -0,0 +1,294 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +import tempfile +from collections import defaultdict +from unittest import TestCase + +import numpy as np +from mmengine.fileio import dump, load +from xtcocotools.coco import COCO + +from mmpose.datasets.datasets.utils import parse_pose_metainfo +from mmpose.evaluation.metrics import CocoWholeBodyMetric + + +class TestCocoWholeBodyMetric(TestCase): + + def setUp(self): + """Setup some variables which are used in every test method. + + TestCase calls functions in this order: setUp() -> testMethod() -> + tearDown() -> cleanUp() + """ + self.tmp_dir = tempfile.TemporaryDirectory() + + self.ann_file_coco = 'tests/data/coco/test_coco_wholebody.json' + meta_info_coco = dict( + from_file='configs/_base_/datasets/coco_wholebody.py') + self.dataset_meta_coco = parse_pose_metainfo(meta_info_coco) + self.coco = COCO(self.ann_file_coco) + self.dataset_meta_coco['CLASSES'] = self.coco.loadCats( + self.coco.getCatIds()) + + self.topdown_data_coco = self._convert_ann_to_topdown_batch_data( + self.ann_file_coco) + assert len(self.topdown_data_coco) == 14 + self.bottomup_data_coco = self._convert_ann_to_bottomup_batch_data( + self.ann_file_coco) + assert len(self.bottomup_data_coco) == 4 + self.target_coco = { + 'coco-wholebody/AP': 1.0, + 'coco-wholebody/AP .5': 1.0, + 'coco-wholebody/AP .75': 1.0, + 'coco-wholebody/AP (M)': 1.0, + 'coco-wholebody/AP (L)': 1.0, + 'coco-wholebody/AR': 1.0, + 'coco-wholebody/AR .5': 1.0, + 'coco-wholebody/AR .75': 1.0, + 'coco-wholebody/AR (M)': 1.0, + 'coco-wholebody/AR (L)': 1.0, + } + + def _convert_ann_to_topdown_batch_data(self, ann_file): + """Convert annotations to topdown-style batch data.""" + topdown_data = [] + db = load(ann_file) + imgid2info = dict() + for img in db['images']: + imgid2info[img['id']] = img + for ann in db['annotations']: + w, h = ann['bbox'][2], ann['bbox'][3] + bboxes = np.array(ann['bbox'], dtype=np.float32).reshape(-1, 4) + bbox_scales = np.array([w * 1.25, h * 1.25]).reshape(-1, 2) + _keypoints = np.array(ann['keypoints'] + ann['foot_kpts'] + + ann['face_kpts'] + ann['lefthand_kpts'] + + ann['righthand_kpts']).reshape(1, -1, 3) + + gt_instances = { + 'bbox_scales': bbox_scales, + 'bbox_scores': np.ones((1, ), dtype=np.float32), + 'bboxes': bboxes, + } + pred_instances = { + 'keypoints': _keypoints[..., :2], + 'keypoint_scores': _keypoints[..., -1], + } + + data = {'inputs': None} + data_sample = { + 'id': ann['id'], + 'img_id': ann['image_id'], + 'category_id': ann.get('category_id', 1), + 'gt_instances': gt_instances, + 'pred_instances': pred_instances, + # dummy image_shape for testing + 'ori_shape': [640, 480], + # store the raw annotation info to test without ann_file + 'raw_ann_info': copy.deepcopy(ann), + } + + # batch size = 1 + data_batch = [data] + data_samples = [data_sample] + topdown_data.append((data_batch, data_samples)) + + return topdown_data + + def _convert_ann_to_bottomup_batch_data(self, ann_file): + """Convert annotations to bottomup-style batch data.""" + img2ann = defaultdict(list) + db = load(ann_file) + for ann in db['annotations']: + img2ann[ann['image_id']].append(ann) + + bottomup_data = [] + for img_id, anns in img2ann.items(): + _keypoints = [] + for ann in anns: + _keypoints.append(ann['keypoints'] + ann['foot_kpts'] + + ann['face_kpts'] + ann['lefthand_kpts'] + + ann['righthand_kpts']) + keypoints = np.array(_keypoints).reshape((len(anns), -1, 3)) + + gt_instances = { + 'bbox_scores': np.ones((len(anns)), dtype=np.float32) + } + + pred_instances = { + 'keypoints': keypoints[..., :2], + 'keypoint_scores': keypoints[..., -1], + } + + data = {'inputs': None} + data_sample = { + 'id': [ann['id'] for ann in anns], + 'img_id': img_id, + 'gt_instances': gt_instances, + 'pred_instances': pred_instances + } + + # batch size = 1 + data_batch = [data] + data_samples = [data_sample] + bottomup_data.append((data_batch, data_samples)) + return bottomup_data + + def tearDown(self): + self.tmp_dir.cleanup() + + def test_init(self): + """test metric init method.""" + # test score_mode option + with self.assertRaisesRegex(ValueError, + '`score_mode` should be one of'): + _ = CocoWholeBodyMetric( + ann_file=self.ann_file_coco, score_mode='invalid') + + # test nms_mode option + with self.assertRaisesRegex(ValueError, '`nms_mode` should be one of'): + _ = CocoWholeBodyMetric( + ann_file=self.ann_file_coco, nms_mode='invalid') + + # test format_only option + with self.assertRaisesRegex( + AssertionError, + '`outfile_prefix` can not be None when `format_only` is True'): + _ = CocoWholeBodyMetric( + ann_file=self.ann_file_coco, + format_only=True, + outfile_prefix=None) + + def test_other_methods(self): + """test other useful methods.""" + # test `_sort_and_unique_bboxes` method + metric_coco = CocoWholeBodyMetric( + ann_file=self.ann_file_coco, score_mode='bbox', nms_mode='none') + metric_coco.dataset_meta = self.dataset_meta_coco + # process samples + for data_batch, data_samples in self.topdown_data_coco: + metric_coco.process(data_batch, data_samples) + # process one extra sample + data_batch, data_samples = self.topdown_data_coco[0] + metric_coco.process(data_batch, data_samples) + # an extra sample + eval_results = metric_coco.evaluate( + size=len(self.topdown_data_coco) + 1) + self.assertDictEqual(eval_results, self.target_coco) + + def test_format_only(self): + """test `format_only` option.""" + metric_coco = CocoWholeBodyMetric( + ann_file=self.ann_file_coco, + format_only=True, + outfile_prefix=f'{self.tmp_dir.name}/test', + score_mode='bbox_keypoint', + nms_mode='oks_nms') + metric_coco.dataset_meta = self.dataset_meta_coco + # process one sample + data_batch, data_samples = self.topdown_data_coco[0] + metric_coco.process(data_batch, data_samples) + eval_results = metric_coco.evaluate(size=1) + self.assertDictEqual(eval_results, {}) + self.assertTrue( + osp.isfile(osp.join(self.tmp_dir.name, 'test.keypoints.json'))) + + # test when gt annotations are absent + db_ = load(self.ann_file_coco) + del db_['annotations'] + tmp_ann_file = osp.join(self.tmp_dir.name, 'temp_ann.json') + dump(db_, tmp_ann_file, sort_keys=True, indent=4) + with self.assertRaisesRegex( + AssertionError, + 'Ground truth annotations are required for evaluation'): + _ = CocoWholeBodyMetric(ann_file=tmp_ann_file, format_only=False) + + def test_bottomup_evaluate(self): + """test bottomup-style COCO metric evaluation.""" + # case1: score_mode='bbox', nms_mode='none' + metric_coco = CocoWholeBodyMetric( + ann_file=self.ann_file_coco, + outfile_prefix=f'{self.tmp_dir.name}/test', + score_mode='bbox', + nms_mode='none') + metric_coco.dataset_meta = self.dataset_meta_coco + + # process samples + for data_batch, data_samples in self.bottomup_data_coco: + metric_coco.process(data_batch, data_samples) + + eval_results = metric_coco.evaluate(size=len(self.bottomup_data_coco)) + self.assertDictEqual(eval_results, self.target_coco) + self.assertTrue( + osp.isfile(osp.join(self.tmp_dir.name, 'test.keypoints.json'))) + + def test_topdown_evaluate(self): + """test topdown-style COCO metric evaluation.""" + # case 1: score_mode='bbox', nms_mode='none' + metric_coco = CocoWholeBodyMetric( + ann_file=self.ann_file_coco, + outfile_prefix=f'{self.tmp_dir.name}/test1', + score_mode='bbox', + nms_mode='none') + metric_coco.dataset_meta = self.dataset_meta_coco + + # process samples + for data_batch, data_samples in self.topdown_data_coco: + metric_coco.process(data_batch, data_samples) + + eval_results = metric_coco.evaluate(size=len(self.topdown_data_coco)) + + self.assertDictEqual(eval_results, self.target_coco) + self.assertTrue( + osp.isfile(osp.join(self.tmp_dir.name, 'test1.keypoints.json'))) + + # case 2: score_mode='bbox_keypoint', nms_mode='oks_nms' + metric_coco = CocoWholeBodyMetric( + ann_file=self.ann_file_coco, + outfile_prefix=f'{self.tmp_dir.name}/test2', + score_mode='bbox_keypoint', + nms_mode='oks_nms') + metric_coco.dataset_meta = self.dataset_meta_coco + + # process samples + for data_batch, data_samples in self.topdown_data_coco: + metric_coco.process(data_batch, data_samples) + + eval_results = metric_coco.evaluate(size=len(self.topdown_data_coco)) + + self.assertDictEqual(eval_results, self.target_coco) + self.assertTrue( + osp.isfile(osp.join(self.tmp_dir.name, 'test2.keypoints.json'))) + + # case 3: score_mode='bbox_rle', nms_mode='soft_oks_nms' + metric_coco = CocoWholeBodyMetric( + ann_file=self.ann_file_coco, + outfile_prefix=f'{self.tmp_dir.name}/test3', + score_mode='bbox_rle', + nms_mode='soft_oks_nms') + metric_coco.dataset_meta = self.dataset_meta_coco + + # process samples + for data_batch, data_samples in self.topdown_data_coco: + metric_coco.process(data_batch, data_samples) + + eval_results = metric_coco.evaluate(size=len(self.topdown_data_coco)) + + self.assertDictEqual(eval_results, self.target_coco) + self.assertTrue( + osp.isfile(osp.join(self.tmp_dir.name, 'test3.keypoints.json'))) + + # case 4: test without providing ann_file + metric_coco = CocoWholeBodyMetric( + outfile_prefix=f'{self.tmp_dir.name}/test4') + metric_coco.dataset_meta = self.dataset_meta_coco + # process samples + for data_batch, data_samples in self.topdown_data_coco: + metric_coco.process(data_batch, data_samples) + eval_results = metric_coco.evaluate(size=len(self.topdown_data_coco)) + self.assertDictEqual(eval_results, self.target_coco) + # test whether convert the annotation to COCO format + self.assertTrue( + osp.isfile(osp.join(self.tmp_dir.name, 'test4.gt.json'))) + self.assertTrue( + osp.isfile(osp.join(self.tmp_dir.name, 'test4.keypoints.json'))) diff --git a/tests/test_evaluation/test_metrics/test_posetrack18_metric.py b/tests/test_evaluation/test_metrics/test_posetrack18_metric.py index 09a6c56435..fe44047e31 100644 --- a/tests/test_evaluation/test_metrics/test_posetrack18_metric.py +++ b/tests/test_evaluation/test_metrics/test_posetrack18_metric.py @@ -41,7 +41,7 @@ def setUp(self): 'posetrack18/Hip AP': 100.0, 'posetrack18/Knee AP': 100.0, 'posetrack18/Ankl AP': 100.0, - 'posetrack18/Total AP': 100.0, + 'posetrack18/AP': 100.0, } def _convert_ann_to_topdown_batch_data(self):