Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Feature] Add RTMPose-Wholebody #2721

Merged
merged 12 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://link.springer.com/chapter/10.1007/978-3-030-58580-8_27">RTMPose (arXiv'2023)</a></summary>

```bibtex
@misc{https://doi.org/10.48550/arxiv.2303.07399,
doi = {10.48550/ARXIV.2303.07399},
url = {https://arxiv.org/abs/2303.07399},
author = {Jiang, Tao and Lu, Peng and Zhang, Li and Ma, Ningsheng and Han, Rui and Lyu, Chengqi and Li, Yining and Chen, Kai},
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {RTMPose: Real-Time Multi-Person Pose Estimation based on MMPose},
publisher = {arXiv},
year = {2023},
copyright = {Creative Commons Attribution 4.0 International}
}

```

</details>

<!-- [BACKBONE] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2212.07784">RTMDet (arXiv'2022)</a></summary>

```bibtex
@misc{lyu2022rtmdet,
title={RTMDet: An Empirical Study of Designing Real-Time Object Detectors},
author={Chengqi Lyu and Wenwei Zhang and Haian Huang and Yue Zhou and Yudong Wang and Yanyi Liu and Shilong Zhang and Kai Chen},
year={2022},
eprint={2212.07784},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

</details>

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://link.springer.com/chapter/10.1007/978-3-030-58545-7_12">COCO-WholeBody (ECCV'2020)</a></summary>

```bibtex
@inproceedings{jin2020whole,
title={Whole-Body Human Pose Estimation in the Wild},
author={Jin, Sheng and Xu, Lumin and Xu, Jin and Wang, Can and Liu, Wentao and Qian, Chen and Ouyang, Wanli and Luo, Ping},
booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
year={2020}
}
```

</details>

- `Cocktail13` denotes model trained on 13 public datasets:
- [AI Challenger](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#aic)
- [CrowdPose](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#crowdpose)
- [MPII](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#mpii)
- [sub-JHMDB](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#sub-jhmdb-dataset)
- [Halpe](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_wholebody_keypoint.html#halpe)
- [PoseTrack18](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#posetrack18)
- [COCO-Wholebody](https://github.com/jin-s13/COCO-WholeBody/)
- [UBody](https://github.com/IDEA-Research/OSX)
- [Human-Art](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#human-art-dataset)
- [WFLW](https://wywu.github.io/projects/LAB/WFLW.html)
- [300W](https://ibug.doc.ic.ac.uk/resources/300-W/)
- [COFW](http://www.vision.caltech.edu/xpburgos/ICCV13/)
- [LaPa](https://github.com/JDAI-CV/lapa-dataset)

Results on COCO-WholeBody v1.0 val with detector having human AP of 56.4 on COCO val2017 dataset

| Arch | Input Size | Body AP | Body AR | Foot AP | Foot AR | Face AP | Face AR | Hand AP | Hand AR | Whole AP | Whole AR | ckpt | log |
| :-------------------------------------- | :--------: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :------: | :------: | :--------------------------------------: | :-------------------------------------: |
| [rtmw-x](/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw-x_8xb320-270e_cocktail13-256x192.py) | 256x192 | 0.753 | 0.815 | 0.773 | 0.869 | 0.843 | 0.894 | 0.602 | 0.703 | 0.672 | 0.754 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmw/rtmw-x_simcc-cocktail13_pt-ucoco_270e-256x192-fbef0d61_20230925.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmw/rtmw-x_simcc-cocktail13_pt-ucoco_270e-256x192-fbef0d61_20230925.json) |
| [rtmw-x](/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw-x_8xb320-270e_cocktail13-384x288.py) | 384x288 | 0.764 | 0.825 | 0.791 | 0.883 | 0.882 | 0.922 | 0.654 | 0.744 | 0.702 | 0.779 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmw/rtmw-x_simcc-cocktail13_pt-ucoco_270e-384x288-0949e3a9_20230925.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmw/rtmw-x_simcc-cocktail13_pt-ucoco_270e-384x288-0949e3a9_20230925.json) |
5 changes: 3 additions & 2 deletions mmpose/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
GenerateTarget, GetBBoxCenterScale,
PhotometricDistortion, RandomBBoxTransform,
RandomFlip, RandomHalfBody, YOLOXHSVRandomAug)
from .converting import KeypointConverter
from .converting import KeypointConverter, SingleHandConverter
from .formatting import PackPoseInputs
from .hand_transforms import HandRandomFlip
from .loading import LoadImage
Expand All @@ -21,5 +21,6 @@
'BottomupGetHeatmapMask', 'BottomupRandomAffine', 'BottomupResize',
'GenerateTarget', 'KeypointConverter', 'RandomFlipAroundRoot',
'FilterAnnotations', 'YOLOXHSVRandomAug', 'YOLOXMixUp', 'Mosaic',
'BottomupRandomCrop', 'BottomupRandomChoiceResize', 'HandRandomFlip'
'BottomupRandomCrop', 'BottomupRandomChoiceResize', 'HandRandomFlip',
'SingleHandConverter'
]
11 changes: 8 additions & 3 deletions mmpose/datasets/transforms/common_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,21 +503,26 @@ def _get_transform_params(self, num_bboxes: int) -> Tuple:
- scale (np.ndarray): Scaling factor of each bbox in shape (n, 1)
- rotate (np.ndarray): Rotation degree of each bbox in shape (n,)
"""
random_v = self._truncnorm(size=(num_bboxes, 4))
offset_v = random_v[:, :2]
scale_v = random_v[:, 2:3]
rotate_v = random_v[:, 3]

# Get shift parameters
offset = self._truncnorm(size=(num_bboxes, 2)) * self.shift_factor
offset = offset_v * self.shift_factor
offset = np.where(
np.random.rand(num_bboxes, 1) < self.shift_prob, offset, 0.)

# Get scaling parameters
scale_min, scale_max = self.scale_factor
mu = (scale_max + scale_min) * 0.5
sigma = (scale_max - scale_min) * 0.5
scale = self._truncnorm(size=(num_bboxes, 1)) * sigma + mu
scale = scale_v * sigma + mu
scale = np.where(
np.random.rand(num_bboxes, 1) < self.scale_prob, scale, 1.)

# Get rotation parameters
rotate = self._truncnorm(size=(num_bboxes, )) * self.rotate_factor
rotate = rotate_v * self.rotate_factor
rotate = np.where(
np.random.rand(num_bboxes) < self.rotate_prob, rotate, 0.)

Expand Down
88 changes: 87 additions & 1 deletion mmpose/datasets/transforms/converting.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def __init__(self, num_keypoints: int,
int]]]):
self.num_keypoints = num_keypoints
self.mapping = mapping
source_index, target_index = zip(*mapping)
if len(mapping):
source_index, target_index = zip(*mapping)
else:
source_index, target_index = [], []

src1, src2 = [], []
interpolation = False
Expand All @@ -89,6 +92,9 @@ def __init__(self, num_keypoints: int,
def transform(self, results: dict) -> dict:
"""Transforms the keypoint results to match the target keypoints."""
num_instances = results['keypoints'].shape[0]
if len(results['keypoints_visible'].shape) > 2:
results['keypoints_visible'] = results['keypoints_visible'][:, :,
0]

# Initialize output arrays
keypoints = np.zeros((num_instances, self.num_keypoints, 3))
Expand Down Expand Up @@ -186,3 +192,83 @@ def __repr__(self) -> str:
repr_str += f'(num_keypoints={self.num_keypoints}, '\
f'mapping={self.mapping})'
return repr_str


@TRANSFORMS.register_module()
class SingleHandConverter(BaseTransform):
"""Mapping a single hand keypoints into double hands according to the given
mapping and hand type.

Required Keys:

- keypoints
- keypoints_visible
- hand_type

Modified Keys:

- keypoints
- keypoints_visible

Args:
num_keypoints (int): The number of keypoints in target dataset.
left_hand_mapping (list): A list containing mapping indexes. Each
element has format (source_index, target_index)
right_hand_mapping (list): A list containing mapping indexes. Each
element has format (source_index, target_index)

Example:
>>> import numpy as np
>>> self = SingleHandConverter(
>>> num_keypoints=42,
>>> left_hand_mapping=[
>>> (0, 0), (1, 1), (2, 2), (3, 3)
>>> ],
>>> right_hand_mapping=[
>>> (0, 21), (1, 22), (2, 23), (3, 24)
>>> ])
>>> results = dict(
>>> keypoints=np.arange(84).reshape(2, 21, 2),
>>> keypoints_visible=np.arange(84).reshape(2, 21, 2) % 2,
>>> hand_type=np.array([[0, 1], [1, 0]]))
>>> results = self(results)
"""

def __init__(self, num_keypoints: int,
left_hand_mapping: Union[List[Tuple[int, int]],
List[Tuple[Tuple, int]]],
right_hand_mapping: Union[List[Tuple[int, int]],
List[Tuple[Tuple, int]]]):
self.num_keypoints = num_keypoints
self.left_hand_converter = KeypointConverter(num_keypoints,
left_hand_mapping)
self.right_hand_converter = KeypointConverter(num_keypoints,
right_hand_mapping)

def transform(self, results: dict) -> dict:
"""Transforms the keypoint results to match the target keypoints."""
assert 'hand_type' in results, (
'hand_type should be provided in results')
hand_type = results['hand_type']

if np.sum(hand_type - [[0, 1]]) <= 1e-6:
# left hand
results = self.left_hand_converter(results)
elif np.sum(hand_type - [[1, 0]]) <= 1e-6:
results = self.right_hand_converter(results)
else:
raise ValueError('hand_type should be left or right')

return results

def __repr__(self) -> str:
"""print the basic information of the transform.

Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__
repr_str += f'(num_keypoints={self.num_keypoints}, '\
f'left_hand_converter={self.left_hand_converter}, '\
f'right_hand_converter={self.right_hand_converter})'
return repr_str
4 changes: 2 additions & 2 deletions mmpose/models/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_head import BaseHead
from .coord_cls_heads import RTMCCHead, SimCCHead
from .coord_cls_heads import RTMCCHead, RTMWHead, SimCCHead
from .heatmap_heads import (AssociativeEmbeddingHead, CIDHead, CPMHead,
HeatmapHead, InternetHead, MSPNHead, ViPNASHead)
from .hybrid_heads import DEKRHead, VisPredictHead
Expand All @@ -16,5 +16,5 @@
'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'VisPredictHead',
'CIDHead', 'RTMCCHead', 'TemporalRegressionHead',
'TrajectoryRegressionHead', 'MotionRegressionHead', 'EDPoseHead',
'InternetHead'
'InternetHead', 'RTMWHead'
]
3 changes: 2 additions & 1 deletion mmpose/models/heads/coord_cls_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .rtmcc_head import RTMCCHead
from .rtmw_head import RTMWHead
from .simcc_head import SimCCHead

__all__ = ['SimCCHead', 'RTMCCHead']
__all__ = ['SimCCHead', 'RTMCCHead', 'RTMWHead']
Loading