-
Notifications
You must be signed in to change notification settings - Fork 392
/
Copy pathduc_hdc.py
97 lines (85 loc) · 3.48 KB
/
duc_hdc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from torch import nn
from torchvision import models
class _DenseUpsamplingConvModule(nn.Module):
def __init__(self, down_factor, in_dim, num_classes):
super(_DenseUpsamplingConvModule, self).__init__()
upsample_dim = (down_factor ** 2) * num_classes
self.conv = nn.Conv2d(in_dim, upsample_dim, kernel_size=3, padding=1)
self.bn = nn.BatchNorm2d(upsample_dim)
self.relu = nn.ReLU(inplace=True)
self.pixel_shuffle = nn.PixelShuffle(down_factor)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.pixel_shuffle(x)
return x
class ResNetDUC(nn.Module):
# the size of image should be multiple of 8
def __init__(self, num_classes, pretrained=True):
super(ResNetDUC, self).__init__()
resnet = models.resnet152(pretrained=pretrained)
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
for n, m in self.layer3.named_modules():
if 'conv2' in n:
m.dilation = (2, 2)
m.padding = (2, 2)
m.stride = (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
for n, m in self.layer4.named_modules():
if 'conv2' in n:
m.dilation = (4, 4)
m.padding = (4, 4)
m.stride = (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
self.duc = _DenseUpsamplingConvModule(8, 2048, num_classes)
def forward(self, x):
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.duc(x)
return x
class ResNetDUCHDC(nn.Module):
# the size of image should be multiple of 8
def __init__(self, num_classes, pretrained=True):
super(ResNetDUCHDC, self).__init__()
resnet = models.resnet152()
if pretrained:
resnet.load_state_dict(torch.load(res152_path))
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
for n, m in self.layer3.named_modules():
if 'conv2' in n or 'downsample.0' in n:
m.stride = (1, 1)
for n, m in self.layer4.named_modules():
if 'conv2' in n or 'downsample.0' in n:
m.stride = (1, 1)
layer3_group_config = [1, 2, 5, 9]
for idx in range(len(self.layer3)):
self.layer3[idx].conv2.dilation = (layer3_group_config[idx % 4], layer3_group_config[idx % 4])
self.layer3[idx].conv2.padding = (layer3_group_config[idx % 4], layer3_group_config[idx % 4])
layer4_group_config = [5, 9, 17]
for idx in range(len(self.layer4)):
self.layer4[idx].conv2.dilation = (layer4_group_config[idx], layer4_group_config[idx])
self.layer4[idx].conv2.padding = (layer4_group_config[idx], layer4_group_config[idx])
self.duc = _DenseUpsamplingConvModule(8, 2048, num_classes)
def forward(self, x):
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.duc(x)
return x