Skip to content

Commit cef01ba

Browse files
committed
Refactoring, cleanup, improved test coverage.
* Add eca_nfnet_l2 weights, 84.7 @ 384x384 * All 'non-std' (ie transformer / mlp) models have classifier / default_cfg test added * Fix huggingface#694 reset_classifer / num_features / forward_features / num_classes=0 consistency for transformer / mlp models * Add direct loading of npz to vision transformer (pure transformer so far, hybrid to come) * Rename vit_deit* to deit_* * Remove some deprecated vit hybrid model defs * Clean up classifier flatten for conv classifiers and unusual cases (mobilenetv3/ghostnet) * Remove explicit model fns for levit conv, just pass in arg
1 parent 6018cb9 commit cef01ba

24 files changed

+637
-455
lines changed

tests/test_models.py

+52-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# transformer models don't support many of the spatial / feature based model functionalities
1818
NON_STD_FILTERS = [
1919
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
20-
'convit_*', 'levit*', 'visformer*']
20+
'convit_*', 'levit*', 'visformer*', 'deit*']
2121
NUM_NON_STD = len(NON_STD_FILTERS)
2222

2323
# exclude models that cause specific test failures
@@ -120,7 +120,6 @@ def test_model_default_cfgs(model_name, batch_size):
120120
state_dict = model.state_dict()
121121
cfg = model.default_cfg
122122

123-
classifier = cfg['classifier']
124123
pool_size = cfg['pool_size']
125124
input_size = model.default_cfg['input_size']
126125

@@ -149,7 +148,57 @@ def test_model_default_cfgs(model_name, batch_size):
149148
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
150149

151150
# check classifier name matches default_cfg
152-
assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params'
151+
classifier = cfg['classifier']
152+
if not isinstance(classifier, (tuple, list)):
153+
classifier = classifier,
154+
for c in classifier:
155+
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
156+
157+
# check first conv(s) names match default_cfg
158+
first_conv = cfg['first_conv']
159+
if isinstance(first_conv, str):
160+
first_conv = (first_conv,)
161+
assert isinstance(first_conv, (tuple, list))
162+
for fc in first_conv:
163+
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
164+
165+
166+
@pytest.mark.timeout(300)
167+
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS))
168+
@pytest.mark.parametrize('batch_size', [1])
169+
def test_model_default_cfgs_non_std(model_name, batch_size):
170+
"""Run a single forward pass with each model"""
171+
model = create_model(model_name, pretrained=False)
172+
model.eval()
173+
state_dict = model.state_dict()
174+
cfg = model.default_cfg
175+
176+
input_size = _get_input_size(model_name=model_name, target=TARGET_FWD_SIZE)
177+
if max(input_size) > MAX_FWD_SIZE:
178+
pytest.skip("Fixed input size model > limit.")
179+
180+
input_tensor = torch.randn((batch_size, *input_size))
181+
182+
# test forward_features (always unpooled)
183+
outputs = model.forward_features(input_tensor)
184+
if isinstance(outputs, tuple):
185+
outputs = outputs[0]
186+
assert outputs.shape[1] == model.num_features
187+
188+
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
189+
model.reset_classifier(0)
190+
outputs = model.forward(input_tensor)
191+
if isinstance(outputs, tuple):
192+
outputs = outputs[0]
193+
assert len(outputs.shape) == 2
194+
assert outputs.shape[1] == model.num_features
195+
196+
# check classifier name matches default_cfg
197+
classifier = cfg['classifier']
198+
if not isinstance(classifier, (tuple, list)):
199+
classifier = classifier,
200+
for c in classifier:
201+
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
153202

154203
# check first conv(s) names match default_cfg
155204
first_conv = cfg['first_conv']

timm/models/cait.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ def _cfg(url='', **kwargs):
7474
class ClassAttn(nn.Module):
7575
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
7676
# with slight modifications to do CA
77-
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
77+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
7878
super().__init__()
7979
self.num_heads = num_heads
8080
head_dim = dim // num_heads
81-
self.scale = qk_scale or head_dim ** -0.5
81+
self.scale = head_dim ** -0.5
8282

8383
self.q = nn.Linear(dim, dim, bias=qkv_bias)
8484
self.k = nn.Linear(dim, dim, bias=qkv_bias)
@@ -110,13 +110,13 @@ class LayerScaleBlockClassAttn(nn.Module):
110110
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
111111
# with slight modifications to add CA and LayerScale
112112
def __init__(
113-
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
113+
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
114114
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn,
115115
mlp_block=Mlp, init_values=1e-4):
116116
super().__init__()
117117
self.norm1 = norm_layer(dim)
118118
self.attn = attn_block(
119-
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
119+
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
120120
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
121121
self.norm2 = norm_layer(dim)
122122
mlp_hidden_dim = int(dim * mlp_ratio)
@@ -134,14 +134,14 @@ def forward(self, x, x_cls):
134134
class TalkingHeadAttn(nn.Module):
135135
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
136136
# with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
137-
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
137+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
138138
super().__init__()
139139

140140
self.num_heads = num_heads
141141

142142
head_dim = dim // num_heads
143143

144-
self.scale = qk_scale or head_dim ** -0.5
144+
self.scale = head_dim ** -0.5
145145

146146
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
147147
self.attn_drop = nn.Dropout(attn_drop)
@@ -177,13 +177,13 @@ class LayerScaleBlock(nn.Module):
177177
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
178178
# with slight modifications to add layerScale
179179
def __init__(
180-
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
180+
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
181181
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn,
182182
mlp_block=Mlp, init_values=1e-4):
183183
super().__init__()
184184
self.norm1 = norm_layer(dim)
185185
self.attn = attn_block(
186-
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
186+
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
187187
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
188188
self.norm2 = norm_layer(dim)
189189
mlp_hidden_dim = int(dim * mlp_ratio)
@@ -202,7 +202,7 @@ class Cait(nn.Module):
202202
# with slight modifications to adapt to our cait models
203203
def __init__(
204204
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
205-
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
205+
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
206206
drop_path_rate=0.,
207207
norm_layer=partial(nn.LayerNorm, eps=1e-6),
208208
global_pool=None,
@@ -235,14 +235,14 @@ def __init__(
235235
dpr = [drop_path_rate for i in range(depth)]
236236
self.blocks = nn.ModuleList([
237237
block_layers(
238-
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
238+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
239239
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
240240
act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale)
241241
for i in range(depth)])
242242

243243
self.blocks_token_only = nn.ModuleList([
244244
block_layers_token(
245-
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, qk_scale=qk_scale,
245+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias,
246246
drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,
247247
act_layer=act_layer, attn_block=attn_block_token_only,
248248
mlp_block=mlp_block_token_only, init_values=init_scale)
@@ -270,6 +270,13 @@ def _init_weights(self, m):
270270
def no_weight_decay(self):
271271
return {'pos_embed', 'cls_token'}
272272

273+
def get_classifier(self):
274+
return self.head
275+
276+
def reset_classifier(self, num_classes, global_pool=''):
277+
self.num_classes = num_classes
278+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
279+
273280
def forward_features(self, x):
274281
B = x.shape[0]
275282
x = self.patch_embed(x)
@@ -293,7 +300,6 @@ def forward_features(self, x):
293300
def forward(self, x):
294301
x = self.forward_features(x)
295302
x = self.head(x)
296-
297303
return x
298304

299305

timm/models/coat.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,8 @@ def __init__(
335335
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
336336
self.return_interm_layers = return_interm_layers
337337
self.out_features = out_features
338+
self.embed_dims = embed_dims
339+
self.num_features = embed_dims[-1]
338340
self.num_classes = num_classes
339341

340342
# Patch embeddings.
@@ -441,10 +443,10 @@ def __init__(
441443
# CoaT series: Aggregate features of last three scales for classification.
442444
assert embed_dims[1] == embed_dims[2] == embed_dims[3]
443445
self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
444-
self.head = nn.Linear(embed_dims[3], num_classes)
446+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
445447
else:
446448
# CoaT-Lite series: Use feature of last scale for classification.
447-
self.head = nn.Linear(embed_dims[3], num_classes)
449+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
448450

449451
# Initialize weights.
450452
trunc_normal_(self.cls_token1, std=.02)
@@ -471,7 +473,7 @@ def get_classifier(self):
471473

472474
def reset_classifier(self, num_classes, global_pool=''):
473475
self.num_classes = num_classes
474-
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
476+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
475477

476478
def insert_cls(self, x, cls_token):
477479
""" Insert CLS token. """

timm/models/convit.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ def _cfg(url='', **kwargs):
5757

5858

5959
class GPSA(nn.Module):
60-
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
60+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
6161
locality_strength=1.):
6262
super().__init__()
6363
self.num_heads = num_heads
6464
self.dim = dim
6565
head_dim = dim // num_heads
66-
self.scale = qk_scale or head_dim ** -0.5
66+
self.scale = head_dim ** -0.5
6767
self.locality_strength = locality_strength
6868

6969
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
@@ -142,11 +142,11 @@ def get_rel_indices(self, num_patches: int) -> torch.Tensor:
142142

143143

144144
class MHSA(nn.Module):
145-
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
145+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
146146
super().__init__()
147147
self.num_heads = num_heads
148148
head_dim = dim // num_heads
149-
self.scale = qk_scale or head_dim ** -0.5
149+
self.scale = head_dim ** -0.5
150150

151151
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
152152
self.attn_drop = nn.Dropout(attn_drop)
@@ -191,19 +191,16 @@ def forward(self, x):
191191

192192
class Block(nn.Module):
193193

194-
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
194+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
195195
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
196196
super().__init__()
197197
self.norm1 = norm_layer(dim)
198198
self.use_gpsa = use_gpsa
199199
if self.use_gpsa:
200200
self.attn = GPSA(
201-
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
202-
proj_drop=drop, **kwargs)
201+
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, **kwargs)
203202
else:
204-
self.attn = MHSA(
205-
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
206-
proj_drop=drop, **kwargs)
203+
self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
207204
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
208205
self.norm2 = norm_layer(dim)
209206
mlp_hidden_dim = int(dim * mlp_ratio)
@@ -220,7 +217,7 @@ class ConViT(nn.Module):
220217
"""
221218

222219
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
223-
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
220+
num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
224221
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None,
225222
local_up_to_layer=3, locality_strength=1., use_pos_embed=True):
226223
super().__init__()
@@ -250,13 +247,13 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
250247
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
251248
self.blocks = nn.ModuleList([
252249
Block(
253-
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
250+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
254251
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
255252
use_gpsa=True,
256253
locality_strength=locality_strength)
257254
if i < local_up_to_layer else
258255
Block(
259-
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
256+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
260257
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
261258
use_gpsa=False)
262259
for i in range(depth)])

timm/models/dla.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chan
288288
self.num_features = channels[-1]
289289
self.global_pool, self.fc = create_classifier(
290290
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
291+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
292+
291293
for m in self.modules():
292294
if isinstance(m, nn.Conv2d):
293295
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
@@ -314,6 +316,7 @@ def reset_classifier(self, num_classes, global_pool='avg'):
314316
self.num_classes = num_classes
315317
self.global_pool, self.fc = create_classifier(
316318
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
319+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
317320

318321
def forward_features(self, x):
319322
x = self.base_layer(x)
@@ -331,8 +334,7 @@ def forward(self, x):
331334
if self.drop_rate > 0.:
332335
x = F.dropout(x, p=self.drop_rate, training=self.training)
333336
x = self.fc(x)
334-
if not self.global_pool.is_identity():
335-
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
337+
x = self.flatten(x)
336338
return x
337339

338340

timm/models/dpn.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
237237
# Using 1x1 conv for the FC layer to allow the extra pooling scheme
238238
self.global_pool, self.classifier = create_classifier(
239239
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
240+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
240241

241242
def get_classifier(self):
242243
return self.classifier
@@ -245,6 +246,7 @@ def reset_classifier(self, num_classes, global_pool='avg'):
245246
self.num_classes = num_classes
246247
self.global_pool, self.classifier = create_classifier(
247248
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
249+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
248250

249251
def forward_features(self, x):
250252
return self.features(x)
@@ -255,8 +257,7 @@ def forward(self, x):
255257
if self.drop_rate > 0.:
256258
x = F.dropout(x, p=self.drop_rate, training=self.training)
257259
x = self.classifier(x)
258-
if not self.global_pool.is_identity():
259-
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
260+
x = self.flatten(x)
260261
return x
261262

262263

timm/models/ghostnet.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def forward(self, x):
133133

134134

135135
class GhostNet(nn.Module):
136-
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32):
136+
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32, global_pool='avg'):
137137
super(GhostNet, self).__init__()
138138
# setting of inverted residual blocks
139139
assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
@@ -178,9 +178,10 @@ def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, o
178178

179179
# building last several layers
180180
self.num_features = out_chs = 1280
181-
self.global_pool = SelectAdaptivePool2d(pool_type='avg')
181+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
182182
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
183183
self.act2 = nn.ReLU(inplace=True)
184+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
184185
self.classifier = Linear(out_chs, num_classes)
185186

186187
def get_classifier(self):
@@ -190,6 +191,7 @@ def reset_classifier(self, num_classes, global_pool='avg'):
190191
self.num_classes = num_classes
191192
# cannot meaningfully change pooling of efficient head after creation
192193
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
194+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
193195
self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity()
194196

195197
def forward_features(self, x):
@@ -204,8 +206,7 @@ def forward_features(self, x):
204206

205207
def forward(self, x):
206208
x = self.forward_features(x)
207-
if not self.global_pool.is_identity():
208-
x = x.view(x.size(0), -1)
209+
x = self.flatten(x)
209210
if self.dropout > 0.:
210211
x = F.dropout(x, p=self.dropout, training=self.training)
211212
x = self.classifier(x)

0 commit comments

Comments
 (0)