-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgempooling.py
45 lines (37 loc) · 1.52 KB
/
gempooling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# -*- coding: utf-8 -*-
# @Time : 2021/6/14 - 12:26
# @File : gem_torch.py
# @Author : surui
import torch
import torch.nn as nn
import torch.nn.functional as F
class GeMPooling(nn.Module):
def __init__(self, feature_size, pool_size=7, init_norm=3.0, eps=1e-6, normalize=False, **kwargs):
super(GeMPooling, self).__init__(**kwargs)
self.feature_size = feature_size # Final layer channel size, the pow calc at -1 axis
self.pool_size = pool_size
self.init_norm = init_norm
self.p = torch.nn.Parameter(torch.ones(self.feature_size) * self.init_norm, requires_grad=True)
self.p.data.fill_(init_norm)
self.normalize = normalize
self.avg_pooling = nn.AvgPool2d((self.pool_size, self.pool_size))
self.eps = eps
def forward(self, features):
# filter invalid value: set minimum to 1e-6
# features-> (B, H, W, C)
features = features.clamp(min=self.eps).pow(self.p)
features = features.permute((0, 3, 1, 2))
features = self.avg_pooling(features)
features = torch.squeeze(features)
features = features.permute((0, 2, 3, 1))
features = torch.pow(features, (1.0 / self.p))
# unit vector
if self.normalize:
features = F.normalize(features, dim=-1, p=2)
return features
if __name__ == '__main__':
x = torch.randn((8, 7, 7, 768)) * 0.02
gem = GeMPooling(768, pool_size=3, init_norm=3.0)
print("input : ", x)
print("=========================")
print(gem(x))