From b9518bc0ff71f515a26145b346735ea8da263f08 Mon Sep 17 00:00:00 2001 From: DC Date: Wed, 2 Mar 2022 15:42:42 +0900 Subject: [PATCH] temporary fixes for issue #40 --- .../rotate_single_level_roi_extractor.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mmrotate/models/roi_heads/roi_extractors/rotate_single_level_roi_extractor.py b/mmrotate/models/roi_heads/roi_extractors/rotate_single_level_roi_extractor.py index ba6fdf47c..d0b1604d2 100644 --- a/mmrotate/models/roi_heads/roi_extractors/rotate_single_level_roi_extractor.py +++ b/mmrotate/models/roi_heads/roi_extractors/rotate_single_level_roi_extractor.py @@ -98,19 +98,21 @@ def forward(self, feats, rois, roi_scale_factor=None): Returns: torch.Tensor: Scaled RoI features. """ - out_size = self.roi_layers[0].out_size + if isinstance(self.roi_layers[0], ops.RiRoIAlignRotated): + out_size = nn.modules.utils._pair(self.roi_layers[0].out_size) + else: + out_size = self.roi_layers[0].output_size num_levels = len(feats) - expand_dims = (-1, self.out_channels * out_size * out_size) + expand_dims = (-1, self.out_channels * out_size[0] * out_size[1]) if torch.onnx.is_in_onnx_export(): # Work around to export mask-rcnn to onnx roi_feats = rois[:, :1].clone().detach() roi_feats = roi_feats.expand(*expand_dims) - roi_feats = roi_feats.reshape(-1, self.out_channels, out_size, - out_size) + roi_feats = roi_feats.reshape(-1, self.out_channels, *out_size) roi_feats = roi_feats * 0 else: roi_feats = feats[0].new_zeros( - rois.size(0), self.out_channels, out_size, out_size) + rois.size(0), self.out_channels, *out_size) # TODO: remove this when parrots supports if torch.__version__ == 'parrots': roi_feats.requires_grad = True