You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
def parse_args():
parser = argparse.ArgumentParser(description='Train a pose model')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume',
nargs='?',
type=str,
const='auto',
help='If specify checkpint path, resume from it, while if not '
'specify, try to auto resume from the latest checkpoint '
'in the work directory.')
parser.add_argument(
'--amp',
action='store_true',
default=False,
help='enable automatic-mixed-precision training')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
parser.add_argument(
'--auto-scale-lr',
action='store_true',
help='whether to auto scale the learning rate according to the '
'actual batch size and the original batch size.')
parser.add_argument(
'--show-dir',
help='directory where the visualization images will be saved.')
parser.add_argument(
'--show',
action='store_true',
help='whether to display the prediction results in a window.')
parser.add_argument(
'--interval',
type=int,
default=1,
help='visualize per interval samples.')
parser.add_argument(
'--wait-time',
type=float,
default=1,
help='display time of every window. (second)')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
# When using PyTorch version >= 2.0.0, the torch.distributed.launch
# will pass the --local-rank parameter to tools/train.py instead
# of --local_rank.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def merge_args(cfg, args):
"""Merge CLI arguments to config."""
if args.no_validate:
cfg.val_cfg = None
cfg.val_dataloader = None
cfg.val_evaluator = None
cfg.launcher = args.launcher
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
# enable automatic-mixed-precision training
if args.amp is True:
from mmengine.optim import AmpOptimWrapper, OptimWrapper
optim_wrapper = cfg.optim_wrapper.get('type', OptimWrapper)
assert optim_wrapper in (OptimWrapper, AmpOptimWrapper,
'OptimWrapper', 'AmpOptimWrapper'), \
'`--amp` is not supported custom optimizer wrapper type ' \
f'`{optim_wrapper}.'
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
# resume training
if args.resume == 'auto':
cfg.resume = True
cfg.load_from = None
elif args.resume is not None:
cfg.resume = True
cfg.load_from = args.resume
# enable auto scale learning rate
if args.auto_scale_lr:
cfg.auto_scale_lr.enable = True
# visualization
if args.show or (args.show_dir is not None):
assert 'visualization' in cfg.default_hooks, \
'PoseVisualizationHook is not set in the ' \
'`default_hooks` field of config. Please set ' \
'`visualization=dict(type="PoseVisualizationHook")`'
cfg.default_hooks.visualization.enable = True
cfg.default_hooks.visualization.show = args.show
if args.show:
cfg.default_hooks.visualization.wait_time = args.wait_time
cfg.default_hooks.visualization.out_dir = args.show_dir
cfg.default_hooks.visualization.interval = args.interval
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
return cfg
def main():
args = parse_args()
# load config
cfg = Config.fromfile(args.config)
# merge CLI arguments to config
cfg = merge_args(cfg, args)
# set preprocess configs to model
#如果 cfg.model 中没有 data_preprocessor,那么将 preprocess_cfg 的配置作为 data_preprocessor 的值进行设置。
#如果 preprocess_cfg 在配置中存在,则将其传递给模型的 data_preprocessor。
if 'preprocess_cfg' in cfg:
cfg.model.setdefault('data_preprocessor',
cfg.get('preprocess_cfg', {}))
# build the runner from config
runner = Runner.from_cfg(cfg)
# start training
runner.train()
Prerequisite
Environment
Package Version Editable project location
actionlib 1.14.0
addict 2.4.0
aenum 3.1.15
aliyun-python-sdk-core 2.15.1
aliyun-python-sdk-kms 2.16.3
angles 1.9.13
attrs 24.1.0
bondpy 1.8.6
camera-calibration 1.17.0
camera-calibration-parsers 1.12.0
catkin 0.8.10
certifi 2024.7.4
cffi 1.16.0
charset-normalizer 3.3.2
chumpy 0.70
click 8.1.7
colorama 0.4.6
coloredlogs 15.0.1
contourpy 1.1.1
controller-manager 0.20.0
controller-manager-msgs 0.20.0
coverage 7.6.1
crcmod 1.7
crowdposetools 2.0
cryptography 43.0.0
cv-bridge 1.16.2
cycler 0.12.1
Cython 3.0.11
diagnostic-analysis 1.11.0
diagnostic-common-diagnostics 1.11.0
diagnostic-updater 1.11.0
dill 0.3.8
dynamic-reconfigure 1.7.3
exceptiongroup 1.2.2
filelock 3.14.0
flake8 7.1.1
flatbuffers 24.3.25
fonttools 4.53.1
gazebo_plugins 2.9.2
gazebo_ros 2.9.2
gencpp 0.7.0
geneus 3.0.0
genlisp 0.4.18
genmsg 0.6.0
gennodejs 2.0.2
genpy 0.6.15
grpcio 1.65.4
humanfriendly 10.0
idna 3.7
image-geometry 1.16.2
importlib_metadata 8.2.0
importlib_resources 6.4.0
iniconfig 2.0.0
interactive-markers 1.12.0
interrogate 1.7.0
isort 4.3.21
jmespath 0.10.0
joint-state-publisher 1.15.1
joint-state-publisher-gui 1.15.1
json-tricks 3.17.3
kiwisolver 1.4.5
laser_geometry 1.6.7
Markdown 3.6
markdown-it-py 3.0.0
matplotlib 3.7.5
mccabe 0.7.0
mdurl 0.1.2
message-filters 1.16.0
mmcv 2.1.0
mmdeploy 1.3.1 /home/huochewang/桌面/mmdeploy
mmdeploy-runtime-gpu 1.3.1
mmdet 3.3.0
mmengine 0.10.4
mmpose 1.3.2 /home/huochewang/桌面/mmpose
model-index 0.1.11
mpmath 1.3.0
multiprocess 0.70.16
munkres 1.1.4
numpy 1.24.4
onnx 1.16.2
onnx-simplifier 0.4.36
onnxruntime-gpu 1.14.0
opencv-python 4.10.0.84
opendatalab 0.0.10
openmim 0.3.9
openxlab 0.1.1
ordered-set 4.1.0
oss2 2.17.0
packaging 24.1
pandas 2.0.3
parameterized 0.9.0
pillow 10.4.0
pip 24.0
platformdirs 4.2.2
pluggy 1.5.0
prettytable 3.10.2
protobuf 3.20.2
py 1.11.0
pycocotools 2.0.7
pycodestyle 2.12.1
pycparser 2.22
pycryptodome 3.20.0
pyflakes 3.2.0
Pygments 2.18.0
pyparsing 3.1.2
pytest 8.3.2
pytest-runner 6.0.1
python-dateutil 2.9.0.post0
python-qt-binding 0.4.4
pytz 2023.4
PyYAML 6.0.1
qt-dotgraph 0.4.2
qt-gui 0.4.2
qt-gui-cpp 0.4.2
qt-gui-py-common 0.4.2
requests 2.28.2
resource_retriever 1.12.7
rich 13.4.2
rosbag 1.16.0
rosboost-cfg 1.15.8
rosclean 1.15.8
roscreate 1.15.8
rosgraph 1.16.0
roslaunch 1.16.0
roslib 1.15.8
roslint 0.12.0
roslz4 1.16.0
rosmake 1.15.8
rosmaster 1.16.0
rosmsg 1.16.0
rosnode 1.16.0
rosparam 1.16.0
rospy 1.16.0
rosservice 1.16.0
rostest 1.16.0
rostopic 1.16.0
rosunit 1.15.8
roswtf 1.16.0
rqt_action 0.4.9
rqt_bag 0.5.1
rqt_bag_plugins 0.5.1
rqt-console 0.4.12
rqt_dep 0.4.12
rqt_graph 0.4.14
rqt_gui 0.5.3
rqt_gui_py 0.5.3
rqt-image-view 0.4.17
rqt_launch 0.4.9
rqt-logger-level 0.4.12
rqt-moveit 0.5.11
rqt_msg 0.4.10
rqt_nav_view 0.5.7
rqt_plot 0.4.13
rqt_pose_view 0.5.11
rqt_publisher 0.4.10
rqt_py_common 0.5.3
rqt_py_console 0.4.10
rqt-reconfigure 0.5.5
rqt-robot-dashboard 0.5.8
rqt-robot-monitor 0.5.15
rqt_robot_steering 0.5.12
rqt-runtime-monitor 0.5.10
rqt-rviz 0.7.0
rqt_service_caller 0.4.10
rqt_shell 0.4.11
rqt_srv 0.4.9
rqt-tf-tree 0.6.4
rqt_top 0.4.10
rqt_topic 0.4.13
rqt_web 0.4.10
rviz 1.14.25
scipy 1.10.1
sensor-msgs 1.13.1
setuptools 60.2.0
shapely 2.0.5
six 1.16.0
smach 2.5.2
smach-ros 2.5.2
smclib 1.8.6
sympy 1.13.1
tabulate 0.9.0
termcolor 2.4.0
terminaltables 3.1.10
tf 1.13.2
tf-conversions 1.13.2
tf2-geometry-msgs 0.7.7
tf2-kdl 0.7.7
tf2-py 0.7.7
tf2-ros 0.7.7
tomli 2.0.1
topic-tools 1.16.0
torch 1.12.1+cu116
torchaudio 0.12.1+cu116
torchvision 0.13.1+cu116
tqdm 4.65.2
typing_extensions 4.12.2
tzdata 2024.1
urllib3 1.26.19
wcwidth 0.2.13
wheel 0.43.0
xacro 1.14.18
xdoctest 1.1.6
xtcocotools 1.14.3
yapf 0.40.2
zipp 3.19.2
Reproduces the problem - code sample
Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
def parse_args():
parser = argparse.ArgumentParser(description='Train a pose model')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume',
nargs='?',
type=str,
const='auto',
help='If specify checkpint path, resume from it, while if not '
'specify, try to auto resume from the latest checkpoint '
'in the work directory.')
parser.add_argument(
'--amp',
action='store_true',
default=False,
help='enable automatic-mixed-precision training')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
parser.add_argument(
'--auto-scale-lr',
action='store_true',
help='whether to auto scale the learning rate according to the '
'actual batch size and the original batch size.')
parser.add_argument(
'--show-dir',
help='directory where the visualization images will be saved.')
parser.add_argument(
'--show',
action='store_true',
help='whether to display the prediction results in a window.')
parser.add_argument(
'--interval',
type=int,
default=1,
help='visualize per interval samples.')
parser.add_argument(
'--wait-time',
type=float,
default=1,
help='display time of every window. (second)')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
# When using PyTorch version >= 2.0.0, the
torch.distributed.launch
# will pass the
--local-rank
parameter totools/train.py
instead# of
--local_rank
.parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
def merge_args(cfg, args):
"""Merge CLI arguments to config."""
if args.no_validate:
cfg.val_cfg = None
cfg.val_dataloader = None
cfg.val_evaluator = None
def main():
args = parse_args()
if name == 'main':
main()
Reproduces the problem - command or script
python tools/train.py configs/body_2d_keypoint/rtmo/coco/rtmo-l_16xb16-600e_coco-640x640arm.py --amp
Reproduces the problem - error message
11/13 20:30:16 - mmengine - INFO - Epoch(train) [5][1800/1812] base_lr: 3.951751e-04 lr: 3.951751e-04 eta: 1:38:27 time: 0.229064 data_time: 0.001469 memory: 4821 grad_norm: 107.484255 loss: 7.052426 loss_bbox: 0.826151 loss_vis: 0.000011 loss_mle: -0.047193 loss_oks: 5.877753 loss_cls: 0.395704 num_samples: 17.000000 overlaps: 0.750766
11/13 20:30:19 - mmengine - INFO - Exp name: rtmo-l_16xb16-600e_coco-640x640arm_20241113_195726
11/13 20:30:20 - mmengine - INFO - Epoch(val) [5][ 50/959] eta: 0:00:17 time: 0.018845 data_time: 0.000969 memory: 4557
11/13 20:30:21 - mmengine - INFO - Epoch(val) [5][100/959] eta: 0:00:15 time: 0.017900 data_time: 0.000073 memory: 932
11/13 20:30:22 - mmengine - INFO - Epoch(val) [5][150/959] eta: 0:00:14 time: 0.018003 data_time: 0.000085 memory: 932
11/13 20:30:22 - mmengine - INFO - Epoch(val) [5][200/959] eta: 0:00:13 time: 0.017905 data_time: 0.000072 memory: 932
11/13 20:30:23 - mmengine - INFO - Epoch(val) [5][250/959] eta: 0:00:12 time: 0.017939 data_time: 0.000071 memory: 932
11/13 20:30:24 - mmengine - INFO - Epoch(val) [5][300/959] eta: 0:00:11 time: 0.017909 data_time: 0.000073 memory: 932
11/13 20:30:25 - mmengine - INFO - Epoch(val) [5][350/959] eta: 0:00:10 time: 0.017907 data_time: 0.000073 memory: 932
11/13 20:30:26 - mmengine - INFO - Epoch(val) [5][400/959] eta: 0:00:10 time: 0.017903 data_time: 0.000073 memory: 932
11/13 20:30:27 - mmengine - INFO - Epoch(val) [5][450/959] eta: 0:00:09 time: 0.017901 data_time: 0.000074 memory: 932
11/13 20:30:28 - mmengine - INFO - Epoch(val) [5][500/959] eta: 0:00:08 time: 0.017903 data_time: 0.000073 memory: 932
11/13 20:30:29 - mmengine - INFO - Epoch(val) [5][550/959] eta: 0:00:07 time: 0.017891 data_time: 0.000069 memory: 932
11/13 20:30:30 - mmengine - INFO - Epoch(val) [5][600/959] eta: 0:00:06 time: 0.017929 data_time: 0.000070 memory: 932
11/13 20:30:31 - mmengine - INFO - Epoch(val) [5][650/959] eta: 0:00:05 time: 0.017915 data_time: 0.000072 memory: 932
11/13 20:30:31 - mmengine - INFO - Epoch(val) [5][700/959] eta: 0:00:04 time: 0.017923 data_time: 0.000073 memory: 932
11/13 20:30:32 - mmengine - INFO - Epoch(val) [5][750/959] eta: 0:00:03 time: 0.017919 data_time: 0.000073 memory: 932
11/13 20:30:33 - mmengine - INFO - Epoch(val) [5][800/959] eta: 0:00:02 time: 0.017943 data_time: 0.000071 memory: 932
11/13 20:30:34 - mmengine - INFO - Epoch(val) [5][850/959] eta: 0:00:01 time: 0.017994 data_time: 0.000077 memory: 932
11/13 20:30:35 - mmengine - INFO - Epoch(val) [5][900/959] eta: 0:00:01 time: 0.017964 data_time: 0.000083 memory: 932
11/13 20:30:36 - mmengine - INFO - Epoch(val) [5][950/959] eta: 0:00:00 time: 0.017942 data_time: 0.000088 memory: 932
Traceback (most recent call last):
File "tools/train.py", line 164, in
main()
File "tools/train.py", line 160, in main
runner.train()
File "/home/huochewang/.conda/envs/mmpose/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "/home/huochewang/.conda/envs/mmpose/lib/python3.8/site-packages/mmengine/runner/loops.py", line 103, in run
self.runner.val_loop.run()
File "/home/huochewang/.conda/envs/mmpose/lib/python3.8/site-packages/mmengine/runner/loops.py", line 376, in run
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
File "/home/huochewang/.conda/envs/mmpose/lib/python3.8/site-packages/mmengine/evaluator/evaluator.py", line 79, in evaluate
_results = metric.evaluate(size)
File "/home/huochewang/.conda/envs/mmpose/lib/python3.8/site-packages/mmengine/evaluator/metric.py", line 133, in evaluate
_metrics = self.compute_metrics(results) # type: ignore
File "/home/huochewang/桌面/mmpose/mmpose/evaluation/metrics/coco_metric.py", line 488, in compute_metrics
self.results2json(valid_kpts, outfile_prefix=outfile_prefix)
File "/home/huochewang/桌面/mmpose/mmpose/evaluation/metrics/coco_metric.py", line 529, in results2json
_keypoints = _keypoints.reshape(-1, num_keypoints * 3)
ValueError: cannot reshape array of size 21 into shape (51)
Additional information
hello I want to use RTMO to detect the posture of the robotic arm. I defined 7 joint points, but there was an error.
The config file for the dataset is shown below:
dataset_info = dict(
dataset_name='panda',
paper_info=dict(
author='DREAM',
title='for 702 test',
container='NO',
year='2024',
homepage='no',
),
keypoint_info={
0:
dict(name='panda_link0', id=0, color=[51, 153, 255], type='', swap=''),
1:
dict(
name='panda_link2',
id=1,
color=[51, 153, 255],
type='',
swap=''),
2:
dict(
name='panda_link3',
id=2,
color=[51, 153, 255],
type='',
swap=''),
3:
dict(
name='panda_link4',
id=3,
color=[51, 153, 255],
type='',
swap=''),
4:
dict(
name='panda_link6',
id=4,
color=[51, 153, 255],
type='',
swap=''),
5:
dict(
name='panda_link7',
id=5,
color=[0, 255, 0],
type='',
swap=''),
6:
dict(
name='panda_hand',
id=6,
color=[255, 128, 0],
type='',
swap=''),
},
skeleton_info={
0:
dict(link=('panda_link0', 'panda_link2'), id=0, color=[0, 255, 0]),
1:
dict(link=('panda_link2', 'panda_link3'), id=1, color=[0, 255, 0]),
2:
dict(link=('panda_link3', 'panda_link4'), id=2, color=[255, 128, 0]),
3:
dict(link=('panda_link4', 'panda_link6'), id=3, color=[255, 128, 0]),
4:
dict(link=('panda_link6', 'panda_link7'), id=4, color=[51, 153, 255]),
5:
dict(link=('panda_link7', 'panda_hand'), id=5, color=[51, 153, 255])
},
joint_weights=[
1., 1., 1., 1., 1., 1., 1.
],
sigmas=[
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079
])
The config file of the model is shown below
base = ['../../../base/default_runtime.py']
runtime
train_cfg = dict(max_epochs=20, val_interval=5, dynamic_intervals=[(18, 1)])
auto_scale_lr = dict(base_batch_size=3)
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))
optim_wrapper = dict(
type='OptimWrapper',
constructor='ForceDefaultOptimWrapperConstructor',
optimizer=dict(type='AdamW', lr=0.0004, weight_decay=0.05),
paramwise_cfg=dict(
norm_decay_mult=0,
bias_decay_mult=0,
bypass_duplicate=True,
force_default_settings=True,
custom_keys=dict({'neck.encoder': dict(lr_mult=0.05)})),
clip_grad=dict(max_norm=0.1, norm_type=2))
param_scheduler = [
dict(
type='QuadraticWarmupLR',
by_epoch=True,
begin=0,
end=4,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
eta_min=0.0002,
begin=4,
T_max=10,
end=10,
by_epoch=True,
convert_to_iter_based=True),
# this scheduler is used to increase the lr from 2e-4 to 5e-4
dict(type='ConstantLR', by_epoch=True, factor=2.5, begin=10, end=11),
dict(
type='CosineAnnealingLR',
eta_min=0.0002,
begin=11,
T_max=7,
end=17,
by_epoch=True,
convert_to_iter_based=True),
dict(type='ConstantLR', by_epoch=True, factor=1, begin=17, end=20),
]
data
input_size = (640, 480)
metafile = 'configs/base/datasets/panda.py'
codec = dict(type='YOLOXPoseAnnotationProcessor', input_size=input_size)
train_pipeline_stage1 = [
dict(type='LoadImage', backend_args=None),
dict(type='FilterAnnotations', by_kpt=True, by_box=True, keep_empty=False),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs'),
]
train_pipeline_stage2 = [
dict(type='LoadImage'),
dict(type='BottomupGetHeatmapMask', get_invalid=True),
dict(type='FilterAnnotations', by_kpt=True, by_box=True, keep_empty=False),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs'),
]
data_mode = 'bottomup'
data_root = 'data/'
train datasets
dataset_coco = dict(
type='CocoDataset',
data_root=data_root,
data_mode=data_mode,
ann_file='panda/ann/trainval.json',
data_prefix=dict(img='panda/images/'),
pipeline=train_pipeline_stage1,
)
train_dataloader = dict(
batch_size=3,
num_workers=6,
persistent_workers=True,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dataset_coco)
val_pipeline = [
dict(type='LoadImage'),
dict(
type='BottomupResize', input_size=input_size, pad_val=(114, 114, 114)),
dict(
type='PackPoseInputs',
meta_keys=('id', 'img_id', 'img_path', 'ori_shape', 'img_shape',
'input_size', 'input_center', 'input_scale'))
]
val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type='CocoDataset',
data_root=data_root,
data_mode=data_mode,
ann_file='panda/ann/test.json',
data_prefix=dict(img='panda/images/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader
evaluators
val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'panda/ann/test.json',
score_mode='bbox',
nms_mode='none',
)
test_evaluator = val_evaluator
hooks
custom_hooks = [
dict(
type='YOLOXPoseModeSwitchHook',
num_last_epochs=5,
new_train_pipeline=train_pipeline_stage2,
priority=48),
dict(
type='RTMOModeSwitchHook',
epoch_attributes={
7: {
'proxy_target_cc': True,
'overlaps_power': 1.0,
'loss_cls.loss_weight': 2.0,
'loss_mle.loss_weight': 5.0,
'loss_oks.loss_weight': 10.0
},
},
priority=48),
dict(type='SyncNormHook', priority=48),
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0002,
update_buffers=True,
strict_load=False,
priority=49),
]
model
widen_factor = 1.0
deepen_factor = 1.0
model = dict(
type='BottomupPoseEstimator',
init_cfg=dict(
type='Kaiming',#数据初始化方法
layer='Conv2d',
a=2.23606797749979,
distribution='uniform',
mode='fan_in',
nonlinearity='leaky_relu'),
data_preprocessor=dict(
type='PoseDataPreprocessor',#数据预处理
pad_size_divisor=32,
mean=[0, 0, 0],#可能是前面已经标准化过了
std=[1, 1, 1],
batch_augments=[
dict(
type='BatchSyncRandomResize',#随机resize
random_size_range=(480, 800),#尺寸范围
size_divisor=32,#倍数,为32的倍数
interval=1),
]),
backbone=dict(
type='CSPDarknet',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
out_indices=(2, 3, 4),#要输出的图像索引第2 3 4层
spp_kernal_sizes=(5, 9, 13),#指定空间金字塔使用的核大小
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),#指定归一化层的配置
act_cfg=dict(type='Swish'),#使用swish激活函数
),
neck=dict(
type='HybridEncoder',#neck部分
in_channels=[256, 512, 1024],#输入的三层的通道数
deepen_factor=deepen_factor,
widen_factor=widen_factor,
hidden_dim=256,
output_indices=[1, 2],#输出第1层和第2层
encoder_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0),#自注意力设置,包含嵌入维度、注意力头数量和 dropout 率。
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=1024,
ffn_drop=0.0,
act_cfg=dict(type='GELU'))),# 嵌入维度、前馈通道数量、dropout 率和激活函数(GELU)。
projector=dict(
type='ChannelMapper',
in_channels=[256, 256],
kernel_size=1,
out_channels=512,
act_cfg=None,
norm_cfg=dict(type='BN'),
num_outs=2)),#通过 1x1 卷积将输入特征图的通道数转换为指定的输出通道数,并进行归一化处理。
head=dict(
type='RTMOHead',
num_keypoints=7,
featmap_strides=(16, 32),#特征缩放比例
head_module_cfg=dict(
num_classes=1,
in_channels=256,#输入通道数
cls_feat_channels=256,#分类特征通道数。
channels_per_group=36,#每组通道数量,分组卷积
pose_vec_channels=512,#姿态向量通道数
widen_factor=widen_factor,
stacked_convs=2,#堆叠的卷积层的数量
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='Swish')),
assigner=dict(
#assigner 部分用于定义目标分配器(Assigner),它负责将预测的框与真实标签进行匹配
type='SimOTAAssigner',
dynamic_k_indicator='oks',
oks_calculator=dict(type='PoseOKS', metainfo=metafile)),
prior_generator=dict(
type='MlvlPointGenerator',
centralize_points=True,
strides=[16, 32]),
dcc_cfg=dict(
in_channels=512,
feat_channels=128,
num_bins=(192, 256),
spe_channels=128,
gau_cfg=dict(
s=128,
expansion_factor=2,
dropout_rate=0.0,
drop_path=0.0,
act_fn='SiLU',
pos_enc='add')),
overlaps_power=0.5,
loss_cls=dict(
type='VariFocalLoss',
reduction='sum',
use_target_weight=True,
loss_weight=1.0),
loss_bbox=dict(
type='IoULoss',
mode='square',
eps=1e-16,
reduction='sum',
loss_weight=5.0),
loss_oks=dict(
type='OKSLoss',
reduction='none',
metainfo=metafile,
loss_weight=30.0),
loss_vis=dict(
type='BCELoss',
use_target_weight=True,
reduction='mean',
loss_weight=1.0),
loss_mle=dict(
type='MLECCLoss',
use_target_weight=True,
loss_weight=1e-2,
),
loss_bbox_aux=dict(type='L1Loss', reduction='sum', loss_weight=1.0),
),
test_cfg=dict(
input_size=input_size,
score_thr=0.1,
nms_thr=0.65,
))
The text was updated successfully, but these errors were encountered: