-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDDBPN.py
93 lines (73 loc) · 3.35 KB
/
DDBPN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from tensorflow.keras.layers import Concatenate
from Models import *
# The D-DBPN model. It has densely connected BP layers. This is a lightweighted version that comprises of a 7 stage
# BP layers, instead of the 10 stages version. The performance doesn't differ much from the latter.
class DDBPN:
def __init__(self, scale_factor=2, bias=True, bias_init='zeros'):
kernel_size = 0
stride = 0
padding = 0
if scale_factor == 2:
# default scaling factor of 2
kernel_size = 6
stride = 2
padding = 2
elif scale_factor == 4:
kernel_size = 8
stride = 4
padding = 2
elif scale_factor == 8:
kernel_size = 12
stride = 8
padding = 2
# Feature extraction stage
self.f0 = CONV(256, 3, 1, 1, bias, bias_init, True)
self.f1 = CONV(64, 1, 1, 0, bias, bias_init, True)
# Back Projection stage, D-DBPN has a total of 7 BP stages, last stage only has an up-projection
self.up1 = UpProjection(64, kernel_size, stride, padding, bias, bias_init)
self.down1 = DownProjection(64, kernel_size, stride, padding, bias, bias_init)
self.up2 = UpProjection(64, kernel_size, stride, padding, bias, bias_init)
self.down2 = DenseDownProjection(64, kernel_size, stride, padding, bias, bias_init)
self.up3 = DenseUpProjection(64, kernel_size, stride, padding, bias, bias_init)
self.down3 = DenseDownProjection(64, kernel_size, stride, padding, bias, bias_init)
self.up4 = DenseUpProjection(64, kernel_size, stride, padding, bias, bias_init)
self.down4 = DenseDownProjection(64, kernel_size, stride, padding, bias, bias_init)
self.up5 = DenseUpProjection(64, kernel_size, stride, padding, bias, bias_init)
self.down5 = DenseDownProjection(64, kernel_size, stride, padding, bias, bias_init)
self.up6 = DenseUpProjection(64, kernel_size, stride, padding, bias, bias_init)
self.down6 = DenseDownProjection(64, kernel_size, stride, padding, bias, bias_init)
self.up7 = DenseUpProjection(64, kernel_size, stride, padding, bias, bias_init)
# Reconstruction
self.reconstruction = CONV(3, 3, 1, 1, bias, bias_init, False)
def __call__(self, x):
# Feature Extraction
x = self.f0(x)
x = self.f1(x)
# BP stage
h1 = self.up1(x)
l1 = self.down1(h1)
h2 = self.up2(l1)
h_cat = Concatenate(axis=-1)([h2, h1])
l2 = self.down2(h_cat)
l_cat = Concatenate(axis=-1)([l2, l1])
h = self.up3(l_cat)
h_cat = Concatenate(axis=-1)([h, h_cat])
l = self.down3(h_cat)
l_cat = Concatenate(axis=-1)([l, l_cat])
h = self.up4(l_cat)
h_cat = Concatenate(axis=-1)([h, h_cat])
l = self.down4(h_cat)
l_cat = Concatenate(axis=-1)([l, l_cat])
h = self.up5(l_cat)
h_cat = Concatenate(axis=-1)([h, h_cat])
l = self.down5(h_cat)
l_cat = Concatenate(axis=-1)([l, l_cat])
h = self.up6(l_cat)
h_cat = Concatenate(axis=-1)([h, h_cat])
l = self.down6(h_cat)
l_cat = Concatenate(axis=-1)([l, l_cat])
h = self.up7(l_cat)
h_cat = Concatenate(axis=-1)([h, h_cat])
# Reconstruction stage
x = self.reconstruction(h_cat)
return x