-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathlinear_type.py
122 lines (97 loc) · 4.78 KB
/
linear_type.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
DenseLinear = nn.Linear
class SplitLinear(nn.Linear):
def __init__(self, *args, **kwargs):
self.split_mode = kwargs.pop('split_mode', None)
split_rate = kwargs.pop('split_rate', None)
last_layer = kwargs.pop('last_layer', None)
self.in_channels_order = kwargs.pop('in_channels_order', None)
self.split_rate = split_rate
self.bias_split_rate = self.split_rate
super().__init__(*args, **kwargs)
## AT : I am assuming a single FC layer in the network. Typical for most CNNs
if self.split_mode == 'kels':
if self.in_channels_order is None:
if last_layer:
active_in_dim = math.ceil(self.weight.size()[1] * split_rate)
mask = np.zeros((self.weight.size()[0],self.weight.size()[1]))
mask[:,:active_in_dim] = 1
else:
active_in_dim = math.ceil(self.weight.size()[1] * split_rate)
active_out_dim = math.ceil(self.weight.size()[0] * split_rate)
mask = np.zeros((self.weight.size()[0], self.weight.size()[1]))
mask[:active_out_dim, :active_in_dim] = 1
else:
mask = np.zeros((self.weight.size()[0], self.weight.size()[1]))
conv_concat = self.in_channels_order.split(',')
start_ch = 0
for conv in conv_concat:
mask[:,start_ch:start_ch + math.ceil(int(conv) * split_rate)] = 1
start_ch += int(conv)
elif self.split_mode == 'wels':
mask = np.random.rand(*list(self.weight.shape))
# threshold = np.percentile(scores, (1 - self.keep_rate) * 100)
threshold = 1 - self.split_rate
mask[mask < threshold] = 0
mask[mask >= threshold] = 1
if self.split_rate != 1:
assert len(np.unique(mask)) == 2, 'Something is wrong with the mask {}'.format(np.unique(mask))
else:
raise NotImplemented('Invalid split_mode {}'.format(self.split_mode))
self.mask = nn.Parameter(torch.Tensor(mask), requires_grad=False)
# self.reset_scores()
# def set_keep_rate(self, keep_rate, bias_keep_rate):
# self.split_rate = keep_rate
# self.bias_keep_rate = bias_keep_rate
# def reset_scores(self):
# if self.split_mode == 'wels':
# scores = np.random.rand(*list(self.weight.shape))
# # threshold = np.percentile(scores, (1 - self.keep_rate) * 100)
# threshold = 1 - self.split_rate
# scores[scores < threshold] = 0
# scores[scores >= threshold] = 1
# if self.split_rate != 1:
# assert len(np.unique(scores)) == 2, 'Something is wrong with the score {}'.format(np.unique(scores))
# else:
# raise NotImplemented('Reset score randomly only with WELS. The current mode is '.format(self.split_mode))
# # active_in_dim = math.ceil(self.weight.size()[1] * self.keep_rate)
# # rand_sub = random.randint(0, self.weight.size()[1] - active_in_dim)
# # scores = np.zeros((self.weight.size()[0], self.weight.size()[1]))
# # scores[:, rand_sub:rand_sub+active_in_dim] = 1
# self.scores.data = torch.Tensor(scores).cuda()
# def reset_bias_scores(self):
# pass
def extract_slim(self,dst_m,src_name,dst_name):
c_out, c_in = self.weight.size()
d_out, d_in = dst_m.weight.size()
if self.in_channels_order is None:
assert dst_m.weight.shape == self.weight[:d_out, :d_in].shape
dst_m.weight.data = self.weight.data[:d_out, :d_in]
assert dst_m.bias.data.shape == self.bias.data[:d_out].shape
dst_m.bias.data = self.bias.data[:d_out]
else:
dst_m.weight.data = self.weight[:d_out, self.mask[0, :] == 1]
dst_m.bias.data = self.bias.data[:d_out]
def split_reinitialize(self, cfg):
if cfg.evolve_mode == 'rand':
rand_tensor = torch.zeros_like(self.weight).cuda()
nn.init.kaiming_uniform_(rand_tensor, a=math.sqrt(5))
self.weight.data = torch.where(self.mask.type(torch.bool), self.weight.data, rand_tensor)
else:
raise NotImplemented('Invalid KE mode {}'.format(cfg.evolve_mode))
def forward(self, x):
## Debugging purpose
# if self.split_rate < 1:
# # subnet = GetSubnet.apply(self.clamped_scores, self.keep_rate)
# w = self.weight * self.scores
# else:
# w = self.weight
w = self.weight
b = self.bias
x = F.linear(x, w, b)
return x