-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathNetworks.py
139 lines (122 loc) · 4.46 KB
/
Networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch.nn as nn
import torchvision.models as backbone_
import torch.nn.functional as F
import torch
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class InceptionV3_Network(nn.Module):
def __init__(self):
super(InceptionV3_Network, self).__init__()
backbone = backbone_.inception_v3(pretrained=True)
self.Conv2d_1a_3x3 = backbone.Conv2d_1a_3x3
self.Conv2d_2a_3x3 = backbone.Conv2d_2a_3x3
self.Conv2d_2b_3x3 = backbone.Conv2d_2b_3x3
self.Conv2d_3b_1x1 = backbone.Conv2d_3b_1x1
self.Conv2d_4a_3x3 = backbone.Conv2d_4a_3x3
for param in self.parameters():
param.requires_grad = False
self.Mixed_5b = backbone.Mixed_5b
self.Mixed_5c = backbone.Mixed_5c
self.Mixed_5d = backbone.Mixed_5d
self.Mixed_6a = backbone.Mixed_6a
self.Mixed_6b = backbone.Mixed_6b
self.Mixed_6c = backbone.Mixed_6c
self.Mixed_6d = backbone.Mixed_6d
self.Mixed_6e = backbone.Mixed_6e
self.Mixed_7a = backbone.Mixed_7a
self.Mixed_7b = backbone.Mixed_7b
self.Mixed_7c = backbone.Mixed_7c
def forward(self, x,size_num):
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 192 x 35 x 35
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
if size_num==0:
return F.normalize(x)
elif size_num==1:
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
return F.normalize(x)
class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()
self.net = nn.Sequential(nn.Conv2d(2048, 512, kernel_size=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512, 1, kernel_size=1))
self.pool_method = nn.AdaptiveMaxPool2d(1) # as default
def forward(self, x):
attn_mask = self.net(x)
attn_mask = attn_mask.view(attn_mask.size(0), -1)
attn_mask = nn.Softmax(dim=1)(attn_mask)
attn_mask = attn_mask.view(attn_mask.size(0), 1, x.size(2), x.size(3))
x = x + (x * attn_mask)
x = self.pool_method(x).view(-1, 2048)
return F.normalize(x)
class Linear(nn.Module):
def __init__(self, feature_num):
super(Linear, self).__init__()
self.head_layer = nn.Linear(2048, feature_num)
def forward(self, x):
return F.normalize(self.head_layer(x))
class residual_block(nn.Module):
def __init__(self, strides=1, same_shape=True, bottle=True):
super(residual_block, self).__init__()
self.same_shape = same_shape
self.bottle = bottle
if not same_shape:
strides = 2
self.strides = strides
self.block = nn.Sequential(
nn.Conv2d(768, 512, kernel_size=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=strides, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 2048, kernel_size=1, bias=False),
nn.BatchNorm2d(2048)
)
self.shortcut = nn.Sequential(
nn.Conv2d(768, 2048, kernel_size=1, bias=False),
nn.BatchNorm2d(2048)
)
self.relu = nn.ReLU()
def forward(self, x):
# print(x.size())
out = self.block(x)
identity = self.shortcut(x)
out = self.relu(out + identity)
return F.normalize(out)