-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
192 lines (150 loc) · 6 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from flax import linen as nn
from functools import partial
from flax.linen.initializers import he_normal, zeros, ones
from typing import (Any, Callable, Tuple)
from fixup_initializer import fixup
PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any
Array = Any
ModuleDef = Any
default_kernel_init = he_normal()
class FixupBias(nn.Module):
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
@nn.compact
def __call__(self, inputs: Array) -> Array:
"""
Adds a scalar bias to the input.
Parameters
----------
inputs: Array
The inputs to the fixup bias layer.
Returns
-------
y: Array
The ouputs of the bias layer.
"""
bias = self.param('fixup_bias', self.bias_init, (1,))
y = inputs+bias
return y
class FixupMultiplier(nn.Module):
multiplier_init: Callable[[PRNGKey, Shape, Dtype], Array] = ones
@nn.compact
def __call__(self, inputs: Array) -> Array:
"""
Multiplies the input with a scalar.
Parameters
----------
inputs: Array
The inputs of the fixup multiplier layer.
Returns
-------
y: Array
The outputs of the fixup multiplier layer.
"""
multiplier = self.param('fixup_multiplier', self.multiplier_init, (1,))
y = inputs * multiplier
return y
class FixupBasicBlock(nn.Module):
"""
Fixup Basic layer.
This is the basic block of the Fixup architecture. The second convolutional layer and the Fixup Bias are initialized
to 0. The first convolutional layer is initialized according to the Fixup initialization. The Fixup Multiplier is
initialized to the value of 1.
"""
in_planes: int
out_planes: int
stride: int
dropRate: float
total_num_basic_blocks: int
@nn.compact
def __call__(self, x, train):
if not self.in_planes == self.out_planes:
x = FixupBias()(x)
x = nn.relu(x)
out = nn.Conv(features=self.out_planes, strides=(self.stride, self.stride), use_bias=False,
kernel_size=(3, 3), kernel_init=fixup(l=self.total_num_basic_blocks, m=2))(x)
out = FixupBias()(out)
out = nn.relu(out)
out = FixupBias()(out)
else:
out = FixupBias()(x)
out = nn.relu(out)
out = nn.Conv(features=self.out_planes, use_bias=False, kernel_size=(3, 3),
kernel_init=fixup(l=self.total_num_basic_blocks, m=2))(out)
out = FixupBias()(out)
out = nn.relu(out)
out = FixupBias()(out)
if self.dropRate > 0:
out = nn.Dropout(self.dropRate, deterministic=not train)(out)
out = nn.Conv(features=self.out_planes, strides=(1, 1), padding=1, use_bias=False, kernel_size=(3, 3),
kernel_init=zeros)(out)
out = FixupMultiplier()(out)
out = FixupBias()(out)
if not self.in_planes == self.out_planes:
x = nn.Conv(features=self.out_planes, strides=(self.stride, self.stride), kernel_size=(1, 1), padding=0,
use_bias=False)(x)
return x + out
else:
return x + out
class FixupNetworkBlock(nn.Module):
"""
Fixup Network block.
This creates concatenations of Fixup Basic Blocks.
"""
nb_layers: int
in_planes: int
out_planes: int
block_cls: ModuleDef
stride: int
dropRate: float
total_num_basic_blocks: int
@nn.compact
def __call__(self, x, train):
for i in range(self.nb_layers):
x = self.block_cls(i == 0 and self.in_planes or self.out_planes,
self.out_planes,
i == 0 and self.stride or 1,
self.dropRate,
self.total_num_basic_blocks)(x, train)
return x
class FixupWideResNet(nn.Module):
"""
Fixup Wide ResNet.
The main class of the Fixup architecture. The final (classification) layer is initialized to 0.
"""
depth: int
widen_factor: int
num_classes: int = 10
dropRate: float = 0.0
def setup(self):
nChannels = [16, 16*self.widen_factor, 32*self.widen_factor, 64*self.widen_factor]
assert ((self.depth-4) % 6 == 0)
block = FixupBasicBlock
m = (self.depth - 4) // 6 #This is the number of basic blocks per block.
total_num_basic_blocks = m * 3 #Since there are 3 blocks, the total number of basic blocks is 3*n.
self.conv1 = nn.Conv(features=nChannels[0], strides=(1, 1), use_bias=False, kernel_size=(3, 3),padding=1)
self.bias1 = FixupBias()
self.block1 = FixupNetworkBlock(m, nChannels[0], nChannels[1], block, 1, self.dropRate, total_num_basic_blocks)
self.block2 = FixupNetworkBlock(m, nChannels[1], nChannels[2], block, 2, self.dropRate, total_num_basic_blocks)
self.block3 = FixupNetworkBlock(m, nChannels[2], nChannels[3], block, 2, self.dropRate, total_num_basic_blocks)
self.bias2 = FixupBias()
self.relu1 = nn.relu
self.fc = nn.Dense(features=self.num_classes, kernel_init=zeros, bias_init=zeros)
def __call__(self, x, train):
out = self.bias1(x)
out = self.conv1(out)
out = self.block1(out, train)
out = self.block2(out, train)
out = self.block3(out, train)
out = self.relu1(out)
out = nn.avg_pool(out, (8, 8))
out = out.reshape((out.shape[0], -1))
out = self.bias2(out)
out = self.fc(out)
return out
FixupWideResNet16 = partial(FixupWideResNet, depth=16, widen_factor=4, num_classes=10, dropRate=0.3)
FixupWideResNet22 = partial(FixupWideResNet, depth=22, widen_factor=4, num_classes=10, dropRate=0.3)
FixupWideResNet28 = partial(FixupWideResNet, depth=52, widen_factor=4, num_classes=10, dropRate=0.3)
# Used for testing
_FixupWideResNet10 = partial(FixupWideResNet, depth=10, widen_factor=1, num_classes=10, dropRate=0.3)