diff --git a/mmdet3d/models/dense_heads/centerpoint_head.py b/mmdet3d/models/dense_heads/centerpoint_head.py index b924d48dd..ad58e5876 100644 --- a/mmdet3d/models/dense_heads/centerpoint_head.py +++ b/mmdet3d/models/dense_heads/centerpoint_head.py @@ -471,7 +471,7 @@ def get_targets_single(self, (gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]), dim=1).to(device) max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg'] - grid_size = torch.tensor(self.train_cfg['grid_size']) + grid_size = torch.tensor(self.train_cfg['grid_size']).to(device) pc_range = torch.tensor(self.train_cfg['point_cloud_range']) voxel_size = torch.tensor(self.train_cfg['voxel_size'])