-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathprecise_bn.py
33 lines (32 loc) · 1.53 KB
/
precise_bn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# adapted from https://github.com/facebookresearch/pycls/blob/f8cd962737e33ce9e19b3083a33551da95c2d9c0/pycls/core/net.py
# Rethinking “Batch” in Batchnorm: https://arxiv.org/pdf/2105.07576.pdf
import torch
import itertools
@torch.no_grad()
def compute_precise_bn_stats(model, loader, num_samples):
"""Computes precise BN stats on training data."""
# Compute the number of minibatches to use
model.train()
num_iter = int(num_samples / loader.batch_size)
num_iter = min(num_iter, len(loader))
# Retrieve the BN layers
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
# Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
running_vars = [torch.zeros_like(bn.running_var) for bn in bns]
# Remember momentum values
momentums = [bn.momentum for bn in bns]
# Set momentum to 1.0 to compute BN stats that only reflect the current batch
for bn in bns:
bn.momentum = 1.0
# Average the BN stats for each BN layer over the batches
for inputs, _labels in itertools.islice(loader, num_iter):
model(inputs.cuda())
for i, bn in enumerate(bns):
running_means[i] += bn.running_mean / num_iter
running_vars[i] += bn.running_var / num_iter
# Set BN stats and restore original momentum values
for i, bn in enumerate(bns):
bn.running_mean = running_means[i]
bn.running_var = running_vars[i]
bn.momentum = momentums[i]