Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Bug] ValueError: cannot reshape array of size 21 into shape (51) #3150

Open
2 tasks done
342600694 opened this issue Nov 13, 2024 · 0 comments
Open
2 tasks done

[Bug] ValueError: cannot reshape array of size 21 into shape (51) #3150

342600694 opened this issue Nov 13, 2024 · 0 comments

Comments

@342600694
Copy link

Prerequisite

Environment

Package Version Editable project location


actionlib 1.14.0
addict 2.4.0
aenum 3.1.15
aliyun-python-sdk-core 2.15.1
aliyun-python-sdk-kms 2.16.3
angles 1.9.13
attrs 24.1.0
bondpy 1.8.6
camera-calibration 1.17.0
camera-calibration-parsers 1.12.0
catkin 0.8.10
certifi 2024.7.4
cffi 1.16.0
charset-normalizer 3.3.2
chumpy 0.70
click 8.1.7
colorama 0.4.6
coloredlogs 15.0.1
contourpy 1.1.1
controller-manager 0.20.0
controller-manager-msgs 0.20.0
coverage 7.6.1
crcmod 1.7
crowdposetools 2.0
cryptography 43.0.0
cv-bridge 1.16.2
cycler 0.12.1
Cython 3.0.11
diagnostic-analysis 1.11.0
diagnostic-common-diagnostics 1.11.0
diagnostic-updater 1.11.0
dill 0.3.8
dynamic-reconfigure 1.7.3
exceptiongroup 1.2.2
filelock 3.14.0
flake8 7.1.1
flatbuffers 24.3.25
fonttools 4.53.1
gazebo_plugins 2.9.2
gazebo_ros 2.9.2
gencpp 0.7.0
geneus 3.0.0
genlisp 0.4.18
genmsg 0.6.0
gennodejs 2.0.2
genpy 0.6.15
grpcio 1.65.4
humanfriendly 10.0
idna 3.7
image-geometry 1.16.2
importlib_metadata 8.2.0
importlib_resources 6.4.0
iniconfig 2.0.0
interactive-markers 1.12.0
interrogate 1.7.0
isort 4.3.21
jmespath 0.10.0
joint-state-publisher 1.15.1
joint-state-publisher-gui 1.15.1
json-tricks 3.17.3
kiwisolver 1.4.5
laser_geometry 1.6.7
Markdown 3.6
markdown-it-py 3.0.0
matplotlib 3.7.5
mccabe 0.7.0
mdurl 0.1.2
message-filters 1.16.0
mmcv 2.1.0
mmdeploy 1.3.1 /home/huochewang/桌面/mmdeploy
mmdeploy-runtime-gpu 1.3.1
mmdet 3.3.0
mmengine 0.10.4
mmpose 1.3.2 /home/huochewang/桌面/mmpose
model-index 0.1.11
mpmath 1.3.0
multiprocess 0.70.16
munkres 1.1.4
numpy 1.24.4
onnx 1.16.2
onnx-simplifier 0.4.36
onnxruntime-gpu 1.14.0
opencv-python 4.10.0.84
opendatalab 0.0.10
openmim 0.3.9
openxlab 0.1.1
ordered-set 4.1.0
oss2 2.17.0
packaging 24.1
pandas 2.0.3
parameterized 0.9.0
pillow 10.4.0
pip 24.0
platformdirs 4.2.2
pluggy 1.5.0
prettytable 3.10.2
protobuf 3.20.2
py 1.11.0
pycocotools 2.0.7
pycodestyle 2.12.1
pycparser 2.22
pycryptodome 3.20.0
pyflakes 3.2.0
Pygments 2.18.0
pyparsing 3.1.2
pytest 8.3.2
pytest-runner 6.0.1
python-dateutil 2.9.0.post0
python-qt-binding 0.4.4
pytz 2023.4
PyYAML 6.0.1
qt-dotgraph 0.4.2
qt-gui 0.4.2
qt-gui-cpp 0.4.2
qt-gui-py-common 0.4.2
requests 2.28.2
resource_retriever 1.12.7
rich 13.4.2
rosbag 1.16.0
rosboost-cfg 1.15.8
rosclean 1.15.8
roscreate 1.15.8
rosgraph 1.16.0
roslaunch 1.16.0
roslib 1.15.8
roslint 0.12.0
roslz4 1.16.0
rosmake 1.15.8
rosmaster 1.16.0
rosmsg 1.16.0
rosnode 1.16.0
rosparam 1.16.0
rospy 1.16.0
rosservice 1.16.0
rostest 1.16.0
rostopic 1.16.0
rosunit 1.15.8
roswtf 1.16.0
rqt_action 0.4.9
rqt_bag 0.5.1
rqt_bag_plugins 0.5.1
rqt-console 0.4.12
rqt_dep 0.4.12
rqt_graph 0.4.14
rqt_gui 0.5.3
rqt_gui_py 0.5.3
rqt-image-view 0.4.17
rqt_launch 0.4.9
rqt-logger-level 0.4.12
rqt-moveit 0.5.11
rqt_msg 0.4.10
rqt_nav_view 0.5.7
rqt_plot 0.4.13
rqt_pose_view 0.5.11
rqt_publisher 0.4.10
rqt_py_common 0.5.3
rqt_py_console 0.4.10
rqt-reconfigure 0.5.5
rqt-robot-dashboard 0.5.8
rqt-robot-monitor 0.5.15
rqt_robot_steering 0.5.12
rqt-runtime-monitor 0.5.10
rqt-rviz 0.7.0
rqt_service_caller 0.4.10
rqt_shell 0.4.11
rqt_srv 0.4.9
rqt-tf-tree 0.6.4
rqt_top 0.4.10
rqt_topic 0.4.13
rqt_web 0.4.10
rviz 1.14.25
scipy 1.10.1
sensor-msgs 1.13.1
setuptools 60.2.0
shapely 2.0.5
six 1.16.0
smach 2.5.2
smach-ros 2.5.2
smclib 1.8.6
sympy 1.13.1
tabulate 0.9.0
termcolor 2.4.0
terminaltables 3.1.10
tf 1.13.2
tf-conversions 1.13.2
tf2-geometry-msgs 0.7.7
tf2-kdl 0.7.7
tf2-py 0.7.7
tf2-ros 0.7.7
tomli 2.0.1
topic-tools 1.16.0
torch 1.12.1+cu116
torchaudio 0.12.1+cu116
torchvision 0.13.1+cu116
tqdm 4.65.2
typing_extensions 4.12.2
tzdata 2024.1
urllib3 1.26.19
wcwidth 0.2.13
wheel 0.43.0
xacro 1.14.18
xdoctest 1.1.6
xtcocotools 1.14.3
yapf 0.40.2
zipp 3.19.2

Reproduces the problem - code sample

Copyright (c) OpenMMLab. All rights reserved.

import argparse
import os
import os.path as osp

from mmengine.config import Config, DictAction
from mmengine.runner import Runner

def parse_args():
parser = argparse.ArgumentParser(description='Train a pose model')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume',
nargs='?',
type=str,
const='auto',
help='If specify checkpint path, resume from it, while if not '
'specify, try to auto resume from the latest checkpoint '
'in the work directory.')
parser.add_argument(
'--amp',
action='store_true',
default=False,
help='enable automatic-mixed-precision training')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
parser.add_argument(
'--auto-scale-lr',
action='store_true',
help='whether to auto scale the learning rate according to the '
'actual batch size and the original batch size.')
parser.add_argument(
'--show-dir',
help='directory where the visualization images will be saved.')
parser.add_argument(
'--show',
action='store_true',
help='whether to display the prediction results in a window.')
parser.add_argument(
'--interval',
type=int,
default=1,
help='visualize per interval samples.')
parser.add_argument(
'--wait-time',
type=float,
default=1,
help='display time of every window. (second)')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
# When using PyTorch version >= 2.0.0, the torch.distributed.launch
# will pass the --local-rank parameter to tools/train.py instead
# of --local_rank.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)

return args

def merge_args(cfg, args):
"""Merge CLI arguments to config."""
if args.no_validate:
cfg.val_cfg = None
cfg.val_dataloader = None
cfg.val_evaluator = None

cfg.launcher = args.launcher

# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
    # update configs according to CLI args if args.work_dir is not None
    cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
    # use config filename as default work_dir if cfg.work_dir is None
    cfg.work_dir = osp.join('./work_dirs',
                            osp.splitext(osp.basename(args.config))[0])

# enable automatic-mixed-precision training
if args.amp is True:
    from mmengine.optim import AmpOptimWrapper, OptimWrapper
    optim_wrapper = cfg.optim_wrapper.get('type', OptimWrapper)
    assert optim_wrapper in (OptimWrapper, AmpOptimWrapper,
                             'OptimWrapper', 'AmpOptimWrapper'), \
        '`--amp` is not supported custom optimizer wrapper type ' \
        f'`{optim_wrapper}.'
    cfg.optim_wrapper.type = 'AmpOptimWrapper'
    cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')

# resume training
if args.resume == 'auto':
    cfg.resume = True
    cfg.load_from = None
elif args.resume is not None:
    cfg.resume = True
    cfg.load_from = args.resume

# enable auto scale learning rate
if args.auto_scale_lr:
    cfg.auto_scale_lr.enable = True

# visualization
if args.show or (args.show_dir is not None):
    assert 'visualization' in cfg.default_hooks, \
        'PoseVisualizationHook is not set in the ' \
        '`default_hooks` field of config. Please set ' \
        '`visualization=dict(type="PoseVisualizationHook")`'

    cfg.default_hooks.visualization.enable = True
    cfg.default_hooks.visualization.show = args.show
    if args.show:
        cfg.default_hooks.visualization.wait_time = args.wait_time
    cfg.default_hooks.visualization.out_dir = args.show_dir
    cfg.default_hooks.visualization.interval = args.interval

if args.cfg_options is not None:
    cfg.merge_from_dict(args.cfg_options)

return cfg

def main():
args = parse_args()

# load config
cfg = Config.fromfile(args.config)

# merge CLI arguments to config
cfg = merge_args(cfg, args)

# set preprocess configs to model
#如果 cfg.model 中没有 data_preprocessor,那么将 preprocess_cfg 的配置作为 data_preprocessor 的值进行设置。
#如果 preprocess_cfg 在配置中存在,则将其传递给模型的 data_preprocessor。
if 'preprocess_cfg' in cfg:
    cfg.model.setdefault('data_preprocessor',
                         cfg.get('preprocess_cfg', {}))

# build the runner from config
runner = Runner.from_cfg(cfg)

# start training
runner.train()

if name == 'main':
main()

Reproduces the problem - command or script

python tools/train.py configs/body_2d_keypoint/rtmo/coco/rtmo-l_16xb16-600e_coco-640x640arm.py --amp

Reproduces the problem - error message

11/13 20:30:16 - mmengine - INFO - Epoch(train) [5][1800/1812] base_lr: 3.951751e-04 lr: 3.951751e-04 eta: 1:38:27 time: 0.229064 data_time: 0.001469 memory: 4821 grad_norm: 107.484255 loss: 7.052426 loss_bbox: 0.826151 loss_vis: 0.000011 loss_mle: -0.047193 loss_oks: 5.877753 loss_cls: 0.395704 num_samples: 17.000000 overlaps: 0.750766
11/13 20:30:19 - mmengine - INFO - Exp name: rtmo-l_16xb16-600e_coco-640x640arm_20241113_195726
11/13 20:30:20 - mmengine - INFO - Epoch(val) [5][ 50/959] eta: 0:00:17 time: 0.018845 data_time: 0.000969 memory: 4557
11/13 20:30:21 - mmengine - INFO - Epoch(val) [5][100/959] eta: 0:00:15 time: 0.017900 data_time: 0.000073 memory: 932
11/13 20:30:22 - mmengine - INFO - Epoch(val) [5][150/959] eta: 0:00:14 time: 0.018003 data_time: 0.000085 memory: 932
11/13 20:30:22 - mmengine - INFO - Epoch(val) [5][200/959] eta: 0:00:13 time: 0.017905 data_time: 0.000072 memory: 932
11/13 20:30:23 - mmengine - INFO - Epoch(val) [5][250/959] eta: 0:00:12 time: 0.017939 data_time: 0.000071 memory: 932
11/13 20:30:24 - mmengine - INFO - Epoch(val) [5][300/959] eta: 0:00:11 time: 0.017909 data_time: 0.000073 memory: 932
11/13 20:30:25 - mmengine - INFO - Epoch(val) [5][350/959] eta: 0:00:10 time: 0.017907 data_time: 0.000073 memory: 932
11/13 20:30:26 - mmengine - INFO - Epoch(val) [5][400/959] eta: 0:00:10 time: 0.017903 data_time: 0.000073 memory: 932
11/13 20:30:27 - mmengine - INFO - Epoch(val) [5][450/959] eta: 0:00:09 time: 0.017901 data_time: 0.000074 memory: 932
11/13 20:30:28 - mmengine - INFO - Epoch(val) [5][500/959] eta: 0:00:08 time: 0.017903 data_time: 0.000073 memory: 932
11/13 20:30:29 - mmengine - INFO - Epoch(val) [5][550/959] eta: 0:00:07 time: 0.017891 data_time: 0.000069 memory: 932
11/13 20:30:30 - mmengine - INFO - Epoch(val) [5][600/959] eta: 0:00:06 time: 0.017929 data_time: 0.000070 memory: 932
11/13 20:30:31 - mmengine - INFO - Epoch(val) [5][650/959] eta: 0:00:05 time: 0.017915 data_time: 0.000072 memory: 932
11/13 20:30:31 - mmengine - INFO - Epoch(val) [5][700/959] eta: 0:00:04 time: 0.017923 data_time: 0.000073 memory: 932
11/13 20:30:32 - mmengine - INFO - Epoch(val) [5][750/959] eta: 0:00:03 time: 0.017919 data_time: 0.000073 memory: 932
11/13 20:30:33 - mmengine - INFO - Epoch(val) [5][800/959] eta: 0:00:02 time: 0.017943 data_time: 0.000071 memory: 932
11/13 20:30:34 - mmengine - INFO - Epoch(val) [5][850/959] eta: 0:00:01 time: 0.017994 data_time: 0.000077 memory: 932
11/13 20:30:35 - mmengine - INFO - Epoch(val) [5][900/959] eta: 0:00:01 time: 0.017964 data_time: 0.000083 memory: 932
11/13 20:30:36 - mmengine - INFO - Epoch(val) [5][950/959] eta: 0:00:00 time: 0.017942 data_time: 0.000088 memory: 932
Traceback (most recent call last):
File "tools/train.py", line 164, in
main()
File "tools/train.py", line 160, in main
runner.train()
File "/home/huochewang/.conda/envs/mmpose/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "/home/huochewang/.conda/envs/mmpose/lib/python3.8/site-packages/mmengine/runner/loops.py", line 103, in run
self.runner.val_loop.run()
File "/home/huochewang/.conda/envs/mmpose/lib/python3.8/site-packages/mmengine/runner/loops.py", line 376, in run
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
File "/home/huochewang/.conda/envs/mmpose/lib/python3.8/site-packages/mmengine/evaluator/evaluator.py", line 79, in evaluate
_results = metric.evaluate(size)
File "/home/huochewang/.conda/envs/mmpose/lib/python3.8/site-packages/mmengine/evaluator/metric.py", line 133, in evaluate
_metrics = self.compute_metrics(results) # type: ignore
File "/home/huochewang/桌面/mmpose/mmpose/evaluation/metrics/coco_metric.py", line 488, in compute_metrics
self.results2json(valid_kpts, outfile_prefix=outfile_prefix)
File "/home/huochewang/桌面/mmpose/mmpose/evaluation/metrics/coco_metric.py", line 529, in results2json
_keypoints = _keypoints.reshape(-1, num_keypoints * 3)
ValueError: cannot reshape array of size 21 into shape (51)

Additional information

hello I want to use RTMO to detect the posture of the robotic arm. I defined 7 joint points, but there was an error.
The config file for the dataset is shown below:
dataset_info = dict(
dataset_name='panda',
paper_info=dict(
author='DREAM',
title='for 702 test',
container='NO',
year='2024',
homepage='no',
),
keypoint_info={
0:
dict(name='panda_link0', id=0, color=[51, 153, 255], type='', swap=''),
1:
dict(
name='panda_link2',
id=1,
color=[51, 153, 255],
type='',
swap=''),
2:
dict(
name='panda_link3',
id=2,
color=[51, 153, 255],
type='',
swap=''),
3:
dict(
name='panda_link4',
id=3,
color=[51, 153, 255],
type='',
swap=''),
4:
dict(
name='panda_link6',
id=4,
color=[51, 153, 255],
type='',
swap=''),
5:
dict(
name='panda_link7',
id=5,
color=[0, 255, 0],
type='',
swap=''),
6:
dict(
name='panda_hand',
id=6,
color=[255, 128, 0],
type='',
swap=''),
},
skeleton_info={
0:
dict(link=('panda_link0', 'panda_link2'), id=0, color=[0, 255, 0]),
1:
dict(link=('panda_link2', 'panda_link3'), id=1, color=[0, 255, 0]),
2:
dict(link=('panda_link3', 'panda_link4'), id=2, color=[255, 128, 0]),
3:
dict(link=('panda_link4', 'panda_link6'), id=3, color=[255, 128, 0]),
4:
dict(link=('panda_link6', 'panda_link7'), id=4, color=[51, 153, 255]),
5:
dict(link=('panda_link7', 'panda_hand'), id=5, color=[51, 153, 255])
},
joint_weights=[
1., 1., 1., 1., 1., 1., 1.
],
sigmas=[
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079
])
The config file of the model is shown below
base = ['../../../base/default_runtime.py']

runtime

train_cfg = dict(max_epochs=20, val_interval=5, dynamic_intervals=[(18, 1)])

auto_scale_lr = dict(base_batch_size=3)

default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))

optim_wrapper = dict(
type='OptimWrapper',
constructor='ForceDefaultOptimWrapperConstructor',
optimizer=dict(type='AdamW', lr=0.0004, weight_decay=0.05),
paramwise_cfg=dict(
norm_decay_mult=0,
bias_decay_mult=0,
bypass_duplicate=True,
force_default_settings=True,
custom_keys=dict({'neck.encoder': dict(lr_mult=0.05)})),
clip_grad=dict(max_norm=0.1, norm_type=2))

param_scheduler = [
dict(
type='QuadraticWarmupLR',
by_epoch=True,
begin=0,
end=4,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
eta_min=0.0002,
begin=4,
T_max=10,
end=10,
by_epoch=True,
convert_to_iter_based=True),
# this scheduler is used to increase the lr from 2e-4 to 5e-4
dict(type='ConstantLR', by_epoch=True, factor=2.5, begin=10, end=11),
dict(
type='CosineAnnealingLR',
eta_min=0.0002,
begin=11,
T_max=7,
end=17,
by_epoch=True,
convert_to_iter_based=True),
dict(type='ConstantLR', by_epoch=True, factor=1, begin=17, end=20),
]

data

input_size = (640, 480)
metafile = 'configs/base/datasets/panda.py'
codec = dict(type='YOLOXPoseAnnotationProcessor', input_size=input_size)

train_pipeline_stage1 = [
dict(type='LoadImage', backend_args=None),
dict(type='FilterAnnotations', by_kpt=True, by_box=True, keep_empty=False),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs'),
]
train_pipeline_stage2 = [
dict(type='LoadImage'),
dict(type='BottomupGetHeatmapMask', get_invalid=True),
dict(type='FilterAnnotations', by_kpt=True, by_box=True, keep_empty=False),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs'),
]

data_mode = 'bottomup'
data_root = 'data/'

train datasets

dataset_coco = dict(
type='CocoDataset',
data_root=data_root,
data_mode=data_mode,
ann_file='panda/ann/trainval.json',
data_prefix=dict(img='panda/images/'),
pipeline=train_pipeline_stage1,
)

train_dataloader = dict(
batch_size=3,
num_workers=6,
persistent_workers=True,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dataset_coco)

val_pipeline = [
dict(type='LoadImage'),
dict(
type='BottomupResize', input_size=input_size, pad_val=(114, 114, 114)),
dict(
type='PackPoseInputs',
meta_keys=('id', 'img_id', 'img_path', 'ori_shape', 'img_shape',
'input_size', 'input_center', 'input_scale'))
]

val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type='CocoDataset',
data_root=data_root,
data_mode=data_mode,
ann_file='panda/ann/test.json',
data_prefix=dict(img='panda/images/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader

evaluators

val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'panda/ann/test.json',
score_mode='bbox',
nms_mode='none',
)
test_evaluator = val_evaluator

hooks

custom_hooks = [
dict(
type='YOLOXPoseModeSwitchHook',
num_last_epochs=5,
new_train_pipeline=train_pipeline_stage2,
priority=48),
dict(
type='RTMOModeSwitchHook',
epoch_attributes={
7: {
'proxy_target_cc': True,
'overlaps_power': 1.0,
'loss_cls.loss_weight': 2.0,
'loss_mle.loss_weight': 5.0,
'loss_oks.loss_weight': 10.0
},
},
priority=48),
dict(type='SyncNormHook', priority=48),
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0002,
update_buffers=True,
strict_load=False,
priority=49),
]

model

widen_factor = 1.0
deepen_factor = 1.0

model = dict(
type='BottomupPoseEstimator',
init_cfg=dict(
type='Kaiming',#数据初始化方法
layer='Conv2d',
a=2.23606797749979,
distribution='uniform',
mode='fan_in',
nonlinearity='leaky_relu'),
data_preprocessor=dict(
type='PoseDataPreprocessor',#数据预处理
pad_size_divisor=32,
mean=[0, 0, 0],#可能是前面已经标准化过了
std=[1, 1, 1],
batch_augments=[
dict(
type='BatchSyncRandomResize',#随机resize
random_size_range=(480, 800),#尺寸范围
size_divisor=32,#倍数,为32的倍数
interval=1),
]),
backbone=dict(
type='CSPDarknet',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
out_indices=(2, 3, 4),#要输出的图像索引第2 3 4层
spp_kernal_sizes=(5, 9, 13),#指定空间金字塔使用的核大小
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),#指定归一化层的配置
act_cfg=dict(type='Swish'),#使用swish激活函数
),
neck=dict(
type='HybridEncoder',#neck部分
in_channels=[256, 512, 1024],#输入的三层的通道数
deepen_factor=deepen_factor,
widen_factor=widen_factor,
hidden_dim=256,
output_indices=[1, 2],#输出第1层和第2层
encoder_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0),#自注意力设置,包含嵌入维度、注意力头数量和 dropout 率。
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=1024,
ffn_drop=0.0,
act_cfg=dict(type='GELU'))),# 嵌入维度、前馈通道数量、dropout 率和激活函数(GELU)。
projector=dict(
type='ChannelMapper',
in_channels=[256, 256],
kernel_size=1,
out_channels=512,
act_cfg=None,
norm_cfg=dict(type='BN'),
num_outs=2)),#通过 1x1 卷积将输入特征图的通道数转换为指定的输出通道数,并进行归一化处理。
head=dict(
type='RTMOHead',
num_keypoints=7,
featmap_strides=(16, 32),#特征缩放比例
head_module_cfg=dict(
num_classes=1,
in_channels=256,#输入通道数
cls_feat_channels=256,#分类特征通道数。
channels_per_group=36,#每组通道数量,分组卷积
pose_vec_channels=512,#姿态向量通道数
widen_factor=widen_factor,
stacked_convs=2,#堆叠的卷积层的数量
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='Swish')),
assigner=dict(
#assigner 部分用于定义目标分配器(Assigner),它负责将预测的框与真实标签进行匹配
type='SimOTAAssigner',
dynamic_k_indicator='oks',
oks_calculator=dict(type='PoseOKS', metainfo=metafile)),
prior_generator=dict(
type='MlvlPointGenerator',
centralize_points=True,
strides=[16, 32]),
dcc_cfg=dict(
in_channels=512,
feat_channels=128,
num_bins=(192, 256),
spe_channels=128,
gau_cfg=dict(
s=128,
expansion_factor=2,
dropout_rate=0.0,
drop_path=0.0,
act_fn='SiLU',
pos_enc='add')),
overlaps_power=0.5,
loss_cls=dict(
type='VariFocalLoss',
reduction='sum',
use_target_weight=True,
loss_weight=1.0),
loss_bbox=dict(
type='IoULoss',
mode='square',
eps=1e-16,
reduction='sum',
loss_weight=5.0),
loss_oks=dict(
type='OKSLoss',
reduction='none',
metainfo=metafile,
loss_weight=30.0),
loss_vis=dict(
type='BCELoss',
use_target_weight=True,
reduction='mean',
loss_weight=1.0),
loss_mle=dict(
type='MLECCLoss',
use_target_weight=True,
loss_weight=1e-2,
),
loss_bbox_aux=dict(type='L1Loss', reduction='sum', loss_weight=1.0),
),
test_cfg=dict(
input_size=input_size,
score_thr=0.1,
nms_thr=0.65,
))

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant