diff --git a/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw-x_8xb320-270e_cocktail13-384x288.py b/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw-x_8xb320-270e_cocktail13-384x288.py
new file mode 100644
index 0000000000..d525df746b
--- /dev/null
+++ b/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw-x_8xb320-270e_cocktail13-384x288.py
@@ -0,0 +1,607 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmengine.config import read_base
+
+with read_base():
+ from mmpose.configs._base_.default_runtime import * # noqa
+
+from albumentations.augmentations import Blur, CoarseDropout, MedianBlur
+from mmdet.engine.hooks import PipelineSwitchHook
+from mmengine.dataset import DefaultSampler
+from mmengine.hooks import EMAHook
+from mmengine.model import PretrainedInit
+from mmengine.optim import CosineAnnealingLR, LinearLR, OptimWrapper
+from torch.nn import SiLU, SyncBatchNorm
+from torch.optim import AdamW
+
+from mmpose.codecs import SimCCLabel
+from mmpose.datasets import (AicDataset, CocoWholeBodyDataset, COFWDataset,
+ CombinedDataset, CrowdPoseDataset,
+ Face300WDataset, GenerateTarget,
+ GetBBoxCenterScale, HalpeDataset,
+ HumanArt21Dataset, JhmdbDataset,
+ KeypointConverter, LapaDataset, LoadImage,
+ MpiiDataset, PackPoseInputs, PoseTrack18Dataset,
+ RandomFlip, RandomHalfBody, TopdownAffine,
+ UBody2dDataset, WFLWDataset)
+from mmpose.datasets.transforms.common_transforms import (
+ Albumentation, PhotometricDistortion, RandomBBoxTransform)
+from mmpose.engine.hooks import ExpMomentumEMA
+from mmpose.evaluation import CocoWholeBodyMetric
+from mmpose.models import (CSPNeXt, CSPNeXtPAFPN, KLDiscretLoss,
+ PoseDataPreprocessor, RTMWHead,
+ TopdownPoseEstimator)
+
+# common setting
+num_keypoints = 133
+input_size = (288, 384)
+
+# runtime
+max_epochs = 270
+stage2_num_epochs = 10
+base_lr = 5e-4
+train_batch_size = 320
+val_batch_size = 32
+
+train_cfg = dict(max_epochs=max_epochs, val_interval=10)
+randomness = dict(seed=21)
+
+# optimizer
+optim_wrapper = dict(
+ type=OptimWrapper,
+ optimizer=dict(type=AdamW, lr=base_lr, weight_decay=0.05),
+ clip_grad=dict(max_norm=35, norm_type=2),
+ paramwise_cfg=dict(
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
+
+# learning rate
+param_scheduler = [
+ dict(
+ type=LinearLR, start_factor=1.0e-5, by_epoch=False, begin=0, end=1000),
+ dict(
+ type=CosineAnnealingLR,
+ eta_min=base_lr * 0.05,
+ begin=max_epochs // 2,
+ end=max_epochs,
+ T_max=max_epochs // 2,
+ by_epoch=True,
+ convert_to_iter_based=True),
+]
+
+# automatically scaling LR based on the actual training batch size
+auto_scale_lr = dict(base_batch_size=5632)
+
+# codec settings
+codec = dict(
+ type=SimCCLabel,
+ input_size=input_size,
+ sigma=(6., 6.93),
+ simcc_split_ratio=2.0,
+ normalize=False,
+ use_dark=False)
+
+# model settings
+model = dict(
+ type=TopdownPoseEstimator,
+ data_preprocessor=dict(
+ type=PoseDataPreprocessor,
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True),
+ backbone=dict(
+ type=CSPNeXt,
+ arch='P5',
+ expand_ratio=0.5,
+ deepen_factor=1.33,
+ widen_factor=1.25,
+ channel_attention=True,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type=SiLU),
+ init_cfg=dict(
+ type=PretrainedInit,
+ prefix='backbone.',
+ checkpoint='https://download.openmmlab.com/mmpose/v1/'
+ 'wholebody_2d_keypoint/rtmpose/ubody/rtmpose-x_simcc-ucoco_pt-aic-coco_270e-384x288-f5b50679_20230822.pth' # noqa
+ )),
+ neck=dict(
+ type=CSPNeXtPAFPN,
+ in_channels=[320, 640, 1280],
+ out_channels=None,
+ out_indices=(
+ 1,
+ 2,
+ ),
+ num_csp_blocks=2,
+ expand_ratio=0.5,
+ norm_cfg=dict(type=SyncBatchNorm),
+ act_cfg=dict(type=SiLU, inplace=True)),
+ head=dict(
+ type=RTMWHead,
+ in_channels=1280,
+ out_channels=num_keypoints,
+ input_size=input_size,
+ in_featuremap_size=tuple([s // 32 for s in input_size]),
+ simcc_split_ratio=codec['simcc_split_ratio'],
+ final_layer_kernel_size=7,
+ gau_cfg=dict(
+ hidden_dims=256,
+ s=128,
+ expansion_factor=2,
+ dropout_rate=0.,
+ drop_path=0.,
+ act_fn=SiLU,
+ use_rel_bias=False,
+ pos_enc=False),
+ loss=dict(
+ type=KLDiscretLoss,
+ use_target_weight=True,
+ beta=10.,
+ label_softmax=True),
+ decoder=codec),
+ test_cfg=dict(flip_test=True))
+
+# base dataset settings
+dataset_type = CocoWholeBodyDataset
+data_mode = 'topdown'
+data_root = 'data/'
+
+backend_args = dict(backend='local')
+
+# pipelines
+train_pipeline = [
+ dict(type=LoadImage, backend_args=backend_args),
+ dict(type=GetBBoxCenterScale),
+ dict(type=RandomFlip, direction='horizontal'),
+ dict(type=RandomHalfBody),
+ dict(type=RandomBBoxTransform, scale_factor=[0.5, 1.5], rotate_factor=90),
+ dict(type=TopdownAffine, input_size=codec['input_size']),
+ dict(type=PhotometricDistortion),
+ dict(
+ type=Albumentation,
+ transforms=[
+ dict(type=Blur, p=0.1),
+ dict(type=MedianBlur, p=0.1),
+ dict(
+ type=CoarseDropout,
+ max_holes=1,
+ max_height=0.4,
+ max_width=0.4,
+ min_holes=1,
+ min_height=0.2,
+ min_width=0.2,
+ p=1.0),
+ ]),
+ dict(
+ type=GenerateTarget, encoder=codec, use_dataset_keypoint_weights=True),
+ dict(type=PackPoseInputs)
+]
+val_pipeline = [
+ dict(type=LoadImage, backend_args=backend_args),
+ dict(type=GetBBoxCenterScale),
+ dict(type=TopdownAffine, input_size=codec['input_size']),
+ dict(type=PackPoseInputs)
+]
+
+train_pipeline_stage2 = [
+ dict(type=LoadImage, backend_args=backend_args),
+ dict(type=GetBBoxCenterScale),
+ dict(type=RandomFlip, direction='horizontal'),
+ dict(type=RandomHalfBody),
+ dict(
+ type=RandomBBoxTransform,
+ shift_factor=0.,
+ scale_factor=[0.5, 1.5],
+ rotate_factor=90),
+ dict(type=TopdownAffine, input_size=codec['input_size']),
+ dict(
+ type=Albumentation,
+ transforms=[
+ dict(type=Blur, p=0.1),
+ dict(type=MedianBlur, p=0.1),
+ dict(
+ type=CoarseDropout,
+ max_holes=1,
+ max_height=0.4,
+ max_width=0.4,
+ min_holes=1,
+ min_height=0.2,
+ min_width=0.2,
+ p=0.5),
+ ]),
+ dict(
+ type=GenerateTarget, encoder=codec, use_dataset_keypoint_weights=True),
+ dict(type=PackPoseInputs)
+]
+
+# mapping
+
+aic_coco133 = [(0, 6), (1, 8), (2, 10), (3, 5), (4, 7), (5, 9), (6, 12),
+ (7, 14), (8, 16), (9, 11), (10, 13), (11, 15)]
+
+crowdpose_coco133 = [(0, 5), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (6, 11),
+ (7, 12), (8, 13), (9, 14), (10, 15), (11, 16)]
+
+mpii_coco133 = [
+ (0, 16),
+ (1, 14),
+ (2, 12),
+ (3, 11),
+ (4, 13),
+ (5, 15),
+ (8, 18),
+ (9, 17),
+ (10, 10),
+ (11, 8),
+ (12, 6),
+ (13, 5),
+ (14, 7),
+ (15, 9),
+]
+
+jhmdb_coco133 = [
+ (0, 18),
+ (2, 17),
+ (3, 6),
+ (4, 5),
+ (5, 12),
+ (6, 11),
+ (7, 8),
+ (8, 7),
+ (9, 14),
+ (10, 13),
+ (11, 10),
+ (12, 9),
+ (13, 16),
+ (14, 15),
+]
+
+halpe_coco133 = [(i, i)
+ for i in range(17)] + [(20, 17), (21, 20), (22, 18), (23, 21),
+ (24, 19),
+ (25, 22)] + [(i, i - 3)
+ for i in range(26, 136)]
+
+posetrack_coco133 = [
+ (0, 0),
+ (2, 17),
+ (3, 3),
+ (4, 4),
+ (5, 5),
+ (6, 6),
+ (7, 7),
+ (8, 8),
+ (9, 9),
+ (10, 10),
+ (11, 11),
+ (12, 12),
+ (13, 13),
+ (14, 14),
+ (15, 15),
+ (16, 16),
+]
+
+humanart_coco133 = [(i, i) for i in range(17)] + [(17, 99), (18, 120),
+ (19, 17), (20, 20)]
+
+# train datasets
+dataset_coco = dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='coco/annotations/coco_wholebody_train_v1.0.json',
+ data_prefix=dict(img='detection/coco/train2017/'),
+ pipeline=[],
+)
+
+dataset_aic = dict(
+ type=AicDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='aic/annotations/aic_train.json',
+ data_prefix=dict(img='pose/ai_challenge/ai_challenger_keypoint'
+ '_train_20170902/keypoint_train_images_20170902/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=aic_coco133)
+ ],
+)
+
+dataset_crowdpose = dict(
+ type=CrowdPoseDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='crowdpose/annotations/mmpose_crowdpose_trainval.json',
+ data_prefix=dict(img='pose/CrowdPose/images/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=crowdpose_coco133)
+ ],
+)
+
+dataset_mpii = dict(
+ type=MpiiDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='mpii/annotations/mpii_train.json',
+ data_prefix=dict(img='pose/MPI/images/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=mpii_coco133)
+ ],
+)
+
+dataset_jhmdb = dict(
+ type=JhmdbDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='jhmdb/annotations/Sub1_train.json',
+ data_prefix=dict(img='pose/JHMDB/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=jhmdb_coco133)
+ ],
+)
+
+dataset_halpe = dict(
+ type=HalpeDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='halpe/annotations/halpe_train_v1.json',
+ data_prefix=dict(img='pose/Halpe/hico_20160224_det/images/train2015'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=halpe_coco133)
+ ],
+)
+
+dataset_posetrack = dict(
+ type=PoseTrack18Dataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='posetrack18/annotations/posetrack18_train.json',
+ data_prefix=dict(img='pose/PoseChallenge2018/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=posetrack_coco133)
+ ],
+)
+
+dataset_humanart = dict(
+ type=HumanArt21Dataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='HumanArt/annotations/training_humanart.json',
+ filter_cfg=dict(scenes=['real_human']),
+ data_prefix=dict(img='pose/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=humanart_coco133)
+ ])
+
+ubody_scenes = [
+ 'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
+ 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
+ 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
+]
+
+ubody_datasets = []
+for scene in ubody_scenes:
+ each = dict(
+ type=UBody2dDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file=f'Ubody/annotations/{scene}/train_annotations.json',
+ data_prefix=dict(img='pose/UBody/images/'),
+ pipeline=[],
+ sample_interval=10)
+ ubody_datasets.append(each)
+
+dataset_ubody = dict(
+ type=CombinedDataset,
+ metainfo=dict(from_file='configs/_base_/datasets/ubody2d.py'),
+ datasets=ubody_datasets,
+ pipeline=[],
+ test_mode=False,
+)
+
+face_pipeline = [
+ dict(type=LoadImage, backend_args=backend_args),
+ dict(type=GetBBoxCenterScale),
+ dict(
+ type=RandomBBoxTransform,
+ shift_factor=0.,
+ scale_factor=[0.3, 0.5],
+ rotate_factor=0),
+ dict(type=TopdownAffine, input_size=codec['input_size']),
+]
+
+wflw_coco133 = [(i * 2, 20 + i)
+ for i in range(17)] + [(33 + i, 41 + i) for i in range(5)] + [
+ (42 + i, 46 + i) for i in range(5)
+ ] + [(51 + i, 50 + i)
+ for i in range(9)] + [(60, 59), (61, 60), (63, 61),
+ (64, 62), (65, 63), (67, 64),
+ (68, 65), (69, 66), (71, 67),
+ (72, 68), (73, 69),
+ (75, 70)] + [(76 + i, 71 + i)
+ for i in range(20)]
+dataset_wflw = dict(
+ type=WFLWDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='wflw/annotations/face_landmarks_wflw_train.json',
+ data_prefix=dict(img='pose/WFLW/images/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=wflw_coco133), *face_pipeline
+ ],
+)
+
+mapping_300w_coco133 = [(i, 20 + i) for i in range(68)]
+dataset_300w = dict(
+ type=Face300WDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='300w/annotations/face_landmarks_300w_train.json',
+ data_prefix=dict(img='pose/300w/images/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=mapping_300w_coco133), *face_pipeline
+ ],
+)
+
+cofw_coco133 = [(0, 41), (2, 45), (4, 43), (1, 47), (3, 49), (6, 45), (8, 59),
+ (10, 62), (9, 68), (11, 65), (18, 54), (19, 58), (20, 53),
+ (21, 56), (22, 71), (23, 77), (24, 74), (25, 85), (26, 89),
+ (27, 80), (28, 28)]
+dataset_cofw = dict(
+ type=COFWDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='cofw/annotations/cofw_train.json',
+ data_prefix=dict(img='pose/COFW/images/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=cofw_coco133), *face_pipeline
+ ],
+)
+
+lapa_coco133 = [(i * 2, 20 + i) for i in range(17)] + [
+ (33 + i, 41 + i) for i in range(5)
+] + [(42 + i, 46 + i) for i in range(5)] + [
+ (51 + i, 50 + i) for i in range(4)
+] + [(58 + i, 54 + i) for i in range(5)] + [(66, 59), (67, 60), (69, 61),
+ (70, 62), (71, 63), (73, 64),
+ (75, 65), (76, 66), (78, 67),
+ (79, 68), (80, 69),
+ (82, 70)] + [(84 + i, 71 + i)
+ for i in range(20)]
+dataset_lapa = dict(
+ type=LapaDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='LaPa/annotations/lapa_trainval.json',
+ data_prefix=dict(img='pose/LaPa/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=lapa_coco133), *face_pipeline
+ ],
+)
+
+dataset_wb = dict(
+ type=CombinedDataset,
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=[dataset_coco, dataset_halpe, dataset_ubody],
+ pipeline=[],
+ test_mode=False,
+)
+
+dataset_body = dict(
+ type=CombinedDataset,
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=[
+ dataset_aic,
+ dataset_crowdpose,
+ dataset_mpii,
+ dataset_jhmdb,
+ dataset_posetrack,
+ dataset_humanart,
+ ],
+ pipeline=[],
+ test_mode=False,
+)
+
+dataset_face = dict(
+ type=CombinedDataset,
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=[
+ dataset_wflw,
+ dataset_300w,
+ dataset_cofw,
+ dataset_lapa,
+ ],
+ pipeline=[],
+ test_mode=False,
+)
+
+train_datasets = [
+ dataset_wb,
+ dataset_body,
+ dataset_face,
+]
+
+# data loaders
+train_dataloader = dict(
+ batch_size=train_batch_size,
+ num_workers=10,
+ pin_memory=True,
+ persistent_workers=True,
+ sampler=dict(type=DefaultSampler, shuffle=True),
+ dataset=dict(
+ type=CombinedDataset,
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=train_datasets,
+ pipeline=train_pipeline,
+ test_mode=False,
+ ))
+
+val_dataloader = dict(
+ batch_size=val_batch_size,
+ num_workers=4,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type=DefaultSampler, shuffle=False, round_up=False),
+ dataset=dict(
+ type=CocoWholeBodyDataset,
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
+ data_prefix=dict(img='data/detection/coco/val2017/'),
+ pipeline=val_pipeline,
+ bbox_file='data/coco/person_detection_results/'
+ 'COCO_val2017_detections_AP_H_56_person.json',
+ test_mode=True))
+
+test_dataloader = val_dataloader
+
+# hooks
+default_hooks = dict(
+ checkpoint=dict(
+ save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
+custom_hooks = [
+ dict(
+ type=EMAHook,
+ ema_type=ExpMomentumEMA,
+ momentum=0.0002,
+ update_buffers=True,
+ priority=49),
+ dict(
+ type=PipelineSwitchHook,
+ switch_epoch=max_epochs - stage2_num_epochs,
+ switch_pipeline=train_pipeline_stage2)
+]
+
+# evaluators
+val_evaluator = dict(
+ type=CocoWholeBodyMetric,
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
+test_evaluator = val_evaluator
diff --git a/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw-x_8xb704-270e_cocktail13-256x192.py b/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw-x_8xb704-270e_cocktail13-256x192.py
new file mode 100644
index 0000000000..02cdfd4ee3
--- /dev/null
+++ b/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw-x_8xb704-270e_cocktail13-256x192.py
@@ -0,0 +1,607 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmengine.config import read_base
+
+with read_base():
+ from mmpose.configs._base_.default_runtime import * # noqa
+
+from albumentations.augmentations import Blur, CoarseDropout, MedianBlur
+from mmdet.engine.hooks import PipelineSwitchHook
+from mmengine.dataset import DefaultSampler
+from mmengine.hooks import EMAHook
+from mmengine.model import PretrainedInit
+from mmengine.optim import CosineAnnealingLR, LinearLR, OptimWrapper
+from torch.nn import SiLU, SyncBatchNorm
+from torch.optim import AdamW
+
+from mmpose.codecs import SimCCLabel
+from mmpose.datasets import (AicDataset, CocoWholeBodyDataset, COFWDataset,
+ CombinedDataset, CrowdPoseDataset,
+ Face300WDataset, GenerateTarget,
+ GetBBoxCenterScale, HalpeDataset,
+ HumanArt21Dataset, JhmdbDataset,
+ KeypointConverter, LapaDataset, LoadImage,
+ MpiiDataset, PackPoseInputs, PoseTrack18Dataset,
+ RandomFlip, RandomHalfBody, TopdownAffine,
+ UBody2dDataset, WFLWDataset)
+from mmpose.datasets.transforms.common_transforms import (
+ Albumentation, PhotometricDistortion, RandomBBoxTransform)
+from mmpose.engine.hooks import ExpMomentumEMA
+from mmpose.evaluation import CocoWholeBodyMetric
+from mmpose.models import (CSPNeXt, CSPNeXtPAFPN, KLDiscretLoss,
+ PoseDataPreprocessor, RTMWHead,
+ TopdownPoseEstimator)
+
+# common setting
+num_keypoints = 133
+input_size = (192, 256)
+
+# runtime
+max_epochs = 270
+stage2_num_epochs = 10
+base_lr = 5e-4
+train_batch_size = 704
+val_batch_size = 32
+
+train_cfg = dict(max_epochs=max_epochs, val_interval=10)
+randomness = dict(seed=21)
+
+# optimizer
+optim_wrapper = dict(
+ type=OptimWrapper,
+ optimizer=dict(type=AdamW, lr=base_lr, weight_decay=0.05),
+ clip_grad=dict(max_norm=35, norm_type=2),
+ paramwise_cfg=dict(
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
+
+# learning rate
+param_scheduler = [
+ dict(
+ type=LinearLR, start_factor=1.0e-5, by_epoch=False, begin=0, end=1000),
+ dict(
+ type=CosineAnnealingLR,
+ eta_min=base_lr * 0.05,
+ begin=max_epochs // 2,
+ end=max_epochs,
+ T_max=max_epochs // 2,
+ by_epoch=True,
+ convert_to_iter_based=True),
+]
+
+# automatically scaling LR based on the actual training batch size
+auto_scale_lr = dict(base_batch_size=5632)
+
+# codec settings
+codec = dict(
+ type=SimCCLabel,
+ input_size=input_size,
+ sigma=(4.9, 5.66),
+ simcc_split_ratio=2.0,
+ normalize=False,
+ use_dark=False)
+
+# model settings
+model = dict(
+ type=TopdownPoseEstimator,
+ data_preprocessor=dict(
+ type=PoseDataPreprocessor,
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True),
+ backbone=dict(
+ type=CSPNeXt,
+ arch='P5',
+ expand_ratio=0.5,
+ deepen_factor=1.33,
+ widen_factor=1.25,
+ channel_attention=True,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type=SiLU),
+ init_cfg=dict(
+ type=PretrainedInit,
+ prefix='backbone.',
+ checkpoint='https://download.openmmlab.com/mmpose/v1/'
+ 'wholebody_2d_keypoint/rtmpose/ubody/rtmpose-x_simcc-ucoco_pt-aic-coco_270e-256x192-05f5bcb7_20230822.pth' # noqa
+ )),
+ neck=dict(
+ type=CSPNeXtPAFPN,
+ in_channels=[320, 640, 1280],
+ out_channels=None,
+ out_indices=(
+ 1,
+ 2,
+ ),
+ num_csp_blocks=2,
+ expand_ratio=0.5,
+ norm_cfg=dict(type=SyncBatchNorm),
+ act_cfg=dict(type=SiLU, inplace=True)),
+ head=dict(
+ type=RTMWHead,
+ in_channels=1280,
+ out_channels=num_keypoints,
+ input_size=input_size,
+ in_featuremap_size=tuple([s // 32 for s in input_size]),
+ simcc_split_ratio=codec['simcc_split_ratio'],
+ final_layer_kernel_size=7,
+ gau_cfg=dict(
+ hidden_dims=256,
+ s=128,
+ expansion_factor=2,
+ dropout_rate=0.,
+ drop_path=0.,
+ act_fn=SiLU,
+ use_rel_bias=False,
+ pos_enc=False),
+ loss=dict(
+ type=KLDiscretLoss,
+ use_target_weight=True,
+ beta=10.,
+ label_softmax=True),
+ decoder=codec),
+ test_cfg=dict(flip_test=True))
+
+# base dataset settings
+dataset_type = CocoWholeBodyDataset
+data_mode = 'topdown'
+data_root = 'data/'
+
+backend_args = dict(backend='local')
+
+# pipelines
+train_pipeline = [
+ dict(type=LoadImage, backend_args=backend_args),
+ dict(type=GetBBoxCenterScale),
+ dict(type=RandomFlip, direction='horizontal'),
+ dict(type=RandomHalfBody),
+ dict(type=RandomBBoxTransform, scale_factor=[0.5, 1.5], rotate_factor=90),
+ dict(type=TopdownAffine, input_size=codec['input_size']),
+ dict(type=PhotometricDistortion),
+ dict(
+ type=Albumentation,
+ transforms=[
+ dict(type=Blur, p=0.1),
+ dict(type=MedianBlur, p=0.1),
+ dict(
+ type=CoarseDropout,
+ max_holes=1,
+ max_height=0.4,
+ max_width=0.4,
+ min_holes=1,
+ min_height=0.2,
+ min_width=0.2,
+ p=1.0),
+ ]),
+ dict(
+ type=GenerateTarget, encoder=codec, use_dataset_keypoint_weights=True),
+ dict(type=PackPoseInputs)
+]
+val_pipeline = [
+ dict(type=LoadImage, backend_args=backend_args),
+ dict(type=GetBBoxCenterScale),
+ dict(type=TopdownAffine, input_size=codec['input_size']),
+ dict(type=PackPoseInputs)
+]
+
+train_pipeline_stage2 = [
+ dict(type=LoadImage, backend_args=backend_args),
+ dict(type=GetBBoxCenterScale),
+ dict(type=RandomFlip, direction='horizontal'),
+ dict(type=RandomHalfBody),
+ dict(
+ type=RandomBBoxTransform,
+ shift_factor=0.,
+ scale_factor=[0.5, 1.5],
+ rotate_factor=90),
+ dict(type=TopdownAffine, input_size=codec['input_size']),
+ dict(
+ type=Albumentation,
+ transforms=[
+ dict(type=Blur, p=0.1),
+ dict(type=MedianBlur, p=0.1),
+ dict(
+ type=CoarseDropout,
+ max_holes=1,
+ max_height=0.4,
+ max_width=0.4,
+ min_holes=1,
+ min_height=0.2,
+ min_width=0.2,
+ p=0.5),
+ ]),
+ dict(
+ type=GenerateTarget, encoder=codec, use_dataset_keypoint_weights=True),
+ dict(type=PackPoseInputs)
+]
+
+# mapping
+
+aic_coco133 = [(0, 6), (1, 8), (2, 10), (3, 5), (4, 7), (5, 9), (6, 12),
+ (7, 14), (8, 16), (9, 11), (10, 13), (11, 15)]
+
+crowdpose_coco133 = [(0, 5), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (6, 11),
+ (7, 12), (8, 13), (9, 14), (10, 15), (11, 16)]
+
+mpii_coco133 = [
+ (0, 16),
+ (1, 14),
+ (2, 12),
+ (3, 11),
+ (4, 13),
+ (5, 15),
+ (8, 18),
+ (9, 17),
+ (10, 10),
+ (11, 8),
+ (12, 6),
+ (13, 5),
+ (14, 7),
+ (15, 9),
+]
+
+jhmdb_coco133 = [
+ (0, 18),
+ (2, 17),
+ (3, 6),
+ (4, 5),
+ (5, 12),
+ (6, 11),
+ (7, 8),
+ (8, 7),
+ (9, 14),
+ (10, 13),
+ (11, 10),
+ (12, 9),
+ (13, 16),
+ (14, 15),
+]
+
+halpe_coco133 = [(i, i)
+ for i in range(17)] + [(20, 17), (21, 20), (22, 18), (23, 21),
+ (24, 19),
+ (25, 22)] + [(i, i - 3)
+ for i in range(26, 136)]
+
+posetrack_coco133 = [
+ (0, 0),
+ (2, 17),
+ (3, 3),
+ (4, 4),
+ (5, 5),
+ (6, 6),
+ (7, 7),
+ (8, 8),
+ (9, 9),
+ (10, 10),
+ (11, 11),
+ (12, 12),
+ (13, 13),
+ (14, 14),
+ (15, 15),
+ (16, 16),
+]
+
+humanart_coco133 = [(i, i) for i in range(17)] + [(17, 99), (18, 120),
+ (19, 17), (20, 20)]
+
+# train datasets
+dataset_coco = dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='coco/annotations/coco_wholebody_train_v1.0.json',
+ data_prefix=dict(img='detection/coco/train2017/'),
+ pipeline=[],
+)
+
+dataset_aic = dict(
+ type=AicDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='aic/annotations/aic_train.json',
+ data_prefix=dict(img='pose/ai_challenge/ai_challenger_keypoint'
+ '_train_20170902/keypoint_train_images_20170902/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=aic_coco133)
+ ],
+)
+
+dataset_crowdpose = dict(
+ type=CrowdPoseDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='crowdpose/annotations/mmpose_crowdpose_trainval.json',
+ data_prefix=dict(img='pose/CrowdPose/images/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=crowdpose_coco133)
+ ],
+)
+
+dataset_mpii = dict(
+ type=MpiiDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='mpii/annotations/mpii_train.json',
+ data_prefix=dict(img='pose/MPI/images/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=mpii_coco133)
+ ],
+)
+
+dataset_jhmdb = dict(
+ type=JhmdbDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='jhmdb/annotations/Sub1_train.json',
+ data_prefix=dict(img='pose/JHMDB/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=jhmdb_coco133)
+ ],
+)
+
+dataset_halpe = dict(
+ type=HalpeDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='halpe/annotations/halpe_train_v1.json',
+ data_prefix=dict(img='pose/Halpe/hico_20160224_det/images/train2015'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=halpe_coco133)
+ ],
+)
+
+dataset_posetrack = dict(
+ type=PoseTrack18Dataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='posetrack18/annotations/posetrack18_train.json',
+ data_prefix=dict(img='pose/PoseChallenge2018/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=posetrack_coco133)
+ ],
+)
+
+dataset_humanart = dict(
+ type=HumanArt21Dataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='HumanArt/annotations/training_humanart.json',
+ filter_cfg=dict(scenes=['real_human']),
+ data_prefix=dict(img='pose/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=humanart_coco133)
+ ])
+
+ubody_scenes = [
+ 'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
+ 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
+ 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
+]
+
+ubody_datasets = []
+for scene in ubody_scenes:
+ each = dict(
+ type=UBody2dDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file=f'Ubody/annotations/{scene}/train_annotations.json',
+ data_prefix=dict(img='pose/UBody/images/'),
+ pipeline=[],
+ sample_interval=10)
+ ubody_datasets.append(each)
+
+dataset_ubody = dict(
+ type=CombinedDataset,
+ metainfo=dict(from_file='configs/_base_/datasets/ubody2d.py'),
+ datasets=ubody_datasets,
+ pipeline=[],
+ test_mode=False,
+)
+
+face_pipeline = [
+ dict(type=LoadImage, backend_args=backend_args),
+ dict(type=GetBBoxCenterScale),
+ dict(
+ type=RandomBBoxTransform,
+ shift_factor=0.,
+ scale_factor=[0.3, 0.5],
+ rotate_factor=0),
+ dict(type=TopdownAffine, input_size=codec['input_size']),
+]
+
+wflw_coco133 = [(i * 2, 20 + i)
+ for i in range(17)] + [(33 + i, 41 + i) for i in range(5)] + [
+ (42 + i, 46 + i) for i in range(5)
+ ] + [(51 + i, 50 + i)
+ for i in range(9)] + [(60, 59), (61, 60), (63, 61),
+ (64, 62), (65, 63), (67, 64),
+ (68, 65), (69, 66), (71, 67),
+ (72, 68), (73, 69),
+ (75, 70)] + [(76 + i, 71 + i)
+ for i in range(20)]
+dataset_wflw = dict(
+ type=WFLWDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='wflw/annotations/face_landmarks_wflw_train.json',
+ data_prefix=dict(img='pose/WFLW/images/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=wflw_coco133), *face_pipeline
+ ],
+)
+
+mapping_300w_coco133 = [(i, 20 + i) for i in range(68)]
+dataset_300w = dict(
+ type=Face300WDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='300w/annotations/face_landmarks_300w_train.json',
+ data_prefix=dict(img='pose/300w/images/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=mapping_300w_coco133), *face_pipeline
+ ],
+)
+
+cofw_coco133 = [(0, 41), (2, 45), (4, 43), (1, 47), (3, 49), (6, 45), (8, 59),
+ (10, 62), (9, 68), (11, 65), (18, 54), (19, 58), (20, 53),
+ (21, 56), (22, 71), (23, 77), (24, 74), (25, 85), (26, 89),
+ (27, 80), (28, 28)]
+dataset_cofw = dict(
+ type=COFWDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='cofw/annotations/cofw_train.json',
+ data_prefix=dict(img='pose/COFW/images/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=cofw_coco133), *face_pipeline
+ ],
+)
+
+lapa_coco133 = [(i * 2, 20 + i) for i in range(17)] + [
+ (33 + i, 41 + i) for i in range(5)
+] + [(42 + i, 46 + i) for i in range(5)] + [
+ (51 + i, 50 + i) for i in range(4)
+] + [(58 + i, 54 + i) for i in range(5)] + [(66, 59), (67, 60), (69, 61),
+ (70, 62), (71, 63), (73, 64),
+ (75, 65), (76, 66), (78, 67),
+ (79, 68), (80, 69),
+ (82, 70)] + [(84 + i, 71 + i)
+ for i in range(20)]
+dataset_lapa = dict(
+ type=LapaDataset,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='LaPa/annotations/lapa_trainval.json',
+ data_prefix=dict(img='pose/LaPa/'),
+ pipeline=[
+ dict(
+ type=KeypointConverter,
+ num_keypoints=num_keypoints,
+ mapping=lapa_coco133), *face_pipeline
+ ],
+)
+
+dataset_wb = dict(
+ type=CombinedDataset,
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=[dataset_coco, dataset_halpe, dataset_ubody],
+ pipeline=[],
+ test_mode=False,
+)
+
+dataset_body = dict(
+ type=CombinedDataset,
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=[
+ dataset_aic,
+ dataset_crowdpose,
+ dataset_mpii,
+ dataset_jhmdb,
+ dataset_posetrack,
+ dataset_humanart,
+ ],
+ pipeline=[],
+ test_mode=False,
+)
+
+dataset_face = dict(
+ type=CombinedDataset,
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=[
+ dataset_wflw,
+ dataset_300w,
+ dataset_cofw,
+ dataset_lapa,
+ ],
+ pipeline=[],
+ test_mode=False,
+)
+
+train_datasets = [
+ dataset_wb,
+ dataset_body,
+ dataset_face,
+]
+
+# data loaders
+train_dataloader = dict(
+ batch_size=train_batch_size,
+ num_workers=10,
+ pin_memory=True,
+ persistent_workers=True,
+ sampler=dict(type=DefaultSampler, shuffle=True),
+ dataset=dict(
+ type=CombinedDataset,
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=train_datasets,
+ pipeline=train_pipeline,
+ test_mode=False,
+ ))
+
+val_dataloader = dict(
+ batch_size=val_batch_size,
+ num_workers=4,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type=DefaultSampler, shuffle=False, round_up=False),
+ dataset=dict(
+ type=CocoWholeBodyDataset,
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
+ data_prefix=dict(img='data/detection/coco/val2017/'),
+ pipeline=val_pipeline,
+ bbox_file='data/coco/person_detection_results/'
+ 'COCO_val2017_detections_AP_H_56_person.json',
+ test_mode=True))
+
+test_dataloader = val_dataloader
+
+# hooks
+default_hooks = dict(
+ checkpoint=dict(
+ save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
+custom_hooks = [
+ dict(
+ type=EMAHook,
+ ema_type=ExpMomentumEMA,
+ momentum=0.0002,
+ update_buffers=True,
+ priority=49),
+ dict(
+ type=PipelineSwitchHook,
+ switch_epoch=max_epochs - stage2_num_epochs,
+ switch_pipeline=train_pipeline_stage2)
+]
+
+# evaluators
+val_evaluator = dict(
+ type=CocoWholeBodyMetric,
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
+test_evaluator = val_evaluator
diff --git a/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw_cocktail13.md b/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw_cocktail13.md
new file mode 100644
index 0000000000..54e75383ba
--- /dev/null
+++ b/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw_cocktail13.md
@@ -0,0 +1,76 @@
+
+
+
+RTMPose (arXiv'2023)
+
+```bibtex
+@misc{https://doi.org/10.48550/arxiv.2303.07399,
+ doi = {10.48550/ARXIV.2303.07399},
+ url = {https://arxiv.org/abs/2303.07399},
+ author = {Jiang, Tao and Lu, Peng and Zhang, Li and Ma, Ningsheng and Han, Rui and Lyu, Chengqi and Li, Yining and Chen, Kai},
+ keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
+ title = {RTMPose: Real-Time Multi-Person Pose Estimation based on MMPose},
+ publisher = {arXiv},
+ year = {2023},
+ copyright = {Creative Commons Attribution 4.0 International}
+}
+
+```
+
+
+
+
+
+
+RTMDet (arXiv'2022)
+
+```bibtex
+@misc{lyu2022rtmdet,
+ title={RTMDet: An Empirical Study of Designing Real-Time Object Detectors},
+ author={Chengqi Lyu and Wenwei Zhang and Haian Huang and Yue Zhou and Yudong Wang and Yanyi Liu and Shilong Zhang and Kai Chen},
+ year={2022},
+ eprint={2212.07784},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
+
+
+
+
+
+
+COCO-WholeBody (ECCV'2020)
+
+```bibtex
+@inproceedings{jin2020whole,
+ title={Whole-Body Human Pose Estimation in the Wild},
+ author={Jin, Sheng and Xu, Lumin and Xu, Jin and Wang, Can and Liu, Wentao and Qian, Chen and Ouyang, Wanli and Luo, Ping},
+ booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
+ year={2020}
+}
+```
+
+
+
+- `Cocktail13` denotes model trained on 13 public datasets:
+ - [AI Challenger](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#aic)
+ - [CrowdPose](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#crowdpose)
+ - [MPII](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#mpii)
+ - [sub-JHMDB](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#sub-jhmdb-dataset)
+ - [Halpe](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_wholebody_keypoint.html#halpe)
+ - [PoseTrack18](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#posetrack18)
+ - [COCO-Wholebody](https://github.com/jin-s13/COCO-WholeBody/)
+ - [UBody](https://github.com/IDEA-Research/OSX)
+ - [Human-Art](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#human-art-dataset)
+ - [WFLW](https://wywu.github.io/projects/LAB/WFLW.html)
+ - [300W](https://ibug.doc.ic.ac.uk/resources/300-W/)
+ - [COFW](http://www.vision.caltech.edu/xpburgos/ICCV13/)
+ - [LaPa](https://github.com/JDAI-CV/lapa-dataset)
+
+Results on COCO-WholeBody v1.0 val with detector having human AP of 56.4 on COCO val2017 dataset
+
+| Arch | Input Size | Body AP | Body AR | Foot AP | Foot AR | Face AP | Face AR | Hand AP | Hand AR | Whole AP | Whole AR | ckpt | log |
+| :-------------------------------------- | :--------: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :------: | :------: | :--------------------------------------: | :-------------------------------------: |
+| [rtmw-x](/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw-x_8xb320-270e_cocktail13-256x192.py) | 256x192 | 0.753 | 0.815 | 0.773 | 0.869 | 0.843 | 0.894 | 0.602 | 0.703 | 0.672 | 0.754 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmw/rtmw-x_simcc-cocktail13_pt-ucoco_270e-256x192-fbef0d61_20230925.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmw/rtmw-x_simcc-cocktail13_pt-ucoco_270e-256x192-fbef0d61_20230925.json) |
+| [rtmw-x](/configs/wholebody_2d_keypoint/rtmpose/cocktail13/rtmw-x_8xb320-270e_cocktail13-384x288.py) | 384x288 | 0.764 | 0.825 | 0.791 | 0.883 | 0.882 | 0.922 | 0.654 | 0.744 | 0.702 | 0.779 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmw/rtmw-x_simcc-cocktail13_pt-ucoco_270e-384x288-0949e3a9_20230925.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmw/rtmw-x_simcc-cocktail13_pt-ucoco_270e-384x288-0949e3a9_20230925.json) |
diff --git a/mmpose/datasets/transforms/__init__.py b/mmpose/datasets/transforms/__init__.py
index b07a18885c..54ad7f3159 100644
--- a/mmpose/datasets/transforms/__init__.py
+++ b/mmpose/datasets/transforms/__init__.py
@@ -6,7 +6,7 @@
GenerateTarget, GetBBoxCenterScale,
PhotometricDistortion, RandomBBoxTransform,
RandomFlip, RandomHalfBody, YOLOXHSVRandomAug)
-from .converting import KeypointConverter
+from .converting import KeypointConverter, SingleHandConverter
from .formatting import PackPoseInputs
from .hand_transforms import HandRandomFlip
from .loading import LoadImage
@@ -21,5 +21,6 @@
'BottomupGetHeatmapMask', 'BottomupRandomAffine', 'BottomupResize',
'GenerateTarget', 'KeypointConverter', 'RandomFlipAroundRoot',
'FilterAnnotations', 'YOLOXHSVRandomAug', 'YOLOXMixUp', 'Mosaic',
- 'BottomupRandomCrop', 'BottomupRandomChoiceResize', 'HandRandomFlip'
+ 'BottomupRandomCrop', 'BottomupRandomChoiceResize', 'HandRandomFlip',
+ 'SingleHandConverter'
]
diff --git a/mmpose/datasets/transforms/common_transforms.py b/mmpose/datasets/transforms/common_transforms.py
index 98aed11683..33f9c560c0 100644
--- a/mmpose/datasets/transforms/common_transforms.py
+++ b/mmpose/datasets/transforms/common_transforms.py
@@ -503,8 +503,13 @@ def _get_transform_params(self, num_bboxes: int) -> Tuple:
- scale (np.ndarray): Scaling factor of each bbox in shape (n, 1)
- rotate (np.ndarray): Rotation degree of each bbox in shape (n,)
"""
+ random_v = self._truncnorm(size=(num_bboxes, 4))
+ offset_v = random_v[:, :2]
+ scale_v = random_v[:, 2:3]
+ rotate_v = random_v[:, 3]
+
# Get shift parameters
- offset = self._truncnorm(size=(num_bboxes, 2)) * self.shift_factor
+ offset = offset_v * self.shift_factor
offset = np.where(
np.random.rand(num_bboxes, 1) < self.shift_prob, offset, 0.)
@@ -512,12 +517,12 @@ def _get_transform_params(self, num_bboxes: int) -> Tuple:
scale_min, scale_max = self.scale_factor
mu = (scale_max + scale_min) * 0.5
sigma = (scale_max - scale_min) * 0.5
- scale = self._truncnorm(size=(num_bboxes, 1)) * sigma + mu
+ scale = scale_v * sigma + mu
scale = np.where(
np.random.rand(num_bboxes, 1) < self.scale_prob, scale, 1.)
# Get rotation parameters
- rotate = self._truncnorm(size=(num_bboxes, )) * self.rotate_factor
+ rotate = rotate_v * self.rotate_factor
rotate = np.where(
np.random.rand(num_bboxes) < self.rotate_prob, rotate, 0.)
diff --git a/mmpose/datasets/transforms/converting.py b/mmpose/datasets/transforms/converting.py
index 1906f16972..fc8b335e26 100644
--- a/mmpose/datasets/transforms/converting.py
+++ b/mmpose/datasets/transforms/converting.py
@@ -62,7 +62,10 @@ def __init__(self, num_keypoints: int,
int]]]):
self.num_keypoints = num_keypoints
self.mapping = mapping
- source_index, target_index = zip(*mapping)
+ if len(mapping):
+ source_index, target_index = zip(*mapping)
+ else:
+ source_index, target_index = [], []
src1, src2 = [], []
interpolation = False
@@ -89,6 +92,9 @@ def __init__(self, num_keypoints: int,
def transform(self, results: dict) -> dict:
"""Transforms the keypoint results to match the target keypoints."""
num_instances = results['keypoints'].shape[0]
+ if len(results['keypoints_visible'].shape) > 2:
+ results['keypoints_visible'] = results['keypoints_visible'][:, :,
+ 0]
# Initialize output arrays
keypoints = np.zeros((num_instances, self.num_keypoints, 3))
@@ -186,3 +192,83 @@ def __repr__(self) -> str:
repr_str += f'(num_keypoints={self.num_keypoints}, '\
f'mapping={self.mapping})'
return repr_str
+
+
+@TRANSFORMS.register_module()
+class SingleHandConverter(BaseTransform):
+ """Mapping a single hand keypoints into double hands according to the given
+ mapping and hand type.
+
+ Required Keys:
+
+ - keypoints
+ - keypoints_visible
+ - hand_type
+
+ Modified Keys:
+
+ - keypoints
+ - keypoints_visible
+
+ Args:
+ num_keypoints (int): The number of keypoints in target dataset.
+ left_hand_mapping (list): A list containing mapping indexes. Each
+ element has format (source_index, target_index)
+ right_hand_mapping (list): A list containing mapping indexes. Each
+ element has format (source_index, target_index)
+
+ Example:
+ >>> import numpy as np
+ >>> self = SingleHandConverter(
+ >>> num_keypoints=42,
+ >>> left_hand_mapping=[
+ >>> (0, 0), (1, 1), (2, 2), (3, 3)
+ >>> ],
+ >>> right_hand_mapping=[
+ >>> (0, 21), (1, 22), (2, 23), (3, 24)
+ >>> ])
+ >>> results = dict(
+ >>> keypoints=np.arange(84).reshape(2, 21, 2),
+ >>> keypoints_visible=np.arange(84).reshape(2, 21, 2) % 2,
+ >>> hand_type=np.array([[0, 1], [1, 0]]))
+ >>> results = self(results)
+ """
+
+ def __init__(self, num_keypoints: int,
+ left_hand_mapping: Union[List[Tuple[int, int]],
+ List[Tuple[Tuple, int]]],
+ right_hand_mapping: Union[List[Tuple[int, int]],
+ List[Tuple[Tuple, int]]]):
+ self.num_keypoints = num_keypoints
+ self.left_hand_converter = KeypointConverter(num_keypoints,
+ left_hand_mapping)
+ self.right_hand_converter = KeypointConverter(num_keypoints,
+ right_hand_mapping)
+
+ def transform(self, results: dict) -> dict:
+ """Transforms the keypoint results to match the target keypoints."""
+ assert 'hand_type' in results, (
+ 'hand_type should be provided in results')
+ hand_type = results['hand_type']
+
+ if np.sum(hand_type - [[0, 1]]) <= 1e-6:
+ # left hand
+ results = self.left_hand_converter(results)
+ elif np.sum(hand_type - [[1, 0]]) <= 1e-6:
+ results = self.right_hand_converter(results)
+ else:
+ raise ValueError('hand_type should be left or right')
+
+ return results
+
+ def __repr__(self) -> str:
+ """print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_keypoints={self.num_keypoints}, '\
+ f'left_hand_converter={self.left_hand_converter}, '\
+ f'right_hand_converter={self.right_hand_converter})'
+ return repr_str
diff --git a/mmpose/models/heads/__init__.py b/mmpose/models/heads/__init__.py
index b1258f53cc..e4b499ad2b 100644
--- a/mmpose/models/heads/__init__.py
+++ b/mmpose/models/heads/__init__.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_head import BaseHead
-from .coord_cls_heads import RTMCCHead, SimCCHead
+from .coord_cls_heads import RTMCCHead, RTMWHead, SimCCHead
from .heatmap_heads import (AssociativeEmbeddingHead, CIDHead, CPMHead,
HeatmapHead, InternetHead, MSPNHead, ViPNASHead)
from .hybrid_heads import DEKRHead, VisPredictHead
@@ -16,5 +16,5 @@
'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'VisPredictHead',
'CIDHead', 'RTMCCHead', 'TemporalRegressionHead',
'TrajectoryRegressionHead', 'MotionRegressionHead', 'EDPoseHead',
- 'InternetHead'
+ 'InternetHead', 'RTMWHead'
]
diff --git a/mmpose/models/heads/coord_cls_heads/__init__.py b/mmpose/models/heads/coord_cls_heads/__init__.py
index 104ff91308..6a4e51c4d7 100644
--- a/mmpose/models/heads/coord_cls_heads/__init__.py
+++ b/mmpose/models/heads/coord_cls_heads/__init__.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .rtmcc_head import RTMCCHead
+from .rtmw_head import RTMWHead
from .simcc_head import SimCCHead
-__all__ = ['SimCCHead', 'RTMCCHead']
+__all__ = ['SimCCHead', 'RTMCCHead', 'RTMWHead']
diff --git a/mmpose/models/heads/coord_cls_heads/rtmw_head.py b/mmpose/models/heads/coord_cls_heads/rtmw_head.py
new file mode 100644
index 0000000000..7111f90446
--- /dev/null
+++ b/mmpose/models/heads/coord_cls_heads/rtmw_head.py
@@ -0,0 +1,337 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Optional, Sequence, Tuple, Union
+
+import torch
+from mmcv.cnn import ConvModule
+from mmengine.dist import get_dist_info
+from mmengine.structures import PixelData
+from torch import Tensor, nn
+
+from mmpose.codecs.utils import get_simcc_normalized
+from mmpose.evaluation.functional import simcc_pck_accuracy
+from mmpose.models.utils.rtmcc_block import RTMCCBlock, ScaleNorm
+from mmpose.models.utils.tta import flip_vectors
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
+ OptSampleList)
+from ..base_head import BaseHead
+
+OptIntSeq = Optional[Sequence[int]]
+
+
+@MODELS.register_module()
+class RTMWHead(BaseHead):
+ """Top-down head introduced in RTMPose-Wholebody (2023).
+
+ Args:
+ in_channels (int | sequence[int]): Number of channels in the input
+ feature map.
+ out_channels (int): Number of channels in the output heatmap.
+ input_size (tuple): Size of input image in shape [w, h].
+ in_featuremap_size (int | sequence[int]): Size of input feature map.
+ simcc_split_ratio (float): Split ratio of pixels.
+ Default: 2.0.
+ final_layer_kernel_size (int): Kernel size of the convolutional layer.
+ Default: 1.
+ gau_cfg (Config): Config dict for the Gated Attention Unit.
+ Default: dict(
+ hidden_dims=256,
+ s=128,
+ expansion_factor=2,
+ dropout_rate=0.,
+ drop_path=0.,
+ act_fn='ReLU',
+ use_rel_bias=False,
+ pos_enc=False).
+ loss (Config): Config of the keypoint loss. Defaults to use
+ :class:`KLDiscretLoss`
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+ """
+
+ def __init__(
+ self,
+ in_channels: Union[int, Sequence[int]],
+ out_channels: int,
+ input_size: Tuple[int, int],
+ in_featuremap_size: Tuple[int, int],
+ simcc_split_ratio: float = 2.0,
+ final_layer_kernel_size: int = 1,
+ gau_cfg: ConfigType = dict(
+ hidden_dims=256,
+ s=128,
+ expansion_factor=2,
+ dropout_rate=0.,
+ drop_path=0.,
+ act_fn='ReLU',
+ use_rel_bias=False,
+ pos_enc=False),
+ loss: ConfigType = dict(type='KLDiscretLoss', use_target_weight=True),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None,
+ ):
+
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+
+ super().__init__(init_cfg)
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.input_size = input_size
+ self.in_featuremap_size = in_featuremap_size
+ self.simcc_split_ratio = simcc_split_ratio
+
+ self.loss_module = MODELS.build(loss)
+ if decoder is not None:
+ self.decoder = KEYPOINT_CODECS.build(decoder)
+ else:
+ self.decoder = None
+
+ if isinstance(in_channels, (tuple, list)):
+ raise ValueError(
+ f'{self.__class__.__name__} does not support selecting '
+ 'multiple input features.')
+
+ # Define SimCC layers
+ flatten_dims = self.in_featuremap_size[0] * self.in_featuremap_size[1]
+
+ ps = 2
+ self.ps = nn.PixelShuffle(ps)
+ self.conv_dec = ConvModule(
+ in_channels // ps**2,
+ in_channels // 4,
+ kernel_size=final_layer_kernel_size,
+ stride=1,
+ padding=final_layer_kernel_size // 2,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='ReLU'))
+
+ self.final_layer = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=final_layer_kernel_size,
+ stride=1,
+ padding=final_layer_kernel_size // 2,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='ReLU'))
+ self.final_layer2 = ConvModule(
+ in_channels // ps + in_channels // 4,
+ out_channels,
+ kernel_size=final_layer_kernel_size,
+ stride=1,
+ padding=final_layer_kernel_size // 2,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='ReLU'))
+
+ self.mlp = nn.Sequential(
+ ScaleNorm(flatten_dims),
+ nn.Linear(flatten_dims, gau_cfg['hidden_dims'] // 2, bias=False))
+
+ self.mlp2 = nn.Sequential(
+ ScaleNorm(flatten_dims * ps**2),
+ nn.Linear(
+ flatten_dims * ps**2, gau_cfg['hidden_dims'] // 2, bias=False))
+
+ W = int(self.input_size[0] * self.simcc_split_ratio)
+ H = int(self.input_size[1] * self.simcc_split_ratio)
+
+ self.gau = RTMCCBlock(
+ self.out_channels,
+ gau_cfg['hidden_dims'],
+ gau_cfg['hidden_dims'],
+ s=gau_cfg['s'],
+ expansion_factor=gau_cfg['expansion_factor'],
+ dropout_rate=gau_cfg['dropout_rate'],
+ drop_path=gau_cfg['drop_path'],
+ attn_type='self-attn',
+ act_fn=gau_cfg['act_fn'],
+ use_rel_bias=gau_cfg['use_rel_bias'],
+ pos_enc=gau_cfg['pos_enc'])
+
+ self.cls_x = nn.Linear(gau_cfg['hidden_dims'], W, bias=False)
+ self.cls_y = nn.Linear(gau_cfg['hidden_dims'], H, bias=False)
+
+ def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
+ """Forward the network.
+
+ The input is the featuremap extracted by backbone and the
+ output is the simcc representation.
+
+ Args:
+ feats (Tuple[Tensor]): Multi scale feature maps.
+
+ Returns:
+ pred_x (Tensor): 1d representation of x.
+ pred_y (Tensor): 1d representation of y.
+ """
+ # enc_b n / 2, h, w
+ # enc_t n, h, w
+ enc_b, enc_t = feats
+
+ feats_t = self.final_layer(enc_t)
+ feats_t = torch.flatten(feats_t, 2)
+ feats_t = self.mlp(feats_t)
+
+ dec_t = self.ps(enc_t)
+ dec_t = self.conv_dec(dec_t)
+ enc_b = torch.cat([dec_t, enc_b], dim=1)
+
+ feats_b = self.final_layer2(enc_b)
+ feats_b = torch.flatten(feats_b, 2)
+ feats_b = self.mlp2(feats_b)
+
+ feats = torch.cat([feats_t, feats_b], dim=2)
+
+ feats = self.gau(feats)
+
+ pred_x = self.cls_x(feats)
+ pred_y = self.cls_y(feats)
+
+ return pred_x, pred_y
+
+ def predict(
+ self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ test_cfg: OptConfigType = {},
+ ) -> InstanceList:
+ """Predict results from features.
+
+ Args:
+ feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
+ features (or multiple multi-stage features in TTA)
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ test_cfg (dict): The runtime config for testing process. Defaults
+ to {}
+
+ Returns:
+ List[InstanceData]: The pose predictions, each contains
+ the following fields:
+ - keypoints (np.ndarray): predicted keypoint coordinates in
+ shape (num_instances, K, D) where K is the keypoint number
+ and D is the keypoint dimension
+ - keypoint_scores (np.ndarray): predicted keypoint scores in
+ shape (num_instances, K)
+ - keypoint_x_labels (np.ndarray, optional): The predicted 1-D
+ intensity distribution in the x direction
+ - keypoint_y_labels (np.ndarray, optional): The predicted 1-D
+ intensity distribution in the y direction
+ """
+
+ if test_cfg.get('flip_test', False):
+ # TTA: flip test -> feats = [orig, flipped]
+ assert isinstance(feats, list) and len(feats) == 2
+ flip_indices = batch_data_samples[0].metainfo['flip_indices']
+ _feats, _feats_flip = feats
+
+ _batch_pred_x, _batch_pred_y = self.forward(_feats)
+
+ _batch_pred_x_flip, _batch_pred_y_flip = self.forward(_feats_flip)
+ _batch_pred_x_flip, _batch_pred_y_flip = flip_vectors(
+ _batch_pred_x_flip,
+ _batch_pred_y_flip,
+ flip_indices=flip_indices)
+
+ batch_pred_x = (_batch_pred_x + _batch_pred_x_flip) * 0.5
+ batch_pred_y = (_batch_pred_y + _batch_pred_y_flip) * 0.5
+ else:
+ batch_pred_x, batch_pred_y = self.forward(feats)
+
+ preds = self.decode((batch_pred_x, batch_pred_y))
+
+ if test_cfg.get('output_heatmaps', False):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ warnings.warn('The predicted simcc values are normalized for '
+ 'visualization. This may cause discrepancy '
+ 'between the keypoint scores and the 1D heatmaps'
+ '.')
+
+ # normalize the predicted 1d distribution
+ batch_pred_x = get_simcc_normalized(batch_pred_x)
+ batch_pred_y = get_simcc_normalized(batch_pred_y)
+
+ B, K, _ = batch_pred_x.shape
+ # B, K, Wx -> B, K, Wx, 1
+ x = batch_pred_x.reshape(B, K, 1, -1)
+ # B, K, Wy -> B, K, 1, Wy
+ y = batch_pred_y.reshape(B, K, -1, 1)
+ # B, K, Wx, Wy
+ batch_heatmaps = torch.matmul(y, x)
+ pred_fields = [
+ PixelData(heatmaps=hm) for hm in batch_heatmaps.detach()
+ ]
+
+ for pred_instances, pred_x, pred_y in zip(preds,
+ to_numpy(batch_pred_x),
+ to_numpy(batch_pred_y)):
+
+ pred_instances.keypoint_x_labels = pred_x[None]
+ pred_instances.keypoint_y_labels = pred_y[None]
+
+ return preds, pred_fields
+ else:
+ return preds
+
+ def loss(
+ self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: OptConfigType = {},
+ ) -> dict:
+ """Calculate losses from a batch of inputs and data samples."""
+
+ pred_x, pred_y = self.forward(feats)
+
+ gt_x = torch.cat([
+ d.gt_instance_labels.keypoint_x_labels for d in batch_data_samples
+ ],
+ dim=0)
+ gt_y = torch.cat([
+ d.gt_instance_labels.keypoint_y_labels for d in batch_data_samples
+ ],
+ dim=0)
+ keypoint_weights = torch.cat(
+ [
+ d.gt_instance_labels.keypoint_weights
+ for d in batch_data_samples
+ ],
+ dim=0,
+ )
+
+ pred_simcc = (pred_x, pred_y)
+ gt_simcc = (gt_x, gt_y)
+
+ # calculate losses
+ losses = dict()
+ loss = self.loss_module(pred_simcc, gt_simcc, keypoint_weights)
+
+ losses.update(loss_kpt=loss)
+
+ # calculate accuracy
+ _, avg_acc, _ = simcc_pck_accuracy(
+ output=to_numpy(pred_simcc),
+ target=to_numpy(gt_simcc),
+ simcc_split_ratio=self.simcc_split_ratio,
+ mask=to_numpy(keypoint_weights) > 0,
+ )
+
+ acc_pose = torch.tensor(avg_acc, device=gt_x.device)
+ losses.update(acc_pose=acc_pose)
+
+ return losses
+
+ @property
+ def default_init_cfg(self):
+ init_cfg = [
+ dict(type='Normal', layer=['Conv2d'], std=0.001),
+ dict(type='Constant', layer='BatchNorm2d', val=1),
+ dict(type='Normal', layer=['Linear'], std=0.01, bias=0),
+ ]
+ return init_cfg
diff --git a/mmpose/models/necks/__init__.py b/mmpose/models/necks/__init__.py
index c9d14cefc8..d4b4f51308 100644
--- a/mmpose/models/necks/__init__.py
+++ b/mmpose/models/necks/__init__.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .channel_mapper import ChannelMapper
+from .cspnext_pafpn import CSPNeXtPAFPN
from .fmap_proc_neck import FeatureMapProcessor
from .fpn import FPN
from .gap_neck import GlobalAveragePooling
@@ -8,5 +9,5 @@
__all__ = [
'GlobalAveragePooling', 'PoseWarperNeck', 'FPN', 'FeatureMapProcessor',
- 'ChannelMapper', 'YOLOXPAFPN'
+ 'ChannelMapper', 'YOLOXPAFPN', 'CSPNeXtPAFPN'
]
diff --git a/mmpose/models/necks/cspnext_pafpn.py b/mmpose/models/necks/cspnext_pafpn.py
new file mode 100644
index 0000000000..35f4dc2f10
--- /dev/null
+++ b/mmpose/models/necks/cspnext_pafpn.py
@@ -0,0 +1,187 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Sequence, Tuple
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+from mmengine.model import BaseModule
+from torch import Tensor
+
+from mmpose.registry import MODELS
+from mmpose.utils.typing import ConfigType, OptMultiConfig
+from ..utils import CSPLayer
+
+
+@MODELS.register_module()
+class CSPNeXtPAFPN(BaseModule):
+ """Path Aggregation Network with CSPNeXt blocks. Modified from RTMDet.
+
+ Args:
+ in_channels (Sequence[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ out_indices (Sequence[int]): Output from which stages.
+ num_csp_blocks (int): Number of bottlenecks in CSPLayer.
+ Defaults to 3.
+ use_depthwise (bool): Whether to use depthwise separable convolution in
+ blocks. Defaults to False.
+ expand_ratio (float): Ratio to adjust the number of channels of the
+ hidden layer. Default: 0.5
+ upsample_cfg (dict): Config dict for interpolate layer.
+ Default: `dict(scale_factor=2, mode='nearest')`
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN')
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='Swish')
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(
+ self,
+ in_channels: Sequence[int],
+ out_channels: int,
+ out_indices=(
+ 0,
+ 1,
+ 2,
+ ),
+ num_csp_blocks: int = 3,
+ use_depthwise: bool = False,
+ expand_ratio: float = 0.5,
+ upsample_cfg: ConfigType = dict(scale_factor=2, mode='nearest'),
+ conv_cfg: bool = None,
+ norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
+ act_cfg: ConfigType = dict(type='Swish'),
+ init_cfg: OptMultiConfig = dict(
+ type='Kaiming',
+ layer='Conv2d',
+ a=math.sqrt(5),
+ distribution='uniform',
+ mode='fan_in',
+ nonlinearity='leaky_relu')
+ ) -> None:
+ super().__init__(init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.out_indices = out_indices
+
+ conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
+
+ # build top-down blocks
+ self.upsample = nn.Upsample(**upsample_cfg)
+ self.reduce_layers = nn.ModuleList()
+ self.top_down_blocks = nn.ModuleList()
+ for idx in range(len(in_channels) - 1, 0, -1):
+ self.reduce_layers.append(
+ ConvModule(
+ in_channels[idx],
+ in_channels[idx - 1],
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.top_down_blocks.append(
+ CSPLayer(
+ in_channels[idx - 1] * 2,
+ in_channels[idx - 1],
+ num_blocks=num_csp_blocks,
+ add_identity=False,
+ use_depthwise=use_depthwise,
+ use_cspnext_block=True,
+ expand_ratio=expand_ratio,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ # build bottom-up blocks
+ self.downsamples = nn.ModuleList()
+ self.bottom_up_blocks = nn.ModuleList()
+ for idx in range(len(in_channels) - 1):
+ self.downsamples.append(
+ conv(
+ in_channels[idx],
+ in_channels[idx],
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.bottom_up_blocks.append(
+ CSPLayer(
+ in_channels[idx] * 2,
+ in_channels[idx + 1],
+ num_blocks=num_csp_blocks,
+ add_identity=False,
+ use_depthwise=use_depthwise,
+ use_cspnext_block=True,
+ expand_ratio=expand_ratio,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ if self.out_channels is not None:
+ self.out_convs = nn.ModuleList()
+ for i in range(len(in_channels)):
+ self.out_convs.append(
+ conv(
+ in_channels[i],
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.out_convs = conv(
+ in_channels[-1],
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]:
+ """
+ Args:
+ inputs (tuple[Tensor]): input features.
+
+ Returns:
+ tuple[Tensor]: YOLOXPAFPN features.
+ """
+ assert len(inputs) == len(self.in_channels)
+
+ # top-down path
+ inner_outs = [inputs[-1]]
+ for idx in range(len(self.in_channels) - 1, 0, -1):
+ feat_high = inner_outs[0]
+ feat_low = inputs[idx - 1]
+ feat_high = self.reduce_layers[len(self.in_channels) - 1 - idx](
+ feat_high)
+ inner_outs[0] = feat_high
+
+ upsample_feat = self.upsample(feat_high)
+
+ inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
+ torch.cat([upsample_feat, feat_low], 1))
+ inner_outs.insert(0, inner_out)
+
+ # bottom-up path
+ outs = [inner_outs[0]]
+ for idx in range(len(self.in_channels) - 1):
+ feat_low = outs[-1]
+ feat_high = inner_outs[idx + 1]
+ downsample_feat = self.downsamples[idx](feat_low)
+ out = self.bottom_up_blocks[idx](
+ torch.cat([downsample_feat, feat_high], 1))
+ outs.append(out)
+
+ if self.out_channels is not None:
+ # out convs
+ for idx, conv in enumerate(self.out_convs):
+ outs[idx] = conv(outs[idx])
+
+ return tuple([outs[i] for i in self.out_indices])
diff --git a/projects/rtmpose/README.md b/projects/rtmpose/README.md
index 27a8a90144..e2412c0014 100644
--- a/projects/rtmpose/README.md
+++ b/projects/rtmpose/README.md
@@ -44,6 +44,9 @@ ______________________________________________________________________
## 🥳 🚀 What's New [🔝](#-table-of-contents)
+- Sep. 2023:
+ - Add RTMW models trained on combined datasets. The alpha version of RTMW-x model achieves 70.2 mAP on COCO-Wholebody val set. The technical report will be released soon.
+ - Add YOLOX and RTMDet models trained on HumanArt dataset.
- Aug. 2023:
- Support distilled 133-keypoint WholeBody models powered by [DWPose](https://github.com/IDEA-Research/DWPose/tree/main).
- You can try DWPose/RTMPose with [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) now! Just update your sd-webui-controlnet >= v1.1237, then choose `dw_openpose_full` as preprocessor.
@@ -219,15 +222,6 @@ Feel free to join our community group for more help:
- RTMPose for Human-Centric Artificial Scenes is supported by [Human-Art](https://github.com/IDEA-Research/HumanArt)
-
-Pose Estimators:
-
-| Config | Input Size | AP
(Human-Art GT) | Params
(M) | FLOPS
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | ncnn-FP16-Latency
(ms)
(Snapdragon 865) | Download |
-| :-----------------------------------------------------------------------------: | :--------: | :-----------------------: | :----------------: | :---------------: | :-----------------------------------------: | :------------------------------------------------: | :-----------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
-| [RTMPose-t\*](./rtmpose/body_2d_keypoint/rtmpose-t_8xb256-420e_coco-256x192.py) | 256x192 | 65.5 | 3.34 | 0.36 | 3.20 | 1.06 | 9.02 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.zip) |
-| [RTMPose-s\*](./rtmpose/body_2d_keypoint/rtmpose-s_8xb256-420e_coco-256x192.py) | 256x192 | 69.8 | 5.47 | 0.68 | 4.48 | 1.39 | 13.89 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.zip) |
-| [RTMPose-m\*](./rtmpose/body_2d_keypoint/rtmpose-m_8xb256-420e_coco-256x192.py) | 256x192 | 72.8 | 13.59 | 1.93 | 11.06 | 2.29 | 26.44 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.zip) |
-| [RTMPose-l\*](./rtmpose/body_2d_keypoint/rtmpose-l_8xb256-420e_coco-256x192.py) | 256x192 | 75.3 | 27.66 | 4.16 | 18.85 | 3.46 | 45.37 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.zip) |
-
Detetors:
| Detection Config | Input Size | Model AP
(OneHand10K) | Flops
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | Download |
@@ -241,6 +235,15 @@ Detetors:
| [YOLOX-l](./yolox/humanart/yolox_l_8xb8-300e_humanart.py) | 640x640 | 60.2 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_l_8xb8-300e_humanart-ce1d7a62.pth) |
| [YOLOX-x](./yolox/humanart/yolox_x_8xb8-300e_humanart.py) | 640x640 | 61.3 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_x_8xb8-300e_humanart-a39d44ed.pth) |
+Pose Estimators:
+
+| Config | Input Size | AP
(Human-Art GT) | Params
(M) | FLOPS
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | ncnn-FP16-Latency
(ms)
(Snapdragon 865) | Download |
+| :-----------------------------------------------------------------------------: | :--------: | :-----------------------: | :----------------: | :---------------: | :-----------------------------------------: | :------------------------------------------------: | :-----------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [RTMPose-t\*](./rtmpose/body_2d_keypoint/rtmpose-t_8xb256-420e_coco-256x192.py) | 256x192 | 65.5 | 3.34 | 0.36 | 3.20 | 1.06 | 9.02 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.zip) |
+| [RTMPose-s\*](./rtmpose/body_2d_keypoint/rtmpose-s_8xb256-420e_coco-256x192.py) | 256x192 | 69.8 | 5.47 | 0.68 | 4.48 | 1.39 | 13.89 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.zip) |
+| [RTMPose-m\*](./rtmpose/body_2d_keypoint/rtmpose-m_8xb256-420e_coco-256x192.py) | 256x192 | 72.8 | 13.59 | 1.93 | 11.06 | 2.29 | 26.44 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.zip) |
+| [RTMPose-l\*](./rtmpose/body_2d_keypoint/rtmpose-l_8xb256-420e_coco-256x192.py) | 256x192 | 75.3 | 27.66 | 4.16 | 18.85 | 3.46 | 45.37 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.zip) |
+
#### 26 Keypoints
@@ -276,7 +279,7 @@ For more details, please refer to [GroupFisher Pruning for RTMPose](./rtmpose/pr
- Keypoints are defined as [COCO-WholeBody](https://github.com/jin-s13/COCO-WholeBody/). For details please refer to the [meta info](/configs/_base_/datasets/coco_wholebody.py).
-
-
+
COCO-WholeBody
| Config | Input Size | Whole AP | Whole AR | FLOPS
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | Download |
@@ -289,7 +292,32 @@ For more details, please refer to [GroupFisher Pruning for RTMPose](./rtmpose/pr
-DWPose
+Cocktail13
+
+- `Cocktail13` denotes model trained on 13 public datasets:
+ - [AI Challenger](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#aic)
+ - [CrowdPose](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#crowdpose)
+ - [MPII](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#mpii)
+ - [sub-JHMDB](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#sub-jhmdb-dataset)
+ - [Halpe](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_wholebody_keypoint.html#halpe)
+ - [PoseTrack18](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#posetrack18)
+ - [COCO-Wholebody](https://github.com/jin-s13/COCO-WholeBody/)
+ - [UBody](https://github.com/IDEA-Research/OSX)
+ - [Human-Art](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#human-art-dataset)
+ - [WFLW](https://wywu.github.io/projects/LAB/WFLW.html)
+ - [300W](https://ibug.doc.ic.ac.uk/resources/300-W/)
+ - [COFW](http://www.vision.caltech.edu/xpburgos/ICCV13/)
+ - [LaPa](https://github.com/JDAI-CV/lapa-dataset)
+
+| Config | Input Size | Whole AP | Whole AR | FLOPS
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | Download |
+| :------------------------------ | :--------: | :------: | :------: | :---------------: | :-----------------------------------------: | :------------------------------------------------: | :-------------------------------: |
+| [RTMW-x
(alpha version)](./rtmpose/wholebody_2d_keypoint/rtmpose-l_8xb64-270e_coco-wholebody-256x192.py) | 256x192 | 67.2 | 75.4 | 13.1 | - | - | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmw/rtmw-x_simcc-cocktail13_pt-ucoco_270e-256x192-fbef0d61_20230925.pth) |
+| [RTMW-x
(alpha version)](./rtmpose/wholebody_2d_keypoint/rtmpose-x_8xb32-270e_coco-wholebody-384x288.py) | 384x288 | 70.2 | 77.9 | 29.3 | - | - | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmw/rtmw-x_simcc-cocktail13_pt-ucoco_270e-384x288-0949e3a9_20230925.pth) |
+
+
+
+
+COCO+UBody
- DWPose Models are supported by [DWPose](https://github.com/IDEA-Research/DWPose)
- Models are trained and distilled on:
diff --git a/projects/rtmpose/README_CN.md b/projects/rtmpose/README_CN.md
index a08c74283f..7930d34797 100644
--- a/projects/rtmpose/README_CN.md
+++ b/projects/rtmpose/README_CN.md
@@ -40,6 +40,9 @@ ______________________________________________________________________
## 🥳 最新进展 [🔝](#-table-of-contents)
+- 2023 年 9 月:
+ - 发布混合数据集上训练的 RTMW 模型。Alpha 版本的 RTMW-x 在 COCO-Wholebody 验证集上取得了 70.2 mAP。技术报告正在撰写中。
+ - 增加 HumanArt 上训练的 YOLOX 和 RTMDet 模型。
- 2023 年 8 月:
- 支持基于 RTMPose 模型蒸馏的 133 点 WholeBody 模型(由 [DWPose](https://github.com/IDEA-Research/DWPose/tree/main) 提供)。
- 你可以在 [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) 中使用 DWPose/RTMPose 作为姿态估计后端进行人物图像生成。升级 sd-webui-controlnet >= v1.1237 并选择 `dw_openpose_full` 即可使用。
@@ -210,15 +213,6 @@ RTMPose 是一个长期优化迭代的项目,致力于业务场景下的高性
- 面向艺术图片的人体姿态估计 RTMPose 模型由 [Human-Art](https://github.com/IDEA-Research/HumanArt) 提供。
-
-人体姿态估计模型:
-
-| Config | Input Size | AP
(Human-Art GT) | Params
(M) | FLOPS
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | ncnn-FP16-Latency
(ms)
(Snapdragon 865) | Download |
-| :-----------------------------------------------------------------------------: | :--------: | :-----------------------: | :----------------: | :---------------: | :-----------------------------------------: | :------------------------------------------------: | :-----------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
-| [RTMPose-t\*](./rtmpose/body_2d_keypoint/rtmpose-t_8xb256-420e_coco-256x192.py) | 256x192 | 65.5 | 3.34 | 0.36 | 3.20 | 1.06 | 9.02 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.zip) |
-| [RTMPose-s\*](./rtmpose/body_2d_keypoint/rtmpose-s_8xb256-420e_coco-256x192.py) | 256x192 | 69.8 | 5.47 | 0.68 | 4.48 | 1.39 | 13.89 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.zip) |
-| [RTMPose-m\*](./rtmpose/body_2d_keypoint/rtmpose-m_8xb256-420e_coco-256x192.py) | 256x192 | 72.8 | 13.59 | 1.93 | 11.06 | 2.29 | 26.44 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.zip) |
-| [RTMPose-l\*](./rtmpose/body_2d_keypoint/rtmpose-l_8xb256-420e_coco-256x192.py) | 256x192 | 75.3 | 27.66 | 4.16 | 18.85 | 3.46 | 45.37 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.zip) |
-
人体检测模型:
| Detection Config | Input Size | Model AP
(OneHand10K) | Flops
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | Download |
@@ -232,6 +226,15 @@ RTMPose 是一个长期优化迭代的项目,致力于业务场景下的高性
| [YOLOX-l](./yolox/humanart/yolox_l_8xb8-300e_humanart.py) | 640x640 | 60.2 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_l_8xb8-300e_humanart-ce1d7a62.pth) |
| [YOLOX-x](./yolox/humanart/yolox_x_8xb8-300e_humanart.py) | 640x640 | 61.3 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_x_8xb8-300e_humanart-a39d44ed.pth) |
+人体姿态估计模型:
+
+| Config | Input Size | AP
(Human-Art GT) | Params
(M) | FLOPS
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | ncnn-FP16-Latency
(ms)
(Snapdragon 865) | Download |
+| :-----------------------------------------------------------------------------: | :--------: | :-----------------------: | :----------------: | :---------------: | :-----------------------------------------: | :------------------------------------------------: | :-----------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [RTMPose-t\*](./rtmpose/body_2d_keypoint/rtmpose-t_8xb256-420e_coco-256x192.py) | 256x192 | 65.5 | 3.34 | 0.36 | 3.20 | 1.06 | 9.02 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.zip) |
+| [RTMPose-s\*](./rtmpose/body_2d_keypoint/rtmpose-s_8xb256-420e_coco-256x192.py) | 256x192 | 69.8 | 5.47 | 0.68 | 4.48 | 1.39 | 13.89 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.zip) |
+| [RTMPose-m\*](./rtmpose/body_2d_keypoint/rtmpose-m_8xb256-420e_coco-256x192.py) | 256x192 | 72.8 | 13.59 | 1.93 | 11.06 | 2.29 | 26.44 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.zip) |
+| [RTMPose-l\*](./rtmpose/body_2d_keypoint/rtmpose-l_8xb256-420e_coco-256x192.py) | 256x192 | 75.3 | 27.66 | 4.16 | 18.85 | 3.46 | 45.37 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.zip) |
+
#### 26 Keypoints
@@ -267,7 +270,7 @@ RTMPose 是一个长期优化迭代的项目,致力于业务场景下的高性
- 关键点骨架定义遵循 [COCO-WholeBody](https://github.com/jin-s13/COCO-WholeBody/),详情见 [meta info](/configs/_base_/datasets/coco_wholebody.py)。
-
-
+
COCO-WholeBody
| Config | Input Size | Whole AP | Whole AR | FLOPS
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | Download |
@@ -280,7 +283,32 @@ RTMPose 是一个长期优化迭代的项目,致力于业务场景下的高性
-DWPose
+Cocktail13
+
+- `Cocktail13` 代表模型在 13 个开源数据集上训练得到:
+ - [AI Challenger](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#aic)
+ - [CrowdPose](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#crowdpose)
+ - [MPII](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#mpii)
+ - [sub-JHMDB](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#sub-jhmdb-dataset)
+ - [Halpe](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_wholebody_keypoint.html#halpe)
+ - [PoseTrack18](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#posetrack18)
+ - [COCO-Wholebody](https://github.com/jin-s13/COCO-WholeBody/)
+ - [UBody](https://github.com/IDEA-Research/OSX)
+ - [Human-Art](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#human-art-dataset)
+ - [WFLW](https://wywu.github.io/projects/LAB/WFLW.html)
+ - [300W](https://ibug.doc.ic.ac.uk/resources/300-W/)
+ - [COFW](http://www.vision.caltech.edu/xpburgos/ICCV13/)
+ - [LaPa](https://github.com/JDAI-CV/lapa-dataset)
+
+| Config | Input Size | Whole AP | Whole AR | FLOPS
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | Download |
+| :------------------------------ | :--------: | :------: | :------: | :---------------: | :-----------------------------------------: | :------------------------------------------------: | :-------------------------------: |
+| [RTMW-x
(alpha version)](./rtmpose/wholebody_2d_keypoint/rtmpose-l_8xb64-270e_coco-wholebody-256x192.py) | 256x192 | 67.2 | 75.4 | 13.1 | - | - | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmw/rtmw-x_simcc-cocktail13_pt-ucoco_270e-256x192-fbef0d61_20230925.pth) |
+| [RTMW-x
(alpha version)](./rtmpose/wholebody_2d_keypoint/rtmpose-x_8xb32-270e_coco-wholebody-384x288.py) | 384x288 | 70.2 | 77.9 | 29.3 | - | - | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmw/rtmw-x_simcc-cocktail13_pt-ucoco_270e-384x288-0949e3a9_20230925.pth) |
+
+
+
+
+COCO+UBody
- DWPose 模型由 [DWPose](https://github.com/IDEA-Research/DWPose) 项目提供
- 模型在以下数据集上训练并蒸馏:
diff --git a/projects/rtmpose/rtmpose/wholebody_2d_keypoint/rtmw-x_8xb320-270e_cocktail13-384x288.py b/projects/rtmpose/rtmpose/wholebody_2d_keypoint/rtmw-x_8xb320-270e_cocktail13-384x288.py
new file mode 100644
index 0000000000..03f5fd3a04
--- /dev/null
+++ b/projects/rtmpose/rtmpose/wholebody_2d_keypoint/rtmw-x_8xb320-270e_cocktail13-384x288.py
@@ -0,0 +1,586 @@
+_base_ = ['mmpose::_base_/default_runtime.py']
+
+# common setting
+num_keypoints = 133
+input_size = (288, 384)
+
+# runtime
+max_epochs = 270
+stage2_num_epochs = 10
+base_lr = 5e-4
+train_batch_size = 320
+val_batch_size = 32
+
+train_cfg = dict(max_epochs=max_epochs, val_interval=10)
+randomness = dict(seed=21)
+
+# optimizer
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
+ clip_grad=dict(max_norm=35, norm_type=2),
+ paramwise_cfg=dict(
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
+
+# learning rate
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1.0e-5,
+ by_epoch=False,
+ begin=0,
+ end=1000),
+ dict(
+ type='CosineAnnealingLR',
+ eta_min=base_lr * 0.05,
+ begin=max_epochs // 2,
+ end=max_epochs,
+ T_max=max_epochs // 2,
+ by_epoch=True,
+ convert_to_iter_based=True),
+]
+
+# automatically scaling LR based on the actual training batch size
+auto_scale_lr = dict(base_batch_size=2560)
+
+# codec settings
+codec = dict(
+ type='SimCCLabel',
+ input_size=input_size,
+ sigma=(6., 6.93),
+ simcc_split_ratio=2.0,
+ normalize=False,
+ use_dark=False)
+
+# model settings
+model = dict(
+ type='TopdownPoseEstimator',
+ data_preprocessor=dict(
+ type='PoseDataPreprocessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True),
+ backbone=dict(
+ _scope_='mmdet',
+ type='CSPNeXt',
+ arch='P5',
+ expand_ratio=0.5,
+ deepen_factor=1.33,
+ widen_factor=1.25,
+ channel_attention=True,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='SiLU'),
+ init_cfg=dict(
+ type='Pretrained',
+ prefix='backbone.',
+ checkpoint='https://download.openmmlab.com/mmpose/v1/'
+ 'wholebody_2d_keypoint/rtmpose/ubody/rtmpose-x_simcc-ucoco_pt-aic-coco_270e-384x288-f5b50679_20230822.pth' # noqa
+ )),
+ neck=dict(
+ type='CSPNeXtPAFPN',
+ in_channels=[320, 640, 1280],
+ out_channels=None,
+ out_indices=(
+ 1,
+ 2,
+ ),
+ num_csp_blocks=2,
+ expand_ratio=0.5,
+ norm_cfg=dict(type='SyncBN'),
+ act_cfg=dict(type='SiLU', inplace=True)),
+ head=dict(
+ type='RTMWHead',
+ in_channels=1280,
+ out_channels=num_keypoints,
+ input_size=input_size,
+ in_featuremap_size=tuple([s // 32 for s in input_size]),
+ simcc_split_ratio=codec['simcc_split_ratio'],
+ final_layer_kernel_size=7,
+ gau_cfg=dict(
+ hidden_dims=256,
+ s=128,
+ expansion_factor=2,
+ dropout_rate=0.,
+ drop_path=0.,
+ act_fn='SiLU',
+ use_rel_bias=False,
+ pos_enc=False),
+ loss=dict(
+ type='KLDiscretLoss',
+ use_target_weight=True,
+ beta=10.,
+ label_softmax=True),
+ decoder=codec),
+ test_cfg=dict(flip_test=True))
+
+# base dataset settings
+dataset_type = 'CocoWholeBodyDataset'
+data_mode = 'topdown'
+data_root = 'data/'
+
+backend_args = dict(backend='local')
+
+# pipelines
+train_pipeline = [
+ dict(type='LoadImage', backend_args=backend_args),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='RandomFlip', direction='horizontal'),
+ dict(type='RandomHalfBody'),
+ dict(
+ type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(type='PhotometricDistortion'),
+ dict(
+ type='Albumentation',
+ transforms=[
+ dict(type='Blur', p=0.1),
+ dict(type='MedianBlur', p=0.1),
+ dict(
+ type='CoarseDropout',
+ max_holes=1,
+ max_height=0.4,
+ max_width=0.4,
+ min_holes=1,
+ min_height=0.2,
+ min_width=0.2,
+ p=1.0),
+ ]),
+ dict(
+ type='GenerateTarget',
+ encoder=codec,
+ use_dataset_keypoint_weights=True),
+ dict(type='PackPoseInputs')
+]
+val_pipeline = [
+ dict(type='LoadImage', backend_args=backend_args),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(type='PackPoseInputs')
+]
+
+train_pipeline_stage2 = [
+ dict(type='LoadImage', backend_args=backend_args),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='RandomFlip', direction='horizontal'),
+ dict(type='RandomHalfBody'),
+ dict(
+ type='RandomBBoxTransform',
+ shift_factor=0.,
+ scale_factor=[0.5, 1.5],
+ rotate_factor=90),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(
+ type='Albumentation',
+ transforms=[
+ dict(type='Blur', p=0.1),
+ dict(type='MedianBlur', p=0.1),
+ dict(
+ type='CoarseDropout',
+ max_holes=1,
+ max_height=0.4,
+ max_width=0.4,
+ min_holes=1,
+ min_height=0.2,
+ min_width=0.2,
+ p=0.5),
+ ]),
+ dict(
+ type='GenerateTarget',
+ encoder=codec,
+ use_dataset_keypoint_weights=True),
+ dict(type='PackPoseInputs')
+]
+
+# mapping
+
+aic_coco133 = [(0, 6), (1, 8), (2, 10), (3, 5), (4, 7), (5, 9), (6, 12),
+ (7, 14), (8, 16), (9, 11), (10, 13), (11, 15)]
+
+crowdpose_coco133 = [(0, 5), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (6, 11),
+ (7, 12), (8, 13), (9, 14), (10, 15), (11, 16)]
+
+mpii_coco133 = [
+ (0, 16),
+ (1, 14),
+ (2, 12),
+ (3, 11),
+ (4, 13),
+ (5, 15),
+ (8, 18),
+ (9, 17),
+ (10, 10),
+ (11, 8),
+ (12, 6),
+ (13, 5),
+ (14, 7),
+ (15, 9),
+]
+
+jhmdb_coco133 = [
+ (0, 18),
+ (2, 17),
+ (3, 6),
+ (4, 5),
+ (5, 12),
+ (6, 11),
+ (7, 8),
+ (8, 7),
+ (9, 14),
+ (10, 13),
+ (11, 10),
+ (12, 9),
+ (13, 16),
+ (14, 15),
+]
+
+halpe_coco133 = [(i, i)
+ for i in range(17)] + [(20, 17), (21, 20), (22, 18), (23, 21),
+ (24, 19),
+ (25, 22)] + [(i, i - 3)
+ for i in range(26, 136)]
+
+posetrack_coco133 = [
+ (0, 0),
+ (2, 17),
+ (3, 3),
+ (4, 4),
+ (5, 5),
+ (6, 6),
+ (7, 7),
+ (8, 8),
+ (9, 9),
+ (10, 10),
+ (11, 11),
+ (12, 12),
+ (13, 13),
+ (14, 14),
+ (15, 15),
+ (16, 16),
+]
+
+humanart_coco133 = [(i, i) for i in range(17)] + [(17, 99), (18, 120),
+ (19, 17), (20, 20)]
+
+# train datasets
+dataset_coco = dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='coco/annotations/coco_wholebody_train_v1.0.json',
+ data_prefix=dict(img='detection/coco/train2017/'),
+ pipeline=[],
+)
+
+dataset_aic = dict(
+ type='AicDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='aic/annotations/aic_train.json',
+ data_prefix=dict(img='pose/ai_challenge/ai_challenger_keypoint'
+ '_train_20170902/keypoint_train_images_20170902/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=aic_coco133)
+ ],
+)
+
+dataset_crowdpose = dict(
+ type='CrowdPoseDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='crowdpose/annotations/mmpose_crowdpose_trainval.json',
+ data_prefix=dict(img='pose/CrowdPose/images/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=crowdpose_coco133)
+ ],
+)
+
+dataset_mpii = dict(
+ type='MpiiDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='mpii/annotations/mpii_train.json',
+ data_prefix=dict(img='pose/MPI/images/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=mpii_coco133)
+ ],
+)
+
+dataset_jhmdb = dict(
+ type='JhmdbDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='jhmdb/annotations/Sub1_train.json',
+ data_prefix=dict(img='pose/JHMDB/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=jhmdb_coco133)
+ ],
+)
+
+dataset_halpe = dict(
+ type='HalpeDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='halpe/annotations/halpe_train_v1.json',
+ data_prefix=dict(img='pose/Halpe/hico_20160224_det/images/train2015'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=halpe_coco133)
+ ],
+)
+
+dataset_posetrack = dict(
+ type='PoseTrack18Dataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='posetrack18/annotations/posetrack18_train.json',
+ data_prefix=dict(img='pose/PoseChallenge2018/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=posetrack_coco133)
+ ],
+)
+
+dataset_humanart = dict(
+ type='HumanArt21Dataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='HumanArt/annotations/training_humanart.json',
+ filter_cfg=dict(scenes=['real_human']),
+ data_prefix=dict(img='pose/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=humanart_coco133)
+ ])
+
+ubody_scenes = [
+ 'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
+ 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
+ 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
+]
+
+ubody_datasets = []
+for scene in ubody_scenes:
+ each = dict(
+ type='UBody2dDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file=f'Ubody/annotations/{scene}/train_annotations.json',
+ data_prefix=dict(img='pose/UBody/images/'),
+ pipeline=[],
+ sample_interval=10)
+ ubody_datasets.append(each)
+
+dataset_ubody = dict(
+ type='CombinedDataset',
+ metainfo=dict(from_file='configs/_base_/datasets/ubody2d.py'),
+ datasets=ubody_datasets,
+ pipeline=[],
+ test_mode=False,
+)
+
+face_pipeline = [
+ dict(type='LoadImage', backend_args=backend_args),
+ dict(type='GetBBoxCenterScale'),
+ dict(
+ type='RandomBBoxTransform',
+ shift_factor=0.,
+ scale_factor=[0.3, 0.5],
+ rotate_factor=0),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+]
+
+wflw_coco133 = [(i * 2, 20 + i)
+ for i in range(17)] + [(33 + i, 41 + i) for i in range(5)] + [
+ (42 + i, 46 + i) for i in range(5)
+ ] + [(51 + i, 50 + i)
+ for i in range(9)] + [(60, 59), (61, 60), (63, 61),
+ (64, 62), (65, 63), (67, 64),
+ (68, 65), (69, 66), (71, 67),
+ (72, 68), (73, 69),
+ (75, 70)] + [(76 + i, 71 + i)
+ for i in range(20)]
+dataset_wflw = dict(
+ type='WFLWDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='wflw/annotations/face_landmarks_wflw_train.json',
+ data_prefix=dict(img='pose/WFLW/images/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=wflw_coco133), *face_pipeline
+ ],
+)
+
+mapping_300w_coco133 = [(i, 20 + i) for i in range(68)]
+dataset_300w = dict(
+ type='Face300WDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='300w/annotations/face_landmarks_300w_train.json',
+ data_prefix=dict(img='pose/300w/images/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=mapping_300w_coco133), *face_pipeline
+ ],
+)
+
+cofw_coco133 = [(0, 41), (2, 45), (4, 43), (1, 47), (3, 49), (6, 45), (8, 59),
+ (10, 62), (9, 68), (11, 65), (18, 54), (19, 58), (20, 53),
+ (21, 56), (22, 71), (23, 77), (24, 74), (25, 85), (26, 89),
+ (27, 80), (28, 28)]
+dataset_cofw = dict(
+ type='COFWDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='cofw/annotations/cofw_train.json',
+ data_prefix=dict(img='pose/COFW/images/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=cofw_coco133), *face_pipeline
+ ],
+)
+
+lapa_coco133 = [(i * 2, 20 + i) for i in range(17)] + [
+ (33 + i, 41 + i) for i in range(5)
+] + [(42 + i, 46 + i) for i in range(5)] + [
+ (51 + i, 50 + i) for i in range(4)
+] + [(58 + i, 54 + i) for i in range(5)] + [(66, 59), (67, 60), (69, 61),
+ (70, 62), (71, 63), (73, 64),
+ (75, 65), (76, 66), (78, 67),
+ (79, 68), (80, 69),
+ (82, 70)] + [(84 + i, 71 + i)
+ for i in range(20)]
+dataset_lapa = dict(
+ type='LapaDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='LaPa/annotations/lapa_trainval.json',
+ data_prefix=dict(img='pose/LaPa/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=lapa_coco133), *face_pipeline
+ ],
+)
+
+dataset_wb = dict(
+ type='CombinedDataset',
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=[dataset_coco, dataset_halpe, dataset_ubody],
+ pipeline=[],
+ test_mode=False,
+)
+
+dataset_body = dict(
+ type='CombinedDataset',
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=[
+ dataset_aic,
+ dataset_crowdpose,
+ dataset_mpii,
+ dataset_jhmdb,
+ dataset_posetrack,
+ dataset_humanart,
+ ],
+ pipeline=[],
+ test_mode=False,
+)
+
+dataset_face = dict(
+ type='CombinedDataset',
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=[
+ dataset_wflw,
+ dataset_300w,
+ dataset_cofw,
+ dataset_lapa,
+ ],
+ pipeline=[],
+ test_mode=False,
+)
+
+train_datasets = [
+ dataset_wb,
+ dataset_body,
+ dataset_face,
+]
+
+# data loaders
+train_dataloader = dict(
+ batch_size=train_batch_size,
+ num_workers=10,
+ pin_memory=True,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(
+ type='CombinedDataset',
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=train_datasets,
+ pipeline=train_pipeline,
+ test_mode=False,
+ ))
+
+val_dataloader = dict(
+ batch_size=val_batch_size,
+ num_workers=4,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
+ dataset=dict(
+ type='CocoWholeBodyDataset',
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
+ data_prefix=dict(img='data/detection/coco/val2017/'),
+ pipeline=val_pipeline,
+ bbox_file='data/coco/person_detection_results/'
+ 'COCO_val2017_detections_AP_H_56_person.json',
+ test_mode=True))
+
+test_dataloader = val_dataloader
+
+# hooks
+default_hooks = dict(
+ checkpoint=dict(
+ save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
+custom_hooks = [
+ dict(
+ type='EMAHook',
+ ema_type='ExpMomentumEMA',
+ momentum=0.0002,
+ update_buffers=True,
+ priority=49),
+ dict(
+ type='mmdet.PipelineSwitchHook',
+ switch_epoch=max_epochs - stage2_num_epochs,
+ switch_pipeline=train_pipeline_stage2)
+]
+
+# evaluators
+val_evaluator = dict(
+ type='CocoWholeBodyMetric',
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
+test_evaluator = val_evaluator
diff --git a/projects/rtmpose/rtmpose/wholebody_2d_keypoint/rtmw-x_8xb704-270e_cocktail13-256x192.py b/projects/rtmpose/rtmpose/wholebody_2d_keypoint/rtmw-x_8xb704-270e_cocktail13-256x192.py
new file mode 100644
index 0000000000..369b85b1ce
--- /dev/null
+++ b/projects/rtmpose/rtmpose/wholebody_2d_keypoint/rtmw-x_8xb704-270e_cocktail13-256x192.py
@@ -0,0 +1,586 @@
+_base_ = ['mmpose::_base_/default_runtime.py']
+
+# common setting
+num_keypoints = 133
+input_size = (192, 256)
+
+# runtime
+max_epochs = 270
+stage2_num_epochs = 10
+base_lr = 5e-4
+train_batch_size = 704
+val_batch_size = 32
+
+train_cfg = dict(max_epochs=max_epochs, val_interval=10)
+randomness = dict(seed=21)
+
+# optimizer
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
+ clip_grad=dict(max_norm=35, norm_type=2),
+ paramwise_cfg=dict(
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
+
+# learning rate
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1.0e-5,
+ by_epoch=False,
+ begin=0,
+ end=1000),
+ dict(
+ type='CosineAnnealingLR',
+ eta_min=base_lr * 0.05,
+ begin=max_epochs // 2,
+ end=max_epochs,
+ T_max=max_epochs // 2,
+ by_epoch=True,
+ convert_to_iter_based=True),
+]
+
+# automatically scaling LR based on the actual training batch size
+auto_scale_lr = dict(base_batch_size=5632)
+
+# codec settings
+codec = dict(
+ type='SimCCLabel',
+ input_size=input_size,
+ sigma=(4.9, 5.66),
+ simcc_split_ratio=2.0,
+ normalize=False,
+ use_dark=False)
+
+# model settings
+model = dict(
+ type='TopdownPoseEstimator',
+ data_preprocessor=dict(
+ type='PoseDataPreprocessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True),
+ backbone=dict(
+ _scope_='mmdet',
+ type='CSPNeXt',
+ arch='P5',
+ expand_ratio=0.5,
+ deepen_factor=1.33,
+ widen_factor=1.25,
+ channel_attention=True,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='SiLU'),
+ init_cfg=dict(
+ type='Pretrained',
+ prefix='backbone.',
+ checkpoint='https://download.openmmlab.com/mmpose/v1/'
+ 'wholebody_2d_keypoint/rtmpose/ubody/rtmpose-x_simcc-ucoco_pt-aic-coco_270e-256x192-05f5bcb7_20230822.pth' # noqa
+ )),
+ neck=dict(
+ type='CSPNeXtPAFPN',
+ in_channels=[320, 640, 1280],
+ out_channels=None,
+ out_indices=(
+ 1,
+ 2,
+ ),
+ num_csp_blocks=2,
+ expand_ratio=0.5,
+ norm_cfg=dict(type='SyncBN'),
+ act_cfg=dict(type='SiLU', inplace=True)),
+ head=dict(
+ type='RTMWHead',
+ in_channels=1280,
+ out_channels=num_keypoints,
+ input_size=input_size,
+ in_featuremap_size=tuple([s // 32 for s in input_size]),
+ simcc_split_ratio=codec['simcc_split_ratio'],
+ final_layer_kernel_size=7,
+ gau_cfg=dict(
+ hidden_dims=256,
+ s=128,
+ expansion_factor=2,
+ dropout_rate=0.,
+ drop_path=0.,
+ act_fn='SiLU',
+ use_rel_bias=False,
+ pos_enc=False),
+ loss=dict(
+ type='KLDiscretLoss',
+ use_target_weight=True,
+ beta=10.,
+ label_softmax=True),
+ decoder=codec),
+ test_cfg=dict(flip_test=True))
+
+# base dataset settings
+dataset_type = 'CocoWholeBodyDataset'
+data_mode = 'topdown'
+data_root = 'data/'
+
+backend_args = dict(backend='local')
+
+# pipelines
+train_pipeline = [
+ dict(type='LoadImage', backend_args=backend_args),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='RandomFlip', direction='horizontal'),
+ dict(type='RandomHalfBody'),
+ dict(
+ type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(type='PhotometricDistortion'),
+ dict(
+ type='Albumentation',
+ transforms=[
+ dict(type='Blur', p=0.1),
+ dict(type='MedianBlur', p=0.1),
+ dict(
+ type='CoarseDropout',
+ max_holes=1,
+ max_height=0.4,
+ max_width=0.4,
+ min_holes=1,
+ min_height=0.2,
+ min_width=0.2,
+ p=1.0),
+ ]),
+ dict(
+ type='GenerateTarget',
+ encoder=codec,
+ use_dataset_keypoint_weights=True),
+ dict(type='PackPoseInputs')
+]
+val_pipeline = [
+ dict(type='LoadImage', backend_args=backend_args),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(type='PackPoseInputs')
+]
+
+train_pipeline_stage2 = [
+ dict(type='LoadImage', backend_args=backend_args),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='RandomFlip', direction='horizontal'),
+ dict(type='RandomHalfBody'),
+ dict(
+ type='RandomBBoxTransform',
+ shift_factor=0.,
+ scale_factor=[0.5, 1.5],
+ rotate_factor=90),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(
+ type='Albumentation',
+ transforms=[
+ dict(type='Blur', p=0.1),
+ dict(type='MedianBlur', p=0.1),
+ dict(
+ type='CoarseDropout',
+ max_holes=1,
+ max_height=0.4,
+ max_width=0.4,
+ min_holes=1,
+ min_height=0.2,
+ min_width=0.2,
+ p=0.5),
+ ]),
+ dict(
+ type='GenerateTarget',
+ encoder=codec,
+ use_dataset_keypoint_weights=True),
+ dict(type='PackPoseInputs')
+]
+
+# mapping
+
+aic_coco133 = [(0, 6), (1, 8), (2, 10), (3, 5), (4, 7), (5, 9), (6, 12),
+ (7, 14), (8, 16), (9, 11), (10, 13), (11, 15)]
+
+crowdpose_coco133 = [(0, 5), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (6, 11),
+ (7, 12), (8, 13), (9, 14), (10, 15), (11, 16)]
+
+mpii_coco133 = [
+ (0, 16),
+ (1, 14),
+ (2, 12),
+ (3, 11),
+ (4, 13),
+ (5, 15),
+ (8, 18),
+ (9, 17),
+ (10, 10),
+ (11, 8),
+ (12, 6),
+ (13, 5),
+ (14, 7),
+ (15, 9),
+]
+
+jhmdb_coco133 = [
+ (0, 18),
+ (2, 17),
+ (3, 6),
+ (4, 5),
+ (5, 12),
+ (6, 11),
+ (7, 8),
+ (8, 7),
+ (9, 14),
+ (10, 13),
+ (11, 10),
+ (12, 9),
+ (13, 16),
+ (14, 15),
+]
+
+halpe_coco133 = [(i, i)
+ for i in range(17)] + [(20, 17), (21, 20), (22, 18), (23, 21),
+ (24, 19),
+ (25, 22)] + [(i, i - 3)
+ for i in range(26, 136)]
+
+posetrack_coco133 = [
+ (0, 0),
+ (2, 17),
+ (3, 3),
+ (4, 4),
+ (5, 5),
+ (6, 6),
+ (7, 7),
+ (8, 8),
+ (9, 9),
+ (10, 10),
+ (11, 11),
+ (12, 12),
+ (13, 13),
+ (14, 14),
+ (15, 15),
+ (16, 16),
+]
+
+humanart_coco133 = [(i, i) for i in range(17)] + [(17, 99), (18, 120),
+ (19, 17), (20, 20)]
+
+# train datasets
+dataset_coco = dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='coco/annotations/coco_wholebody_train_v1.0.json',
+ data_prefix=dict(img='detection/coco/train2017/'),
+ pipeline=[],
+)
+
+dataset_aic = dict(
+ type='AicDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='aic/annotations/aic_train.json',
+ data_prefix=dict(img='pose/ai_challenge/ai_challenger_keypoint'
+ '_train_20170902/keypoint_train_images_20170902/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=aic_coco133)
+ ],
+)
+
+dataset_crowdpose = dict(
+ type='CrowdPoseDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='crowdpose/annotations/mmpose_crowdpose_trainval.json',
+ data_prefix=dict(img='pose/CrowdPose/images/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=crowdpose_coco133)
+ ],
+)
+
+dataset_mpii = dict(
+ type='MpiiDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='mpii/annotations/mpii_train.json',
+ data_prefix=dict(img='pose/MPI/images/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=mpii_coco133)
+ ],
+)
+
+dataset_jhmdb = dict(
+ type='JhmdbDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='jhmdb/annotations/Sub1_train.json',
+ data_prefix=dict(img='pose/JHMDB/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=jhmdb_coco133)
+ ],
+)
+
+dataset_halpe = dict(
+ type='HalpeDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='halpe/annotations/halpe_train_v1.json',
+ data_prefix=dict(img='pose/Halpe/hico_20160224_det/images/train2015'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=halpe_coco133)
+ ],
+)
+
+dataset_posetrack = dict(
+ type='PoseTrack18Dataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='posetrack18/annotations/posetrack18_train.json',
+ data_prefix=dict(img='pose/PoseChallenge2018/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=posetrack_coco133)
+ ],
+)
+
+dataset_humanart = dict(
+ type='HumanArt21Dataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='HumanArt/annotations/training_humanart.json',
+ filter_cfg=dict(scenes=['real_human']),
+ data_prefix=dict(img='pose/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=humanart_coco133)
+ ])
+
+ubody_scenes = [
+ 'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
+ 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
+ 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
+]
+
+ubody_datasets = []
+for scene in ubody_scenes:
+ each = dict(
+ type='UBody2dDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file=f'Ubody/annotations/{scene}/train_annotations.json',
+ data_prefix=dict(img='pose/UBody/images/'),
+ pipeline=[],
+ sample_interval=10)
+ ubody_datasets.append(each)
+
+dataset_ubody = dict(
+ type='CombinedDataset',
+ metainfo=dict(from_file='configs/_base_/datasets/ubody2d.py'),
+ datasets=ubody_datasets,
+ pipeline=[],
+ test_mode=False,
+)
+
+face_pipeline = [
+ dict(type='LoadImage', backend_args=backend_args),
+ dict(type='GetBBoxCenterScale'),
+ dict(
+ type='RandomBBoxTransform',
+ shift_factor=0.,
+ scale_factor=[0.3, 0.5],
+ rotate_factor=0),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+]
+
+wflw_coco133 = [(i * 2, 20 + i)
+ for i in range(17)] + [(33 + i, 41 + i) for i in range(5)] + [
+ (42 + i, 46 + i) for i in range(5)
+ ] + [(51 + i, 50 + i)
+ for i in range(9)] + [(60, 59), (61, 60), (63, 61),
+ (64, 62), (65, 63), (67, 64),
+ (68, 65), (69, 66), (71, 67),
+ (72, 68), (73, 69),
+ (75, 70)] + [(76 + i, 71 + i)
+ for i in range(20)]
+dataset_wflw = dict(
+ type='WFLWDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='wflw/annotations/face_landmarks_wflw_train.json',
+ data_prefix=dict(img='pose/WFLW/images/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=wflw_coco133), *face_pipeline
+ ],
+)
+
+mapping_300w_coco133 = [(i, 20 + i) for i in range(68)]
+dataset_300w = dict(
+ type='Face300WDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='300w/annotations/face_landmarks_300w_train.json',
+ data_prefix=dict(img='pose/300w/images/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=mapping_300w_coco133), *face_pipeline
+ ],
+)
+
+cofw_coco133 = [(0, 41), (2, 45), (4, 43), (1, 47), (3, 49), (6, 45), (8, 59),
+ (10, 62), (9, 68), (11, 65), (18, 54), (19, 58), (20, 53),
+ (21, 56), (22, 71), (23, 77), (24, 74), (25, 85), (26, 89),
+ (27, 80), (28, 28)]
+dataset_cofw = dict(
+ type='COFWDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='cofw/annotations/cofw_train.json',
+ data_prefix=dict(img='pose/COFW/images/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=cofw_coco133), *face_pipeline
+ ],
+)
+
+lapa_coco133 = [(i * 2, 20 + i) for i in range(17)] + [
+ (33 + i, 41 + i) for i in range(5)
+] + [(42 + i, 46 + i) for i in range(5)] + [
+ (51 + i, 50 + i) for i in range(4)
+] + [(58 + i, 54 + i) for i in range(5)] + [(66, 59), (67, 60), (69, 61),
+ (70, 62), (71, 63), (73, 64),
+ (75, 65), (76, 66), (78, 67),
+ (79, 68), (80, 69),
+ (82, 70)] + [(84 + i, 71 + i)
+ for i in range(20)]
+dataset_lapa = dict(
+ type='LapaDataset',
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='LaPa/annotations/lapa_trainval.json',
+ data_prefix=dict(img='pose/LaPa/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=num_keypoints,
+ mapping=lapa_coco133), *face_pipeline
+ ],
+)
+
+dataset_wb = dict(
+ type='CombinedDataset',
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=[dataset_coco, dataset_halpe, dataset_ubody],
+ pipeline=[],
+ test_mode=False,
+)
+
+dataset_body = dict(
+ type='CombinedDataset',
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=[
+ dataset_aic,
+ dataset_crowdpose,
+ dataset_mpii,
+ dataset_jhmdb,
+ dataset_posetrack,
+ dataset_humanart,
+ ],
+ pipeline=[],
+ test_mode=False,
+)
+
+dataset_face = dict(
+ type='CombinedDataset',
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=[
+ dataset_wflw,
+ dataset_300w,
+ dataset_cofw,
+ dataset_lapa,
+ ],
+ pipeline=[],
+ test_mode=False,
+)
+
+train_datasets = [
+ dataset_wb,
+ dataset_body,
+ dataset_face,
+]
+
+# data loaders
+train_dataloader = dict(
+ batch_size=train_batch_size,
+ num_workers=10,
+ pin_memory=True,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(
+ type='CombinedDataset',
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
+ datasets=train_datasets,
+ pipeline=train_pipeline,
+ test_mode=False,
+ ))
+
+val_dataloader = dict(
+ batch_size=val_batch_size,
+ num_workers=4,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
+ dataset=dict(
+ type='CocoWholeBodyDataset',
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
+ data_prefix=dict(img='data/detection/coco/val2017/'),
+ pipeline=val_pipeline,
+ bbox_file='data/coco/person_detection_results/'
+ 'COCO_val2017_detections_AP_H_56_person.json',
+ test_mode=True))
+
+test_dataloader = val_dataloader
+
+# hooks
+default_hooks = dict(
+ checkpoint=dict(
+ save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
+custom_hooks = [
+ dict(
+ type='EMAHook',
+ ema_type='ExpMomentumEMA',
+ momentum=0.0002,
+ update_buffers=True,
+ priority=49),
+ dict(
+ type='mmdet.PipelineSwitchHook',
+ switch_epoch=max_epochs - stage2_num_epochs,
+ switch_pipeline=train_pipeline_stage2)
+]
+
+# evaluators
+val_evaluator = dict(
+ type='CocoWholeBodyMetric',
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
+test_evaluator = val_evaluator
diff --git a/tools/misc/pth_transfer.py b/tools/misc/pth_transfer.py
index ee08fee748..7433c6771e 100644
--- a/tools/misc/pth_transfer.py
+++ b/tools/misc/pth_transfer.py
@@ -6,7 +6,7 @@
def change_model(args):
- dis_model = torch.load(args.dis_path)
+ dis_model = torch.load(args.dis_path, map_location='cpu')
all_name = []
if args.two_dis:
for name, v in dis_model['state_dict'].items():