Skip to content

Commit

Permalink
Fix NaiveSyncBatchNorm1d and NaiveSyncBatchNorm2d (#1435)
Browse files Browse the repository at this point in the history
* add quick install command

* fix SyncBatchNorm

* fix SyncBatchNorm
  • Loading branch information
jshilong authored Apr 28, 2022
1 parent d842546 commit 16e1715
Showing 1 changed file with 32 additions and 3 deletions.
35 changes: 32 additions & 3 deletions mmdet3d/ops/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,27 @@ def __init__(self, *args, **kwargs):
# TODO: make mmcv fp16 utils handle customized norm layers
@force_fp32(out_fp16=True)
def forward(self, input):
"""
Args:
input (tensor): Has shape (N, C) or (N, C, L), where N is
the batch size, C is the number of features or
channels, and L is the sequence length
Returns:
tensor: Has shape (N, C) or (N, C, L), has same shape
as input.
"""
assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}'
if dist.get_world_size() == 1 or not self.training:
using_dist = dist.is_available() and dist.is_initialized()
if (not using_dist) or dist.get_world_size() == 1 \
or not self.training:
return super().forward(input)
assert input.shape[0] > 0, 'SyncBN does not support empty inputs'
is_two_dim = input.dim() == 2
if is_two_dim:
input = input.unsqueeze(2)

C = input.shape[1]
mean = torch.mean(input, dim=[0, 2])
meansqr = torch.mean(input * input, dim=[0, 2])
Expand All @@ -76,7 +92,10 @@ def forward(self, input):
bias = self.bias - mean * scale
scale = scale.reshape(1, -1, 1)
bias = bias.reshape(1, -1, 1)
return input * scale + bias
output = input * scale + bias
if is_two_dim:
output = output.squeeze(2)
return output


@NORM_LAYERS.register_module('naiveSyncBN2d')
Expand Down Expand Up @@ -107,9 +126,19 @@ def __init__(self, *args, **kwargs):
# TODO: make mmcv fp16 utils handle customized norm layers
@force_fp32(out_fp16=True)
def forward(self, input):
"""
Args:
Input (tensor): Feature has shape (N, C, H, W).
Returns:
tensor: Has shape (N, C, H, W), same shape as input.
"""
assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}'
if dist.get_world_size() == 1 or not self.training:
using_dist = dist.is_available() and dist.is_initialized()
if (not using_dist) or \
dist.get_world_size() == 1 or \
not self.training:
return super().forward(input)

assert input.shape[0] > 0, 'SyncBN does not support empty inputs'
Expand Down

0 comments on commit 16e1715

Please # to comment.