Skip to content

Commit

Permalink
[Fix]: temporarily fix out_size issue in RoIAlignRotated (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
DangChuong-DC authored Mar 2, 2022
1 parent 583ab82 commit eeb2720
Showing 1 changed file with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,21 @@ def forward(self, feats, rois, roi_scale_factor=None):
Returns:
torch.Tensor: Scaled RoI features.
"""
out_size = self.roi_layers[0].out_size
if isinstance(self.roi_layers[0], ops.RiRoIAlignRotated):
out_size = nn.modules.utils._pair(self.roi_layers[0].out_size)
else:
out_size = self.roi_layers[0].output_size
num_levels = len(feats)
expand_dims = (-1, self.out_channels * out_size * out_size)
expand_dims = (-1, self.out_channels * out_size[0] * out_size[1])
if torch.onnx.is_in_onnx_export():
# Work around to export mask-rcnn to onnx
roi_feats = rois[:, :1].clone().detach()
roi_feats = roi_feats.expand(*expand_dims)
roi_feats = roi_feats.reshape(-1, self.out_channels, out_size,
out_size)
roi_feats = roi_feats.reshape(-1, self.out_channels, *out_size)
roi_feats = roi_feats * 0
else:
roi_feats = feats[0].new_zeros(
rois.size(0), self.out_channels, out_size, out_size)
rois.size(0), self.out_channels, *out_size)
# TODO: remove this when parrots supports
if torch.__version__ == 'parrots':
roi_feats.requires_grad = True
Expand Down

0 comments on commit eeb2720

Please # to comment.