From 7686e2260f1752761b8516a67d0d512554ae9d02 Mon Sep 17 00:00:00 2001 From: lupeng Date: Wed, 20 Mar 2024 16:14:13 +0800 Subject: [PATCH 1/2] fix loss computation in mspn head --- mmpose/models/heads/heatmap_heads/mspn_head.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mmpose/models/heads/heatmap_heads/mspn_head.py b/mmpose/models/heads/heatmap_heads/mspn_head.py index 8b7cddf798..4d2c0bfcef 100644 --- a/mmpose/models/heads/heatmap_heads/mspn_head.py +++ b/mmpose/models/heads/heatmap_heads/mspn_head.py @@ -394,7 +394,9 @@ def loss(self, keypoint_weights = torch.cat([ d.gt_instance_labels.keypoint_weights for d in batch_data_samples - ]) # shape: [B*N, L, K] + ], + dim=1) + keypoint_weights = keypoint_weights.transpose(0, 1) # [B*N, L, K] # calculate losses over multiple stages and multiple units losses = dict() From f19f2f2445be1e03163ecb4314e034e6d8b32f13 Mon Sep 17 00:00:00 2001 From: lupeng Date: Thu, 21 Mar 2024 14:15:17 +0800 Subject: [PATCH 2/2] fix ut --- .../test_heads/test_heatmap_heads/test_mspn_head.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_models/test_heads/test_heatmap_heads/test_mspn_head.py b/tests/test_models/test_heads/test_heatmap_heads/test_mspn_head.py index ce3d19b688..5643ff00ba 100644 --- a/tests/test_models/test_heads/test_heatmap_heads/test_mspn_head.py +++ b/tests/test_models/test_heads/test_heatmap_heads/test_mspn_head.py @@ -44,6 +44,7 @@ def _get_data_samples(self, with_heatmap=True, with_reg_label=False, num_levels=num_levels)['data_samples'] + return batch_data_samples def test_init(self): @@ -153,6 +154,10 @@ def test_loss(self): (unit_channels, 32, 24), (unit_channels, 64, 48)]) batch_data_samples = self._get_data_samples( batch_size=2, heatmap_size=(48, 64), num_levels=4) + for ds in batch_data_samples: + ds.gt_instance_labels = InstanceData( + keypoint_weights=ds.gt_instance_labels.keypoint_weights. + transpose(0, 1)) losses = head.loss(feats, batch_data_samples) self.assertIsInstance(losses['loss_kpt'], torch.Tensor) @@ -189,6 +194,10 @@ def test_loss(self): (unit_channels, 32, 24), (unit_channels, 64, 48)]) batch_data_samples = self._get_data_samples( batch_size=2, heatmap_size=(48, 64), num_levels=16) + for ds in batch_data_samples: + ds.gt_instance_labels = InstanceData( + keypoint_weights=ds.gt_instance_labels.keypoint_weights. + transpose(0, 1)) losses = head.loss(feats, batch_data_samples) self.assertIsInstance(losses['loss_kpt'], torch.Tensor)