-
Notifications
You must be signed in to change notification settings - Fork 3
/
mobilevig.py
370 lines (302 loc) · 12.2 KB
/
mobilevig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
# (ref) https://github.com/SLDGroup/MobileViG
# (ref) https://arxiv.org/abs/2307.00395
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Sequential as Seq
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath
from timm.models.registry import register_model
'''
@article{han2022vision,
title={Vision GNN: An Image is Worth Graph of Nodes},
author={Han, Kai and Wang, Yunhe and Guo, Jianyuan and Tang, Yehui and Wu, Enhua},
journal={arXiv preprint arXiv:2206.00272},
year={2022}
}
'''
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'classifier': 'head',
**kwargs
}
default_cfgs = {
'mobilevig': _cfg(crop_pct=0.9, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
}
class Stem(nn.Module):
def __init__(self, input_dim, output_dim, activation=nn.GELU):
super(Stem, self).__init__()
self.stem = nn.Sequential(
nn.Conv2d(input_dim, output_dim // 2, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(output_dim // 2),
nn.GELU(),
nn.Conv2d(output_dim // 2, output_dim, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(output_dim),
nn.GELU()
)
def forward(self, x):
return self.stem(x)
class MLP(nn.Module):
"""
Implementation of MLP with 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
"""
def __init__(self, in_features, hidden_features=None,
out_features=None, drop=0., mid_conv=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.mid_conv = mid_conv
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = nn.GELU()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
if self.mid_conv:
self.mid = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1,
groups=hidden_features)
self.mid_norm = nn.BatchNorm2d(hidden_features)
self.norm1 = nn.BatchNorm2d(hidden_features)
self.norm2 = nn.BatchNorm2d(out_features)
def forward(self, x):
x = self.fc1(x)
x = self.norm1(x)
x = self.act(x)
if self.mid_conv:
x_mid = self.mid(x)
x_mid = self.mid_norm(x_mid)
x = self.act(x_mid)
x = self.drop(x)
x = self.fc2(x)
x = self.norm2(x)
x = self.drop(x)
return x
class InvertedResidual(nn.Module):
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop, mid_conv=True)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(self.layer_scale_2 * self.mlp(x))
else:
x = x + self.drop_path(self.mlp(x))
return x
class MRConv4d(nn.Module):
"""
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
K is the number of superpatches, therefore hops equals res // K.
"""
def __init__(self, in_channels, out_channels, K=2):
super(MRConv4d, self).__init__()
self.nn = nn.Sequential(
nn.Conv2d(in_channels * 2, out_channels, 1),
nn.BatchNorm2d(in_channels * 2),
nn.GELU()
)
self.K = K
def forward(self, x):
B, C, H, W = x.shape
x_j = x - x
for i in range(self.K, H, self.K):
x_c = x - torch.cat([x[:, :, -i:, :], x[:, :, :-i, :]], dim=2)
x_j = torch.max(x_j, x_c)
for i in range(self.K, W, self.K):
x_r = x - torch.cat([x[:, :, :, -i:], x[:, :, :, :-i]], dim=3)
x_j = torch.max(x_j, x_r)
x = torch.cat([x, x_j], dim=1)
return self.nn(x)
class Grapher(nn.Module):
"""
Grapher module with graph convolution and fc layers
"""
def __init__(self, in_channels, drop_path=0.0, K=2):
super(Grapher, self).__init__()
self.channels = in_channels
self.K = K
self.fc1 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
nn.BatchNorm2d(in_channels),
)
self.graph_conv = MRConv4d(in_channels, in_channels * 2, K=self.K)
self.fc2 = nn.Sequential(
nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
nn.BatchNorm2d(in_channels),
) # out_channels back to 1x}
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
_tmp = x
x = self.fc1(x)
x = self.graph_conv(x)
x = self.fc2(x)
x = self.drop_path(x) + _tmp
return x
class Downsample(nn.Module):
""" Convolution-based downsample
"""
def __init__(self, in_dim, out_dim):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1),
nn.BatchNorm2d(out_dim),
)
def forward(self, x):
x = self.conv(x)
return x
class FFN(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, drop_path=0.0):
super().__init__()
out_features = out_features or in_features # same as input
hidden_features = hidden_features or in_features # x4
self.fc1 = nn.Sequential(
nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0),
nn.BatchNorm2d(hidden_features),
)
self.act = nn.GELU()
self.fc2 = nn.Sequential(
nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0),
nn.BatchNorm2d(out_features),
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.drop_path(x) + shortcut
return x
class MobileViG(torch.nn.Module):
def __init__(self, local_blocks, local_channels,
global_blocks, global_channels,
dropout=0., drop_path=0., emb_dims=512,
K=2, distillation=True, num_classes=1000):
super(MobileViG, self).__init__()
self.distillation = distillation
n_blocks = sum(global_blocks) + sum(local_blocks)
dpr = [x.item() for x in torch.linspace(0, drop_path, n_blocks)] # stochastic depth decay rule
dpr_idx = 0
self.stem = Stem(input_dim=3, output_dim=local_channels[0])
# local processing with inverted residuals
self.local_backbone = nn.ModuleList([])
for i in range(len(local_blocks)):
if i > 0:
self.local_backbone.append(Downsample(local_channels[i-1], local_channels[i]))
for _ in range(local_blocks[i]):
self.local_backbone.append(InvertedResidual(dim=local_channels[i], mlp_ratio=4, drop_path=dpr[dpr_idx]))
dpr_idx += 1
self.local_backbone.append(Downsample(local_channels[-1], global_channels[0])) # transition from local to global
# global processing with svga
self.backbone = nn.ModuleList([])
for i in range(len(global_blocks)):
if i > 0:
self.backbone.append(Downsample(global_channels[i-1], global_channels[i]))
for j in range(global_blocks[i]):
self.backbone += [nn.Sequential(
Grapher(global_channels[i], drop_path=dpr[dpr_idx], K=K),
FFN(global_channels[i], global_channels[i] * 4, drop_path=dpr[dpr_idx]))
]
dpr_idx += 1
self.prediction = nn.Sequential(nn.AdaptiveAvgPool2d(1),
nn.Conv2d(global_channels[-1], emb_dims, 1, bias=True),
nn.BatchNorm2d(emb_dims),
nn.GELU(),
nn.Dropout(dropout))
self.head = nn.Conv2d(emb_dims, num_classes, 1, bias=True)
if self.distillation:
self.dist_head = nn.Conv2d(emb_dims, num_classes, 1, bias=True)
self.model_init()
def model_init(self):
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
m.weight.requires_grad = True
if m.bias is not None:
m.bias.data.zero_()
m.bias.requires_grad = True
def forward(self, inputs):
x = self.stem(inputs)
B, C, H, W = x.shape
for i in range(len(self.local_backbone)):
x = self.local_backbone[i](x)
for i in range(len(self.backbone)):
x = self.backbone[i](x)
x = self.prediction(x)
if self.distillation:
x = self.head(x).squeeze(-1).squeeze(-1), self.dist_head(x).squeeze(-1).squeeze(-1)
if not self.training:
x = (x[0] + x[1]) / 2
else:
x = self.head(x).squeeze(-1).squeeze(-1)
return x
@register_model
def mobilevig_ti(pretrained=False, **kwargs):
model = MobileViG(local_blocks=[2, 2, 6],
local_channels=[42, 84, 168],
global_blocks=[2],
global_channels=[256],
dropout=0.,
drop_path=0.1,
emb_dims=512,
K=2,
distillation=True,
num_classes=1000)
model.default_cfg = default_cfgs['mobilevig']
return model
@register_model
def mobilevig_s(pretrained=False, **kwargs):
model = MobileViG(local_blocks=[3, 3, 9],
local_channels=[42, 84, 176],
global_blocks=[3],
global_channels=[256],
dropout=0.,
drop_path=0.1,
emb_dims=512,
K=2,
distillation=True,
num_classes=1000)
model.default_cfg = default_cfgs['mobilevig']
return model
@register_model
def mobilevig_m(pretrained=False, **kwargs):
model = MobileViG(local_blocks=[3, 3, 9],
local_channels=[42, 84, 224],
global_blocks=[3],
global_channels=[400],
dropout=0.,
drop_path=0.1,
emb_dims=768,
K=2,
distillation=True,
num_classes=1000)
model.default_cfg = default_cfgs['mobilevig']
return model
@register_model
def mobilevig_b(pretrained=False, **kwargs):
model = MobileViG(local_blocks=[5, 5, 15],
local_channels=[42, 84, 240],
global_blocks=[5],
global_channels=[464],
dropout=0.,
drop_path=0.1,
emb_dims=768,
K=2,
distillation=True,
num_classes=1000)
model.default_cfg = default_cfgs['mobilevig']
return model
if __name__ == "__main__":
model = mobilevig_s()
model.eval()
img_size = [224, 224]
input = torch.rand(1, 3, *img_size)
output = model(input)
print(output.shape)