-
Notifications
You must be signed in to change notification settings - Fork 392
/
Copy pathpsp_net.py
142 lines (124 loc) · 5.12 KB
/
psp_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models
from ..utils import initialize_weights
from ..utils.misc import Conv2dDeformable
class _PyramidPoolingModule(nn.Module):
def __init__(self, in_dim, reduction_dim, setting):
super().__init__()
self.features = []
for s in setting:
self.features.append(nn.Sequential(
nn.AdaptiveAvgPool2d(s),
nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
nn.BatchNorm2d(reduction_dim, momentum=.95),
nn.ReLU(inplace=True)
))
self.features = nn.ModuleList(self.features)
def forward(self, x):
x_size = x.size()
out = [x]
for f in self.features:
out.append(F.interpolate(f(x), x_size[2:], mode='bilinear'))
out = torch.cat(out, 1)
return out
class PSPNet(nn.Module):
def __init__(self, num_classes, pretrained=True, use_aux=True):
super().__init__()
self.use_aux = use_aux
resnet = models.resnet101(pretrained=pretrained)
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
for n, m in self.layer3.named_modules():
if 'conv2' in n:
m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
for n, m in self.layer4.named_modules():
if 'conv2' in n:
m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
self.final = nn.Sequential(
nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512, momentum=.95),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv2d(512, num_classes, kernel_size=1)
)
if use_aux:
self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1)
initialize_weights(self.aux_logits)
initialize_weights(self.ppm, self.final)
def forward(self, image):
x = image
x_size = x.size()
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if self.training and self.use_aux:
aux = self.aux_logits(x)
x = self.layer4(x)
x = self.ppm(x)
x = self.final(x)
if self.training and self.use_aux:
return F.interpolate(x, x_size[2:], mode='bilinear'), F.interpolate(aux, x_size[2:], mode='bilinear')
return F.interpolate(x, x_size[2:], mode='bilinear')
# just a try, not recommend to use
class PSPNetDeform(nn.Module):
def __init__(self, num_classes, input_size, pretrained=True, use_aux=True):
super(PSPNetDeform, self).__init__()
self.input_size = input_size
self.use_aux = use_aux
resnet = models.resnet101()
if pretrained:
resnet.load_state_dict(torch.load(res101_path))
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
for n, m in self.layer3.named_modules():
if 'conv2' in n:
m.padding = (1, 1)
m.stride = (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
for n, m in self.layer4.named_modules():
if 'conv2' in n:
m.padding = (1, 1)
m.stride = (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
for idx in range(len(self.layer3)):
self.layer3[idx].conv2 = Conv2dDeformable(self.layer3[idx].conv2)
for idx in range(len(self.layer4)):
self.layer4[idx].conv2 = Conv2dDeformable(self.layer4[idx].conv2)
self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
self.final = nn.Sequential(
nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512, momentum=.95),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv2d(512, num_classes, kernel_size=1)
)
if use_aux:
self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1)
initialize_weights(self.aux_logits)
initialize_weights(self.ppm, self.final)
def forward(self, x):
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if self.training and self.use_aux:
aux = self.aux_logits(x)
x = self.layer4(x)
x = self.ppm(x)
x = self.final(x)
if self.training and self.use_aux:
return F.upsample(x, self.input_size, mode='bilinear'), F.upsample(aux, self.input_size, mode='bilinear')
return F.upsample(x, self.input_size, mode='bilinear')