Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

AssertionError: target_weights and target have mismatched shapes torch.Size([128, 6]) v.s. torch.Size([32, 6, 64, 48]) #2991

Open
2 tasks done
wang1528186571 opened this issue Mar 20, 2024 · 4 comments
Assignees

Comments

@wang1528186571
Copy link

Prerequisite

Environment

mmcv 2.1.0
mmdet 3.3.0
mmengine 0.10.3
mmpose 1.3.1 /home/meng/Desktop/wjl-project/mmpose

Reproduces the problem - code sample

base = ['mmpose::base/default_runtime.py']

runtime

train_cfg = dict(max_epochs=210, val_interval=10)

optimizer

optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=2e-2,
))

learning policy

param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=210,
milestones=[170, 190, 200],
gamma=0.1,
by_epoch=True)
]

automatically scaling LR based on the actual training batch size

auto_scale_lr = dict(base_batch_size=256)

hooks

default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))

codec settings

multiple kernel_sizes of heatmap gaussian for 'Megvii' approach.

kernel_sizes = [11, 9, 7, 5]
codec = [
dict(
type='MegviiHeatmap',
input_size=(192, 256),
heatmap_size=(48, 64),
kernel_size=kernel_size) for kernel_size in kernel_sizes
]

model settings

model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='RSN',
unit_channels=256,
num_stages=1,
num_units=4,
num_blocks=[2, 2, 2, 2],
num_steps=4,
norm_cfg=dict(type='BN'),
),
head=dict(
type='MSPNHead',
out_shape=(64, 48),
unit_channels=256,
out_channels=6,
num_stages=1,
num_units=4,
norm_cfg=dict(type='BN'),
# each sub list is for a stage
# and each element in each list is for a unit
level_indices=[0, 1, 2, 3],
loss=[
dict(
type='KeypointMSELoss',
use_target_weight=True,
loss_weight=0.25)
] * 3 + [
dict(
type='KeypointOHKMMSELoss',
use_target_weight=True,
loss_weight=1.)
],
decoder=codec[-1]),
test_cfg=dict(
flip_test=True,
flip_mode='heatmap',
shift_heatmap=False,
))

base dataset settings

dataset_type = 'CocoDataset'
data_mode = 'topdown'
data_root = '/home/meng/Desktop/wjl-project/mmpose/data/Plane_coco/'

pipelines

train_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(type='RandomBBoxTransform'),
dict(type='TopdownAffine', input_size=codec[0]['input_size']),
dict(type='GenerateTarget', multilevel=True, encoder=codec),
dict(type='PackPoseInputs')
]

val_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec[0]['input_size']),
dict(type='PackPoseInputs')
]

data loaders

train_dataloader = dict(
batch_size=32,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='train_coco.json',
data_prefix=dict(img='images/'),
pipeline=train_pipeline,
metainfo=dict(from_file='configs/base/datasets/coco_Plane.py'),
))
val_dataloader = dict(
batch_size=32,
num_workers=4,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='val_coco.json',
data_prefix=dict(img='images/'),
test_mode=True,
bbox_file=None,
pipeline=val_pipeline,
metainfo=dict(from_file='configs/base/datasets/coco_Plane.py'),
))
test_dataloader = val_dataloader

evaluators

val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'val_coco.json',
nms_mode='none')
test_evaluator = val_evaluator

fp16 settings

fp16 = dict(loss_scale='dynamic')

Reproduces the problem - command or script

python tools/train.py /home/meng/Desktop/wjl-project/mmpose/data/td-hm_rsn18_8xb32-210e_coco-256x192.py

Reproduces the problem - error message

Traceback (most recent call last):
File "tools/train.py", line 162, in
main()
File "tools/train.py", line 158, in main
runner.train()
File "/home/meng/anaconda3/envs/mmpose/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "/home/meng/anaconda3/envs/mmpose/lib/python3.8/site-packages/mmengine/runner/loops.py", line 96, in run
self.run_epoch()
File "/home/meng/anaconda3/envs/mmpose/lib/python3.8/site-packages/mmengine/runner/loops.py", line 112, in run_epoch
self.run_iter(idx, data_batch)
File "/home/meng/anaconda3/envs/mmpose/lib/python3.8/site-packages/mmengine/runner/loops.py", line 128, in run_iter
outputs = self.runner.model.train_step(
File "/home/meng/anaconda3/envs/mmpose/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 114, in train_step
losses = self._run_forward(data, mode='loss') # type: ignore
File "/home/meng/anaconda3/envs/mmpose/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 361, in _run_forward
results = self(**data, mode=mode)
File "/home/meng/anaconda3/envs/mmpose/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/meng/Desktop/wjl-project/mmpose/mmpose/models/pose_estimators/base.py", line 155, in forward
return self.loss(inputs, data_samples)
File "/home/meng/Desktop/wjl-project/mmpose/mmpose/models/pose_estimators/topdown.py", line 74, in loss
self.head.loss(feats, data_samples, train_cfg=self.train_cfg))
File "/home/meng/Desktop/wjl-project/mmpose/mmpose/models/heads/heatmap_heads/mspn_head.py", line 415, in loss
loss_i = loss_func(msmu_pred_heatmaps[i], gt_heatmaps,
File "/home/meng/anaconda3/envs/mmpose/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/meng/Desktop/wjl-project/mmpose/mmpose/models/losses/heatmap_loss.py", line 63, in forward
_mask = self._get_mask(target, target_weights, mask)
File "/home/meng/Desktop/wjl-project/mmpose/mmpose/models/losses/heatmap_loss.py", line 93, in _get_mask
assert (target_weights.ndim in (2, 4) and target_weights.shape
AssertionError: target_weights and target have mismatched shapes torch.Size([128, 6]) v.s. torch.Size([32, 6, 64, 48])

Additional information

No response

@wang1528186571
Copy link
Author

please help me!

@Ben-Louis
Copy link
Collaborator

Thank you for bringing this issue to our attention! There seems to be a bug in the MSPNHead, and we will address it promptly. In the meantime, you could try using another model.

@wang1528186571
Copy link
Author

Thank you for bringing this issue to our attention! There seems to be a bug in the MSPNHead, and we will address it promptly. In the meantime, you could try using another model.

thank you! if you address please tell me ! thank you!

@Ben-Louis
Copy link
Collaborator

If you wish to use RSN, you can modify the code manually by following #2993.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants