forked from cyrilli/Generative-Model-for-Video-Compression
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_compression.py
84 lines (71 loc) · 4.32 KB
/
model_compression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import tensorflow as tf
import sys
sys.path.append("./dcgan/")
import tensorlayer as tl
from model import *
from tensorlayer.layers import *
def vanilla_encoder(inputs, z_dim=100,is_train=True, reuse=False):
'''
Build a vanilla encoder with convolutional layers, lrelu, batch norm
:param inputs:
:param z_dim:
:param is_train:
:param reuse:
:return:
'''
df_dim = 64 # Dimension of discrim filters in first conv layer. [64]
w_init = tf.random_normal_initializer(stddev=0.02)
gamma_init = tf.random_normal_initializer(1., 0.02)
with tf.variable_scope("encoder", reuse=reuse):
tl.layers.set_name_reuse(reuse)
net_in = InputLayer(inputs, name='d/in')
net_h0 = Conv2d(net_in, df_dim, (5, 5), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2),
padding='SAME', W_init=w_init, name='enc/h0/conv2d')
net_h1 = Conv2d(net_h0, df_dim * 2, (5, 5), (2, 2), act=None,
padding='SAME', W_init=w_init, name='enc/h1/conv2d')
net_h1 = BatchNormLayer(net_h1, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='enc/h1/batch_norm')
net_h2 = Conv2d(net_h1, df_dim * 4, (5, 5), (2, 2), act=None,
padding='SAME', W_init=w_init, name='enc/h2/conv2d')
net_h2 = BatchNormLayer(net_h2, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='enc/h2/batch_norm')
net_h3 = Conv2d(net_h2, df_dim * 8, (5, 5), (2, 2), act=None,
padding='SAME', W_init=w_init, name='enc/h3/conv2d')
net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2),
is_train=is_train, gamma_init=gamma_init, name='enc/h3/batch_norm')
net_h4 = FlattenLayer(net_h3, name='enc/h4/flatten')
net_h4 = DenseLayer(net_h4, n_units=z_dim, act=tf.identity,
W_init=w_init, name='enc/h4/lin_sigmoid')
logits = net_h4.outputs
net_h4.outputs = tf.nn.sigmoid(net_h4.outputs)
return net_h4, logits
def dcgan_decoder(inputs, image_size = 64, c_dim=3, batch_size=64, is_train=False, reuse=False):
s2, s4, s8, s16 = int(image_size/2), int(image_size/4), int(image_size/8), int(image_size/16)
gf_dim = 64 # Dimension of gen filters in first conv layer. [64]
w_init = tf.random_normal_initializer(stddev=0.02)
gamma_init = tf.random_normal_initializer(1., 0.02)
with tf.variable_scope("generator", reuse=reuse):
tl.layers.set_name_reuse(reuse)
net_in = InputLayer(inputs, name='g/in')
net_h0 = DenseLayer(net_in, n_units=gf_dim*8*s16*s16, W_init=w_init,
act = tf.identity, name='g/h0/lin')
net_h0 = ReshapeLayer(net_h0, shape=[-1, s16, s16, gf_dim*8], name='g/h0/reshape')
net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h0/batch_norm')
net_h1 = DeConv2d(net_h0, gf_dim*4, (5, 5), out_size=(s8, s8), strides=(2, 2),
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h1/decon2d')
net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h1/batch_norm')
net_h2 = DeConv2d(net_h1, gf_dim*2, (5, 5), out_size=(s4, s4), strides=(2, 2),
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h2/decon2d')
net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h2/batch_norm')
net_h3 = DeConv2d(net_h2, gf_dim, (5, 5), out_size=(s2, s2), strides=(2, 2),
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h3/decon2d')
net_h3 = BatchNormLayer(net_h3, act=tf.nn.relu, is_train=is_train,
gamma_init=gamma_init, name='g/h3/batch_norm')
net_h4 = DeConv2d(net_h3, c_dim, (5, 5), out_size=(image_size, image_size), strides=(2, 2),
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h4/decon2d')
logits = net_h4.outputs
net_h4.outputs = tf.nn.tanh(net_h4.outputs)
return net_h4, logits