diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 1d730d863b..3776857229 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -234,6 +234,8 @@ class VisionTransformer(nn.Module): and its variants only. Default: False. final_norm (bool): Whether to add a additional layer to normalize final feature map. Default: False. + out_reshape (str): Select the output format of feature information. + Default: NCHW. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Default: bicubic. with_cls_token (bool): If concatenating class token into image tokens @@ -261,6 +263,7 @@ def __init__(self, act_cfg=dict(type='GELU'), norm_eval=False, final_norm=False, + out_shape='NCHW', with_cls_token=True, interpolate_mode='bicubic', with_cp=False): @@ -303,6 +306,11 @@ def __init__(self, with_cp=with_cp) for i in range(depth) ]) + assert out_shape in ['NLC', + 'NCHW'], 'output shape must be "NLC" or "NCHW".' + + self.out_shape = out_shape + self.interpolate_mode = interpolate_mode self.final_norm = final_norm if final_norm: @@ -443,10 +451,11 @@ def forward(self, inputs): out = x[:, 1:] else: out = x - B, _, C = out.shape - out = out.reshape(B, inputs.shape[2] // self.patch_size, - inputs.shape[3] // self.patch_size, - C).permute(0, 3, 1, 2) + if self.out_shape == 'NCHW': + B, _, C = out.shape + out = out.reshape(B, inputs.shape[2] // self.patch_size, + inputs.shape[3] // self.patch_size, + C).permute(0, 3, 1, 2) outs.append(out) return tuple(outs) diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py index c36894ec92..1ec42d34ea 100644 --- a/tests/test_models/test_backbones/test_vit.py +++ b/tests/test_models/test_backbones/test_vit.py @@ -30,6 +30,10 @@ def test_vit_backbone(): model = VisionTransformer() model(x) + with pytest.raises(AssertionError): + # out_shape must be 'NLC' or 'NCHW;' + VisionTransformer(out_shape='NCL') + # Test img_size isinstance int imgs = torch.randn(1, 3, 224, 224) model = VisionTransformer(img_size=224) @@ -72,3 +76,9 @@ def test_vit_backbone(): imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert feat[-1].shape == (1, 768, 14, 14) + + # Test final reshape arg + imgs = torch.randn(1, 3, 224, 224) + model = VisionTransformer(out_shape='NLC') + feat = model(imgs) + assert feat[-1].shape == (1, 196, 768)