diff --git a/mmdet3d/ops/norm.py b/mmdet3d/ops/norm.py index 52a1363d2..98ec7f117 100644 --- a/mmdet3d/ops/norm.py +++ b/mmdet3d/ops/norm.py @@ -53,11 +53,27 @@ def __init__(self, *args, **kwargs): # TODO: make mmcv fp16 utils handle customized norm layers @force_fp32(out_fp16=True) def forward(self, input): + """ + Args: + input (tensor): Has shape (N, C) or (N, C, L), where N is + the batch size, C is the number of features or + channels, and L is the sequence length + + Returns: + tensor: Has shape (N, C) or (N, C, L), has same shape + as input. + """ assert input.dtype == torch.float32, \ f'input should be in float32 type, got {input.dtype}' - if dist.get_world_size() == 1 or not self.training: + using_dist = dist.is_available() and dist.is_initialized() + if (not using_dist) or dist.get_world_size() == 1 \ + or not self.training: return super().forward(input) assert input.shape[0] > 0, 'SyncBN does not support empty inputs' + is_two_dim = input.dim() == 2 + if is_two_dim: + input = input.unsqueeze(2) + C = input.shape[1] mean = torch.mean(input, dim=[0, 2]) meansqr = torch.mean(input * input, dim=[0, 2]) @@ -76,7 +92,10 @@ def forward(self, input): bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1) bias = bias.reshape(1, -1, 1) - return input * scale + bias + output = input * scale + bias + if is_two_dim: + output = output.squeeze(2) + return output @NORM_LAYERS.register_module('naiveSyncBN2d') @@ -107,9 +126,19 @@ def __init__(self, *args, **kwargs): # TODO: make mmcv fp16 utils handle customized norm layers @force_fp32(out_fp16=True) def forward(self, input): + """ + Args: + Input (tensor): Feature has shape (N, C, H, W). + + Returns: + tensor: Has shape (N, C, H, W), same shape as input. + """ assert input.dtype == torch.float32, \ f'input should be in float32 type, got {input.dtype}' - if dist.get_world_size() == 1 or not self.training: + using_dist = dist.is_available() and dist.is_initialized() + if (not using_dist) or \ + dist.get_world_size() == 1 or \ + not self.training: return super().forward(input) assert input.shape[0] > 0, 'SyncBN does not support empty inputs'