-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerator.py
119 lines (86 loc) · 4.04 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
class Conv2dLSTMCell(tf.keras.Model):
def __init__(self, latent_dim, kernel_size=5):
super().__init__()
args = [latent_dim, kernel_size]
kwargs = {'padding': 'SAME'}
self.forget = Conv2D(*args, **kwargs, activation=tf.sigmoid)
self.inp = Conv2D(*args, **kwargs, activation=tf.sigmoid)
self.outp = Conv2D(*args, **kwargs, activation=tf.sigmoid)
self.state = Conv2D(*args, **kwargs, activation=tf.tanh)
def call(self, input, cell):
forget_gate = self.forget(input)
input_gate = self.inp(input)
output_gate = self.outp(input)
state_gate = self.state(input)
cell = forget_gate * cell + input_gate * state_gate
hidden = output_gate * tf.tanh(cell)
return hidden, cell
class LatentDistribution(tf.keras.Model):
def __init__(self, z_dim):
super().__init__()
self.z_dim = z_dim
self.parametrize = Conv2D(z_dim * 2, 5, padding='SAME')
def call(self, input):
parametrization = self.parametrize(input)
mu, sigma = tf.split(parametrization, [self.z_dim, self.z_dim], -1)
return tf.distributions.Normal(loc=mu, scale=tf.nn.softplus(sigma))
class Generator(tf.keras.Model):
def __init__(self, x_dim, z_dim, h_dim, L):
super().__init__()
self.L = L
self.z_dim = z_dim
self.h_dim = h_dim
self.inference_core = Conv2dLSTMCell(h_dim)
self.generator_core = Conv2dLSTMCell(h_dim)
self.posterior_distribution = LatentDistribution(z_dim)
self.prior_distribution = LatentDistribution(z_dim)
self.observation_density = Conv2D(
x_dim, 1, padding='SAME', activation=tf.sigmoid)
self.upsample = Conv2DTranspose(h_dim, 4, strides=4)
self.downsample = Conv2D(h_dim, 4, strides=4)
def call(self, x, v, r):
batch_size, v_dim = v.shape
batch_size, im_size, im_size, x_dim = x.shape
batch_size, r_size, r_size, r_dim = r.shape
v = tf.tile(v, [1, r_size * r_size])
v = tf.reshape(v, [-1, r_size, r_size, v_dim])
kl = 0
_, im_size, im_size, _ = x.shape
c_g = tf.zeros([batch_size, r_size, r_size, self.h_dim])
h_g = tf.zeros([batch_size, r_size, r_size, self.h_dim])
u = tf.zeros([batch_size, im_size, im_size, self.h_dim])
c_i = tf.zeros([batch_size, r_size, r_size, self.h_dim])
h_i = tf.zeros([batch_size, r_size, r_size, self.h_dim])
x = self.downsample(x)
for _ in range(self.L):
prior_factor = self.prior_distribution(h_g)
input = tf.concat([h_i, h_g, x, v, r], 3)
h_i, c_i = self.inference_core(input, c_i)
posterior_factor = self.posterior_distribution(h_i)
z = posterior_factor.sample()
input = tf.concat([h_g, z, v, r], 3)
h_g, c_g = self.generator_core(input, c_g)
u = self.upsample(h_g) + u
kl += tf.reduce_mean(
tf.distributions.kl_divergence(posterior_factor, prior_factor)
)
x_mu = self.observation_density(u)
return x_mu, kl
def sample(self, v, r, im_size):
batch_size, v_dim = v.shape
batch_size, r_size, r_size, r_dim = r.shape
v = tf.tile(v, [1, r_size * r_size])
v = tf.reshape(v, [-1, r_size, r_size, v_dim])
c_g = tf.zeros([batch_size, r_size, r_size, self.h_dim])
h_g = tf.zeros([batch_size, r_size, r_size, self.h_dim])
u = tf.zeros([batch_size, im_size, im_size, self.h_dim])
for _ in range(self.L):
prior_factor = self.prior_distribution(h_g)
z = prior_factor.sample()
input = tf.concat([h_g, z, v, r], 3)
h_g, c_g = self.generator_core(input, c_g)
u = self.upsample(h_g) + u
x_mu = self.observation_density(u)
return x_mu