-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathaxial-attention.py
239 lines (198 loc) · 9.83 KB
/
axial-attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import *
__all__ = ['axial26s', 'axial50s', 'axial50m', 'axial50l']
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class AxialAttention(nn.Module):
def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,
stride=1, bias=False, width=False):
assert (in_planes % groups == 0) and (out_planes % groups == 0)
super(AxialAttention, self).__init__()
self.in_planes = in_planes
self.out_planes = out_planes
self.groups = groups
self.group_planes = out_planes // groups
self.kernel_size = kernel_size
self.stride = stride
self.bias = bias
self.width = width
# Multi-head self attention
self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
padding=0, bias=False)
self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
self.bn_similarity = nn.BatchNorm2d(groups * 3)
#self.bn_qk = nn.BatchNorm2d(groups)
#self.bn_qr = nn.BatchNorm2d(groups)
#self.bn_kr = nn.BatchNorm2d(groups)
self.bn_output = nn.BatchNorm1d(out_planes * 2)
# Position embedding
self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)
query_index = torch.arange(kernel_size).unsqueeze(0)
key_index = torch.arange(kernel_size).unsqueeze(1)
relative_index = key_index - query_index + kernel_size - 1
self.register_buffer('flatten_index', relative_index.view(-1))
if stride > 1:
self.pooling = nn.AvgPool2d(stride, stride=stride)
self.reset_parameters()
def forward(self, x):
if self.width:
x = x.permute(0, 2, 1, 3)
else:
x = x.permute(0, 3, 1, 2) # N, W, C, H
N, W, C, H = x.shape
x = x.contiguous().view(N * W, C, H)
# Transformations
qkv = self.bn_qkv(self.qkv_transform(x))
q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)
# Calculate position embedding
all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
qk = torch.einsum('bgci, bgcj->bgij', q, k)
stacked_similarity = torch.cat([qk, qr, kr], dim=1)
stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)
#stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)
# (N, groups, H, H, W)
similarity = F.softmax(stacked_similarity, dim=3)
sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)
stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)
output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)
if self.width:
output = output.permute(0, 2, 1, 3)
else:
output = output.permute(0, 2, 3, 1)
if self.stride > 1:
output = self.pooling(output)
return output
def reset_parameters(self):
self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
#nn.init.uniform_(self.relative, -0.1, 0.1)
nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))
class AxialBlock(nn.Module):
expansion = 2
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, kernel_size=56):
super(AxialBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.))
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv_down = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.hight_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size)
self.width_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
self.conv_up = conv1x1(width, planes * self.expansion)
self.bn2 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv_down(x)
out = self.bn1(out)
out = self.relu(out)
out = self.hight_block(out)
out = self.width_block(out)
out = self.relu(out)
out = self.conv_up(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AxialAttentionNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=True,
groups=8, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, s=0.5):
super(AxialAttentionNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = int(64 * s)
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size=56)
self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=56,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=28,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=14,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(int(1024 * block.expansion * s), num_classes)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
if isinstance(m, qkv_transform):
pass
else:
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, AxialBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups,
base_width=self.base_width, dilation=previous_dilation,
norm_layer=norm_layer, kernel_size=kernel_size))
self.inplanes = planes * block.expansion
if stride != 1:
kernel_size = kernel_size // 2
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer, kernel_size=kernel_size))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)