@@ -74,11 +74,11 @@ def _cfg(url='', **kwargs):
74
74
class ClassAttn (nn .Module ):
75
75
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
76
76
# with slight modifications to do CA
77
- def __init__ (self , dim , num_heads = 8 , qkv_bias = False , qk_scale = None , attn_drop = 0. , proj_drop = 0. ):
77
+ def __init__ (self , dim , num_heads = 8 , qkv_bias = False , attn_drop = 0. , proj_drop = 0. ):
78
78
super ().__init__ ()
79
79
self .num_heads = num_heads
80
80
head_dim = dim // num_heads
81
- self .scale = qk_scale or head_dim ** - 0.5
81
+ self .scale = head_dim ** - 0.5
82
82
83
83
self .q = nn .Linear (dim , dim , bias = qkv_bias )
84
84
self .k = nn .Linear (dim , dim , bias = qkv_bias )
@@ -110,13 +110,13 @@ class LayerScaleBlockClassAttn(nn.Module):
110
110
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
111
111
# with slight modifications to add CA and LayerScale
112
112
def __init__ (
113
- self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , qk_scale = None , drop = 0. , attn_drop = 0. ,
113
+ self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. ,
114
114
drop_path = 0. , act_layer = nn .GELU , norm_layer = nn .LayerNorm , attn_block = ClassAttn ,
115
115
mlp_block = Mlp , init_values = 1e-4 ):
116
116
super ().__init__ ()
117
117
self .norm1 = norm_layer (dim )
118
118
self .attn = attn_block (
119
- dim , num_heads = num_heads , qkv_bias = qkv_bias , qk_scale = qk_scale , attn_drop = attn_drop , proj_drop = drop )
119
+ dim , num_heads = num_heads , qkv_bias = qkv_bias , attn_drop = attn_drop , proj_drop = drop )
120
120
self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
121
121
self .norm2 = norm_layer (dim )
122
122
mlp_hidden_dim = int (dim * mlp_ratio )
@@ -134,14 +134,14 @@ def forward(self, x, x_cls):
134
134
class TalkingHeadAttn (nn .Module ):
135
135
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
136
136
# with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
137
- def __init__ (self , dim , num_heads = 8 , qkv_bias = False , qk_scale = None , attn_drop = 0. , proj_drop = 0. ):
137
+ def __init__ (self , dim , num_heads = 8 , qkv_bias = False , attn_drop = 0. , proj_drop = 0. ):
138
138
super ().__init__ ()
139
139
140
140
self .num_heads = num_heads
141
141
142
142
head_dim = dim // num_heads
143
143
144
- self .scale = qk_scale or head_dim ** - 0.5
144
+ self .scale = head_dim ** - 0.5
145
145
146
146
self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
147
147
self .attn_drop = nn .Dropout (attn_drop )
@@ -177,13 +177,13 @@ class LayerScaleBlock(nn.Module):
177
177
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
178
178
# with slight modifications to add layerScale
179
179
def __init__ (
180
- self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , qk_scale = None , drop = 0. , attn_drop = 0. ,
180
+ self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. ,
181
181
drop_path = 0. , act_layer = nn .GELU , norm_layer = nn .LayerNorm , attn_block = TalkingHeadAttn ,
182
182
mlp_block = Mlp , init_values = 1e-4 ):
183
183
super ().__init__ ()
184
184
self .norm1 = norm_layer (dim )
185
185
self .attn = attn_block (
186
- dim , num_heads = num_heads , qkv_bias = qkv_bias , qk_scale = qk_scale , attn_drop = attn_drop , proj_drop = drop )
186
+ dim , num_heads = num_heads , qkv_bias = qkv_bias , attn_drop = attn_drop , proj_drop = drop )
187
187
self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
188
188
self .norm2 = norm_layer (dim )
189
189
mlp_hidden_dim = int (dim * mlp_ratio )
@@ -202,7 +202,7 @@ class Cait(nn.Module):
202
202
# with slight modifications to adapt to our cait models
203
203
def __init__ (
204
204
self , img_size = 224 , patch_size = 16 , in_chans = 3 , num_classes = 1000 , embed_dim = 768 , depth = 12 ,
205
- num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , qk_scale = None , drop_rate = 0. , attn_drop_rate = 0. ,
205
+ num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , drop_rate = 0. , attn_drop_rate = 0. ,
206
206
drop_path_rate = 0. ,
207
207
norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
208
208
global_pool = None ,
@@ -235,14 +235,14 @@ def __init__(
235
235
dpr = [drop_path_rate for i in range (depth )]
236
236
self .blocks = nn .ModuleList ([
237
237
block_layers (
238
- dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , qkv_bias = qkv_bias , qk_scale = qk_scale ,
238
+ dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , qkv_bias = qkv_bias ,
239
239
drop = drop_rate , attn_drop = attn_drop_rate , drop_path = dpr [i ], norm_layer = norm_layer ,
240
240
act_layer = act_layer , attn_block = attn_block , mlp_block = mlp_block , init_values = init_scale )
241
241
for i in range (depth )])
242
242
243
243
self .blocks_token_only = nn .ModuleList ([
244
244
block_layers_token (
245
- dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio_clstk , qkv_bias = qkv_bias , qk_scale = qk_scale ,
245
+ dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio_clstk , qkv_bias = qkv_bias ,
246
246
drop = 0.0 , attn_drop = 0.0 , drop_path = 0.0 , norm_layer = norm_layer ,
247
247
act_layer = act_layer , attn_block = attn_block_token_only ,
248
248
mlp_block = mlp_block_token_only , init_values = init_scale )
@@ -270,6 +270,13 @@ def _init_weights(self, m):
270
270
def no_weight_decay (self ):
271
271
return {'pos_embed' , 'cls_token' }
272
272
273
+ def get_classifier (self ):
274
+ return self .head
275
+
276
+ def reset_classifier (self , num_classes , global_pool = '' ):
277
+ self .num_classes = num_classes
278
+ self .head = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
279
+
273
280
def forward_features (self , x ):
274
281
B = x .shape [0 ]
275
282
x = self .patch_embed (x )
@@ -293,7 +300,6 @@ def forward_features(self, x):
293
300
def forward (self , x ):
294
301
x = self .forward_features (x )
295
302
x = self .head (x )
296
-
297
303
return x
298
304
299
305
0 commit comments