Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix NaiveSyncBatchNorm1d and NaiveSyncBatchNorm2d #1435

Merged
merged 5 commits into from
Apr 28, 2022
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions mmdet3d/ops/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,27 @@ def __init__(self, *args, **kwargs):
# TODO: make mmcv fp16 utils handle customized norm layers
@force_fp32(out_fp16=True)
def forward(self, input):
"""
Args:
input (tensor): Has shape (N, C) or (N, C, L), where N is
the batch size, C is the number of features or
channels, and L is the sequence length

Returns:
tensor: Has shape (N, C) or (N, C, L), has same shape
as input.
"""
assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}'
if dist.get_world_size() == 1 or not self.training:
using_dist = dist.is_available() and dist.is_initialized()
if (not using_dist) or dist.get_world_size() == 1 \
or not self.training:
return super().forward(input)
assert input.shape[0] > 0, 'SyncBN does not support empty inputs'
is_two_dim = input.dim() == 2
if is_two_dim:
input.unsqueeze_(2)
jshilong marked this conversation as resolved.
Show resolved Hide resolved
jshilong marked this conversation as resolved.
Show resolved Hide resolved

C = input.shape[1]
mean = torch.mean(input, dim=[0, 2])
meansqr = torch.mean(input * input, dim=[0, 2])
Expand All @@ -76,7 +92,10 @@ def forward(self, input):
bias = self.bias - mean * scale
scale = scale.reshape(1, -1, 1)
bias = bias.reshape(1, -1, 1)
return input * scale + bias
output = input * scale + bias
if is_two_dim:
output = output.squeeze(2)
return output


@NORM_LAYERS.register_module('naiveSyncBN2d')
Expand Down Expand Up @@ -107,9 +126,19 @@ def __init__(self, *args, **kwargs):
# TODO: make mmcv fp16 utils handle customized norm layers
@force_fp32(out_fp16=True)
def forward(self, input):
"""
Args:
Input (tensor): Feature has shape (N, C, H, W).

Returns:
tensor: Has shape (N, C, H, W), same shape as input.
"""
assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}'
if dist.get_world_size() == 1 or not self.training:
using_dist = dist.is_available() and dist.is_initialized()
if (not using_dist) or \
dist.get_world_size() == 1 or \
not self.training:
return super().forward(input)

assert input.shape[0] > 0, 'SyncBN does not support empty inputs'
Expand Down