From f3aff732690a9f275fd2cc0cd8f2f57d01c5d7be Mon Sep 17 00:00:00 2001
From: Qikai Li <87690686+liqikai9@users.noreply.github.com>
Date: Fri, 9 Dec 2022 12:19:49 +0800
Subject: [PATCH] [Refactor] Update CrowdPose evaluation results (#1868)
---
.../crowdpose/hrnet_crowdpose.md | 6 +-
.../crowdpose/resnet_crowdpose.md | 12 +-
.../wholebody/coco_wholebody_dataset.py | 5 +
mmpose/evaluation/metrics/coco_metric.py | 9 +-
.../metrics/coco_wholebody_metric.py | 118 ++++++-
.../evaluation/metrics/posetrack18_metric.py | 22 +-
.../test_coco_wholebody_metric.py | 294 ++++++++++++++++++
.../test_metrics/test_posetrack18_metric.py | 2 +-
8 files changed, 427 insertions(+), 41 deletions(-)
create mode 100644 tests/test_evaluation/test_metrics/test_coco_wholebody_metric.py
diff --git a/configs/body_2d_keypoint/topdown_heatmap/crowdpose/hrnet_crowdpose.md b/configs/body_2d_keypoint/topdown_heatmap/crowdpose/hrnet_crowdpose.md
index 8fd9e6820b..c0d24d4717 100644
--- a/configs/body_2d_keypoint/topdown_heatmap/crowdpose/hrnet_crowdpose.md
+++ b/configs/body_2d_keypoint/topdown_heatmap/crowdpose/hrnet_crowdpose.md
@@ -33,6 +33,6 @@
Results on CrowdPose test with [YOLOv3](https://github.com/eriklindernoren/PyTorch-YOLOv3) human detector
-| Arch | Input Size | AP | AP50 | AP75 | AR | AR50 | ckpt | log |
-| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: |
-| [pose_hrnet_w32](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_hrnet-w32_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.675 | 0.825 | 0.729 | 0.816 | 0.769 | [ckpt](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_crowdpose_256x192-960be101_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_crowdpose_256x192_20201227.log.json) |
+| Arch | Input Size | AP | AP50 | AP75 | AP (E) | AP (M) | AP (H) | ckpt | log |
+| :--------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :----: | :----: | :----: | :--------------------------------------------: | :-------------------------------------------: |
+| [pose_hrnet_w32](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_hrnet-w32_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.675 | 0.825 | 0.729 | 0.770 | 0.687 | 0.553 | [ckpt](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_crowdpose_256x192-960be101_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_crowdpose_256x192_20201227.log.json) |
diff --git a/configs/body_2d_keypoint/topdown_heatmap/crowdpose/resnet_crowdpose.md b/configs/body_2d_keypoint/topdown_heatmap/crowdpose/resnet_crowdpose.md
index bc509800ed..56a771806d 100644
--- a/configs/body_2d_keypoint/topdown_heatmap/crowdpose/resnet_crowdpose.md
+++ b/configs/body_2d_keypoint/topdown_heatmap/crowdpose/resnet_crowdpose.md
@@ -50,9 +50,9 @@
Results on CrowdPose test with [YOLOv3](https://github.com/eriklindernoren/PyTorch-YOLOv3) human detector
-| Arch | Input Size | AP | AP50 | AP75 | AR | AR50 | ckpt | log |
-| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: |
-| [pose_resnet_50](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res50_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.637 | 0.808 | 0.692 | 0.785 | 0.738 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res50_crowdpose_256x192-c6a526b6_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res50_crowdpose_256x192_20201227.log.json) |
-| [pose_resnet_101](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res101_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.647 | 0.810 | 0.703 | 0.796 | 0.746 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_256x192-8f5870f4_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_256x192_20201227.log.json) |
-| [pose_resnet_101](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res101_8xb64-210e_crowdpose-320x256.py) | 320x256 | 0.661 | 0.821 | 0.714 | 0.800 | 0.759 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_320x256-c88c512a_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_320x256_20201227.log.json) |
-| [pose_resnet_152](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res152_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.656 | 0.818 | 0.712 | 0.803 | 0.754 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res152_crowdpose_256x192-dbd49aba_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res152_crowdpose_256x192_20201227.log.json) |
+| Arch | Input Size | AP | AP50 | AP75 | AP (E) | AP (M) | AP (H) | ckpt | log |
+| :--------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :----: | :----: | :----: | :--------------------------------------------: | :-------------------------------------------: |
+| [pose_resnet_50](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res50_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.637 | 0.808 | 0.692 | 0.738 | 0.650 | 0.506 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res50_crowdpose_256x192-c6a526b6_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res50_crowdpose_256x192_20201227.log.json) |
+| [pose_resnet_101](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res101_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.647 | 0.810 | 0.703 | 0.745 | 0.658 | 0.521 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_256x192-8f5870f4_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_256x192_20201227.log.json) |
+| [pose_resnet_101](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res101_8xb64-210e_crowdpose-320x256.py) | 320x256 | 0.661 | 0.821 | 0.714 | 0.759 | 0.672 | 0.534 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_320x256-c88c512a_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_320x256_20201227.log.json) |
+| [pose_resnet_152](/configs/body_2d_keypoint/topdown_heatmap/crowdpose/td-hm_res152_8xb64-210e_crowdpose-256x192.py) | 256x192 | 0.656 | 0.818 | 0.712 | 0.754 | 0.666 | 0.533 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res152_crowdpose_256x192-dbd49aba_20201227.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res152_crowdpose_256x192_20201227.log.json) |
diff --git a/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py b/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py
index 29e7c8dfbc..00a2ea418f 100644
--- a/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py
+++ b/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import copy
import os.path as osp
from typing import Optional
@@ -117,6 +118,10 @@ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
'iscrowd': ann['iscrowd'],
'segmentation': ann['segmentation'],
'id': ann['id'],
+ 'category_id': ann['category_id'],
+ # store the raw annotation of the instance
+ # it is useful for evaluation without providing ann_file
+ 'raw_ann_info': copy.deepcopy(ann),
}
return data_info
diff --git a/mmpose/evaluation/metrics/coco_metric.py b/mmpose/evaluation/metrics/coco_metric.py
index 9be6b55734..4a97919966 100644
--- a/mmpose/evaluation/metrics/coco_metric.py
+++ b/mmpose/evaluation/metrics/coco_metric.py
@@ -18,7 +18,7 @@
@METRICS.register_module()
class CocoMetric(BaseMetric):
- """COCO evaluation metric.
+ """COCO pose estimation task evaluation metric.
Evaluate AR, AP, and mAP for keypoint detection tasks. Support COCO
dataset and other datasets in COCO format. Please refer to
@@ -32,7 +32,7 @@ class CocoMetric(BaseMetric):
use_area (bool): Whether to use ``'area'`` message in the annotations.
If the ground truth annotations (e.g. CrowdPose, AIC) do not have
the field ``'area'``, please set ``use_area=False``.
- Default: ``True``
+ Defaults to ``True``
iou_type (str): The same parameter as `iouType` in
:class:`xtcocotools.COCOeval`, which can be ``'keypoints'``, or
``'keypoints_crowd'`` (used in CrowdPose dataset).
@@ -72,10 +72,9 @@ class CocoMetric(BaseMetric):
test submission when the ground truth annotations are absent. If
set to ``True``, ``outfile_prefix`` should specify the path to
store the output results. Defaults to ``False``
- outfile_prefix (str, optional): The prefix of json files. It includes
+ outfile_prefix (str | None): The prefix of json files. It includes
the file path and the prefix of filename, e.g., ``'a/b/prefix'``.
- If not specified, a temp file will be created.
- Defaults to ``None``
+ If not specified, a temp file will be created. Defaults to ``None``
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be ``'cpu'`` or
``'gpu'``. Defaults to ``'cpu'``
diff --git a/mmpose/evaluation/metrics/coco_wholebody_metric.py b/mmpose/evaluation/metrics/coco_wholebody_metric.py
index 34d81aed20..c5675f54c8 100644
--- a/mmpose/evaluation/metrics/coco_wholebody_metric.py
+++ b/mmpose/evaluation/metrics/coco_wholebody_metric.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Dict, Optional
+import datetime
+from typing import Dict, Optional, Sequence
import numpy as np
from mmengine.fileio import dump
@@ -19,15 +20,17 @@ class CocoWholeBodyMetric(CocoMetric):
for more details.
Args:
- ann_file (str): Path to the coco format annotation file.
+ ann_file (str, optional): Path to the coco format annotation file.
+ If not specified, ground truth annotations from the dataset will
+ be converted to coco format. Defaults to None
use_area (bool): Whether to use ``'area'`` message in the annotations.
If the ground truth annotations (e.g. CrowdPose, AIC) do not have
the field ``'area'``, please set ``use_area=False``.
- Default: ``True``.
+ Defaults to ``True``
iou_type (str): The same parameter as `iouType` in
:class:`xtcocotools.COCOeval`, which can be ``'keypoints'``, or
``'keypoints_crowd'`` (used in CrowdPose dataset).
- Defaults to ``'keypoints'``.
+ Defaults to ``'keypoints'``
score_mode (str): The mode to score the prediction results which
should be one of the following options:
@@ -62,17 +65,11 @@ class CocoWholeBodyMetric(CocoMetric):
doing quantitative evaluation. This is designed for the need of
test submission when the ground truth annotations are absent. If
set to ``True``, ``outfile_prefix`` should specify the path to
- store the output results. Default: ``False``.
+ store the output results. Defaults to ``False``
outfile_prefix (str | None): The prefix of json files. It includes
the file path and the prefix of filename, e.g., ``'a/b/prefix'``.
- If not specified, a temp file will be created. Default: ``None``.
- collect_device (str): Device name used for collecting results from
- different ranks during distributed training. Must be ``'cpu'`` or
- ``'gpu'``. Default: ``'cpu'``.
- prefix (str, optional): The prefix that will be added in the metric
- names to disambiguate homonymous metrics of different evaluators.
- If prefix is not provided in the argument, ``self.default_prefix``
- will be used instead. Default: ``None``.
+ If not specified, a temp file will be created. Defaults to ``None``
+ **kwargs: Keyword parameters passed to :class:`mmeval.BaseMetric`
"""
default_prefix: Optional[str] = 'coco-wholebody'
body_num = 17
@@ -81,6 +78,101 @@ class CocoWholeBodyMetric(CocoMetric):
left_hand_num = 21
right_hand_num = 21
+ def gt_to_coco_json(self, gt_dicts: Sequence[dict],
+ outfile_prefix: str) -> str:
+ """Convert ground truth to coco format json file.
+
+ Args:
+ gt_dicts (Sequence[dict]): Ground truth of the dataset. Each dict
+ contains the ground truth information about the data sample.
+ Required keys of the each `gt_dict` in `gt_dicts`:
+ - `img_id`: image id of the data sample
+ - `width`: original image width
+ - `height`: original image height
+ - `raw_ann_info`: the raw annotation information
+ Optional keys:
+ - `crowd_index`: measure the crowding level of an image,
+ defined in CrowdPose dataset
+ It is worth mentioning that, in order to compute `CocoMetric`,
+ there are some required keys in the `raw_ann_info`:
+ - `id`: the id to distinguish different annotations
+ - `image_id`: the image id of this annotation
+ - `category_id`: the category of the instance.
+ - `bbox`: the object bounding box
+ - `keypoints`: the keypoints cooridinates along with their
+ visibilities. Note that it need to be aligned
+ with the official COCO format, e.g., a list with length
+ N * 3, in which N is the number of keypoints. And each
+ triplet represent the [x, y, visible] of the keypoint.
+ - 'keypoints'
+ - `iscrowd`: indicating whether the annotation is a crowd.
+ It is useful when matching the detection results to
+ the ground truth.
+ There are some optional keys as well:
+ - `area`: it is necessary when `self.use_area` is `True`
+ - `num_keypoints`: it is necessary when `self.iou_type`
+ is set as `keypoints_crowd`.
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json file will be named
+ "somepath/xxx.gt.json".
+ Returns:
+ str: The filename of the json file.
+ """
+ image_infos = []
+ annotations = []
+ img_ids = []
+ ann_ids = []
+
+ for gt_dict in gt_dicts:
+ # filter duplicate image_info
+ if gt_dict['img_id'] not in img_ids:
+ image_info = dict(
+ id=gt_dict['img_id'],
+ width=gt_dict['width'],
+ height=gt_dict['height'],
+ )
+ if self.iou_type == 'keypoints_crowd':
+ image_info['crowdIndex'] = gt_dict['crowd_index']
+
+ image_infos.append(image_info)
+ img_ids.append(gt_dict['img_id'])
+
+ # filter duplicate annotations
+ for ann in gt_dict['raw_ann_info']:
+ annotation = dict(
+ id=ann['id'],
+ image_id=ann['image_id'],
+ category_id=ann['category_id'],
+ bbox=ann['bbox'],
+ keypoints=ann['keypoints'],
+ foot_kpts=ann['foot_kpts'],
+ face_kpts=ann['face_kpts'],
+ lefthand_kpts=ann['lefthand_kpts'],
+ righthand_kpts=ann['righthand_kpts'],
+ iscrowd=ann['iscrowd'],
+ )
+ if self.use_area:
+ assert 'area' in ann, \
+ '`area` is required when `self.use_area` is `True`'
+ annotation['area'] = ann['area']
+
+ annotations.append(annotation)
+ ann_ids.append(ann['id'])
+
+ info = dict(
+ date_created=str(datetime.datetime.now()),
+ description='Coco json file converted by mmpose CocoMetric.')
+ coco_json: dict = dict(
+ info=info,
+ images=image_infos,
+ categories=self.dataset_meta['CLASSES'],
+ licenses=None,
+ annotations=annotations,
+ )
+ converted_json_path = f'{outfile_prefix}.gt.json'
+ dump(coco_json, converted_json_path, sort_keys=True, indent=4)
+ return converted_json_path
+
def results2json(self, keypoints: Dict[int, list],
outfile_prefix: str) -> str:
"""Dump the keypoint detection results to a COCO style json file.
diff --git a/mmpose/evaluation/metrics/posetrack18_metric.py b/mmpose/evaluation/metrics/posetrack18_metric.py
index 85d06c37ee..86f801455a 100644
--- a/mmpose/evaluation/metrics/posetrack18_metric.py
+++ b/mmpose/evaluation/metrics/posetrack18_metric.py
@@ -28,7 +28,9 @@ class PoseTrack18Metric(CocoMetric):
for more details.
Args:
- ann_file (str): Path to the annotation file.
+ ann_file (str, optional): Path to the coco format annotation file.
+ If not specified, ground truth annotations from the dataset will
+ be converted to coco format. Defaults to None
score_mode (str): The mode to score the prediction results which
should be one of the following options:
@@ -37,7 +39,7 @@ class PoseTrack18Metric(CocoMetric):
- ``'bbox_keypoint'``: Use keypoint score to rescore the
prediction results.
- Defaults to ``'bbox'`
+ Defaults to ``'bbox_keypoint'`
keypoint_score_thr (float): The threshold of keypoint score. The
keypoints with score lower than it will not be included to
rescore the prediction results. Valid only when ``score_mode`` is
@@ -61,22 +63,16 @@ class PoseTrack18Metric(CocoMetric):
doing quantitative evaluation. This is designed for the need of
test submission when the ground truth annotations are absent. If
set to ``True``, ``outfile_prefix`` should specify the path to
- store the output results. Default: ``False``.
+ store the output results. Defaults to ``False``
outfile_prefix (str | None): The prefix of json files. It includes
the file path and the prefix of filename, e.g., ``'a/b/prefix'``.
- If not specified, a temp file will be created. Default: ``None``.
- collect_device (str): Device name used for collecting results from
- different ranks during distributed training. Must be ``'cpu'`` or
- ``'gpu'``. Default: ``'cpu'``.
- prefix (str, optional): The prefix that will be added in the metric
- names to disambiguate homonymous metrics of different evaluators.
- If prefix is not provided in the argument, ``self.default_prefix``
- will be used instead. Default: ``None``.
+ If not specified, a temp file will be created. Defaults to ``None``
+ **kwargs: Keyword parameters passed to :class:`mmeval.BaseMetric`
"""
default_prefix: Optional[str] = 'posetrack18'
def __init__(self,
- ann_file: str,
+ ann_file: Optional[str] = None,
score_mode: str = 'bbox_keypoint',
keypoint_score_thr: float = 0.2,
nms_mode: str = 'oks_nms',
@@ -216,7 +212,7 @@ def _do_python_keypoint_eval(self, outfile_prefix: str) -> List[tuple]:
stats_names = [
'Head AP', 'Shou AP', 'Elb AP', 'Wri AP', 'Hip AP', 'Knee AP',
- 'Ankl AP', 'Total AP'
+ 'Ankl AP', 'AP'
]
info_str = list(zip(stats_names, stats))
diff --git a/tests/test_evaluation/test_metrics/test_coco_wholebody_metric.py b/tests/test_evaluation/test_metrics/test_coco_wholebody_metric.py
new file mode 100644
index 0000000000..46e8498851
--- /dev/null
+++ b/tests/test_evaluation/test_metrics/test_coco_wholebody_metric.py
@@ -0,0 +1,294 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import os.path as osp
+import tempfile
+from collections import defaultdict
+from unittest import TestCase
+
+import numpy as np
+from mmengine.fileio import dump, load
+from xtcocotools.coco import COCO
+
+from mmpose.datasets.datasets.utils import parse_pose_metainfo
+from mmpose.evaluation.metrics import CocoWholeBodyMetric
+
+
+class TestCocoWholeBodyMetric(TestCase):
+
+ def setUp(self):
+ """Setup some variables which are used in every test method.
+
+ TestCase calls functions in this order: setUp() -> testMethod() ->
+ tearDown() -> cleanUp()
+ """
+ self.tmp_dir = tempfile.TemporaryDirectory()
+
+ self.ann_file_coco = 'tests/data/coco/test_coco_wholebody.json'
+ meta_info_coco = dict(
+ from_file='configs/_base_/datasets/coco_wholebody.py')
+ self.dataset_meta_coco = parse_pose_metainfo(meta_info_coco)
+ self.coco = COCO(self.ann_file_coco)
+ self.dataset_meta_coco['CLASSES'] = self.coco.loadCats(
+ self.coco.getCatIds())
+
+ self.topdown_data_coco = self._convert_ann_to_topdown_batch_data(
+ self.ann_file_coco)
+ assert len(self.topdown_data_coco) == 14
+ self.bottomup_data_coco = self._convert_ann_to_bottomup_batch_data(
+ self.ann_file_coco)
+ assert len(self.bottomup_data_coco) == 4
+ self.target_coco = {
+ 'coco-wholebody/AP': 1.0,
+ 'coco-wholebody/AP .5': 1.0,
+ 'coco-wholebody/AP .75': 1.0,
+ 'coco-wholebody/AP (M)': 1.0,
+ 'coco-wholebody/AP (L)': 1.0,
+ 'coco-wholebody/AR': 1.0,
+ 'coco-wholebody/AR .5': 1.0,
+ 'coco-wholebody/AR .75': 1.0,
+ 'coco-wholebody/AR (M)': 1.0,
+ 'coco-wholebody/AR (L)': 1.0,
+ }
+
+ def _convert_ann_to_topdown_batch_data(self, ann_file):
+ """Convert annotations to topdown-style batch data."""
+ topdown_data = []
+ db = load(ann_file)
+ imgid2info = dict()
+ for img in db['images']:
+ imgid2info[img['id']] = img
+ for ann in db['annotations']:
+ w, h = ann['bbox'][2], ann['bbox'][3]
+ bboxes = np.array(ann['bbox'], dtype=np.float32).reshape(-1, 4)
+ bbox_scales = np.array([w * 1.25, h * 1.25]).reshape(-1, 2)
+ _keypoints = np.array(ann['keypoints'] + ann['foot_kpts'] +
+ ann['face_kpts'] + ann['lefthand_kpts'] +
+ ann['righthand_kpts']).reshape(1, -1, 3)
+
+ gt_instances = {
+ 'bbox_scales': bbox_scales,
+ 'bbox_scores': np.ones((1, ), dtype=np.float32),
+ 'bboxes': bboxes,
+ }
+ pred_instances = {
+ 'keypoints': _keypoints[..., :2],
+ 'keypoint_scores': _keypoints[..., -1],
+ }
+
+ data = {'inputs': None}
+ data_sample = {
+ 'id': ann['id'],
+ 'img_id': ann['image_id'],
+ 'category_id': ann.get('category_id', 1),
+ 'gt_instances': gt_instances,
+ 'pred_instances': pred_instances,
+ # dummy image_shape for testing
+ 'ori_shape': [640, 480],
+ # store the raw annotation info to test without ann_file
+ 'raw_ann_info': copy.deepcopy(ann),
+ }
+
+ # batch size = 1
+ data_batch = [data]
+ data_samples = [data_sample]
+ topdown_data.append((data_batch, data_samples))
+
+ return topdown_data
+
+ def _convert_ann_to_bottomup_batch_data(self, ann_file):
+ """Convert annotations to bottomup-style batch data."""
+ img2ann = defaultdict(list)
+ db = load(ann_file)
+ for ann in db['annotations']:
+ img2ann[ann['image_id']].append(ann)
+
+ bottomup_data = []
+ for img_id, anns in img2ann.items():
+ _keypoints = []
+ for ann in anns:
+ _keypoints.append(ann['keypoints'] + ann['foot_kpts'] +
+ ann['face_kpts'] + ann['lefthand_kpts'] +
+ ann['righthand_kpts'])
+ keypoints = np.array(_keypoints).reshape((len(anns), -1, 3))
+
+ gt_instances = {
+ 'bbox_scores': np.ones((len(anns)), dtype=np.float32)
+ }
+
+ pred_instances = {
+ 'keypoints': keypoints[..., :2],
+ 'keypoint_scores': keypoints[..., -1],
+ }
+
+ data = {'inputs': None}
+ data_sample = {
+ 'id': [ann['id'] for ann in anns],
+ 'img_id': img_id,
+ 'gt_instances': gt_instances,
+ 'pred_instances': pred_instances
+ }
+
+ # batch size = 1
+ data_batch = [data]
+ data_samples = [data_sample]
+ bottomup_data.append((data_batch, data_samples))
+ return bottomup_data
+
+ def tearDown(self):
+ self.tmp_dir.cleanup()
+
+ def test_init(self):
+ """test metric init method."""
+ # test score_mode option
+ with self.assertRaisesRegex(ValueError,
+ '`score_mode` should be one of'):
+ _ = CocoWholeBodyMetric(
+ ann_file=self.ann_file_coco, score_mode='invalid')
+
+ # test nms_mode option
+ with self.assertRaisesRegex(ValueError, '`nms_mode` should be one of'):
+ _ = CocoWholeBodyMetric(
+ ann_file=self.ann_file_coco, nms_mode='invalid')
+
+ # test format_only option
+ with self.assertRaisesRegex(
+ AssertionError,
+ '`outfile_prefix` can not be None when `format_only` is True'):
+ _ = CocoWholeBodyMetric(
+ ann_file=self.ann_file_coco,
+ format_only=True,
+ outfile_prefix=None)
+
+ def test_other_methods(self):
+ """test other useful methods."""
+ # test `_sort_and_unique_bboxes` method
+ metric_coco = CocoWholeBodyMetric(
+ ann_file=self.ann_file_coco, score_mode='bbox', nms_mode='none')
+ metric_coco.dataset_meta = self.dataset_meta_coco
+ # process samples
+ for data_batch, data_samples in self.topdown_data_coco:
+ metric_coco.process(data_batch, data_samples)
+ # process one extra sample
+ data_batch, data_samples = self.topdown_data_coco[0]
+ metric_coco.process(data_batch, data_samples)
+ # an extra sample
+ eval_results = metric_coco.evaluate(
+ size=len(self.topdown_data_coco) + 1)
+ self.assertDictEqual(eval_results, self.target_coco)
+
+ def test_format_only(self):
+ """test `format_only` option."""
+ metric_coco = CocoWholeBodyMetric(
+ ann_file=self.ann_file_coco,
+ format_only=True,
+ outfile_prefix=f'{self.tmp_dir.name}/test',
+ score_mode='bbox_keypoint',
+ nms_mode='oks_nms')
+ metric_coco.dataset_meta = self.dataset_meta_coco
+ # process one sample
+ data_batch, data_samples = self.topdown_data_coco[0]
+ metric_coco.process(data_batch, data_samples)
+ eval_results = metric_coco.evaluate(size=1)
+ self.assertDictEqual(eval_results, {})
+ self.assertTrue(
+ osp.isfile(osp.join(self.tmp_dir.name, 'test.keypoints.json')))
+
+ # test when gt annotations are absent
+ db_ = load(self.ann_file_coco)
+ del db_['annotations']
+ tmp_ann_file = osp.join(self.tmp_dir.name, 'temp_ann.json')
+ dump(db_, tmp_ann_file, sort_keys=True, indent=4)
+ with self.assertRaisesRegex(
+ AssertionError,
+ 'Ground truth annotations are required for evaluation'):
+ _ = CocoWholeBodyMetric(ann_file=tmp_ann_file, format_only=False)
+
+ def test_bottomup_evaluate(self):
+ """test bottomup-style COCO metric evaluation."""
+ # case1: score_mode='bbox', nms_mode='none'
+ metric_coco = CocoWholeBodyMetric(
+ ann_file=self.ann_file_coco,
+ outfile_prefix=f'{self.tmp_dir.name}/test',
+ score_mode='bbox',
+ nms_mode='none')
+ metric_coco.dataset_meta = self.dataset_meta_coco
+
+ # process samples
+ for data_batch, data_samples in self.bottomup_data_coco:
+ metric_coco.process(data_batch, data_samples)
+
+ eval_results = metric_coco.evaluate(size=len(self.bottomup_data_coco))
+ self.assertDictEqual(eval_results, self.target_coco)
+ self.assertTrue(
+ osp.isfile(osp.join(self.tmp_dir.name, 'test.keypoints.json')))
+
+ def test_topdown_evaluate(self):
+ """test topdown-style COCO metric evaluation."""
+ # case 1: score_mode='bbox', nms_mode='none'
+ metric_coco = CocoWholeBodyMetric(
+ ann_file=self.ann_file_coco,
+ outfile_prefix=f'{self.tmp_dir.name}/test1',
+ score_mode='bbox',
+ nms_mode='none')
+ metric_coco.dataset_meta = self.dataset_meta_coco
+
+ # process samples
+ for data_batch, data_samples in self.topdown_data_coco:
+ metric_coco.process(data_batch, data_samples)
+
+ eval_results = metric_coco.evaluate(size=len(self.topdown_data_coco))
+
+ self.assertDictEqual(eval_results, self.target_coco)
+ self.assertTrue(
+ osp.isfile(osp.join(self.tmp_dir.name, 'test1.keypoints.json')))
+
+ # case 2: score_mode='bbox_keypoint', nms_mode='oks_nms'
+ metric_coco = CocoWholeBodyMetric(
+ ann_file=self.ann_file_coco,
+ outfile_prefix=f'{self.tmp_dir.name}/test2',
+ score_mode='bbox_keypoint',
+ nms_mode='oks_nms')
+ metric_coco.dataset_meta = self.dataset_meta_coco
+
+ # process samples
+ for data_batch, data_samples in self.topdown_data_coco:
+ metric_coco.process(data_batch, data_samples)
+
+ eval_results = metric_coco.evaluate(size=len(self.topdown_data_coco))
+
+ self.assertDictEqual(eval_results, self.target_coco)
+ self.assertTrue(
+ osp.isfile(osp.join(self.tmp_dir.name, 'test2.keypoints.json')))
+
+ # case 3: score_mode='bbox_rle', nms_mode='soft_oks_nms'
+ metric_coco = CocoWholeBodyMetric(
+ ann_file=self.ann_file_coco,
+ outfile_prefix=f'{self.tmp_dir.name}/test3',
+ score_mode='bbox_rle',
+ nms_mode='soft_oks_nms')
+ metric_coco.dataset_meta = self.dataset_meta_coco
+
+ # process samples
+ for data_batch, data_samples in self.topdown_data_coco:
+ metric_coco.process(data_batch, data_samples)
+
+ eval_results = metric_coco.evaluate(size=len(self.topdown_data_coco))
+
+ self.assertDictEqual(eval_results, self.target_coco)
+ self.assertTrue(
+ osp.isfile(osp.join(self.tmp_dir.name, 'test3.keypoints.json')))
+
+ # case 4: test without providing ann_file
+ metric_coco = CocoWholeBodyMetric(
+ outfile_prefix=f'{self.tmp_dir.name}/test4')
+ metric_coco.dataset_meta = self.dataset_meta_coco
+ # process samples
+ for data_batch, data_samples in self.topdown_data_coco:
+ metric_coco.process(data_batch, data_samples)
+ eval_results = metric_coco.evaluate(size=len(self.topdown_data_coco))
+ self.assertDictEqual(eval_results, self.target_coco)
+ # test whether convert the annotation to COCO format
+ self.assertTrue(
+ osp.isfile(osp.join(self.tmp_dir.name, 'test4.gt.json')))
+ self.assertTrue(
+ osp.isfile(osp.join(self.tmp_dir.name, 'test4.keypoints.json')))
diff --git a/tests/test_evaluation/test_metrics/test_posetrack18_metric.py b/tests/test_evaluation/test_metrics/test_posetrack18_metric.py
index 09a6c56435..fe44047e31 100644
--- a/tests/test_evaluation/test_metrics/test_posetrack18_metric.py
+++ b/tests/test_evaluation/test_metrics/test_posetrack18_metric.py
@@ -41,7 +41,7 @@ def setUp(self):
'posetrack18/Hip AP': 100.0,
'posetrack18/Knee AP': 100.0,
'posetrack18/Ankl AP': 100.0,
- 'posetrack18/Total AP': 100.0,
+ 'posetrack18/AP': 100.0,
}
def _convert_ann_to_topdown_batch_data(self):