-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmodel.py
71 lines (62 loc) · 2.84 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import tensorflow as tf
from layers.encoder import Encoder
from layers.decoder import Decoder
from layers.vae import VariationalAutoencoder
class Model(tf.keras.models.Model):
def __init__(self,
data_format='channels_last',
groups=8,
reduction=2,
l2_scale=1e-5,
dropout=0.2,
downsampling='conv',
upsampling='conv',
base_filters=16,
depth=4,
in_ch=2,
out_ch=3):
""" Initializes the model, a cross between the 3D U-net
and 2018 BraTS Challenge top model with VAE regularization.
References:
- [3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation](https://arxiv.org/pdf/1606.06650.pdf)
- [3D MRI brain tumor segmentation using autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf)
"""
super(Model, self).__init__()
self.epoch = tf.Variable(0, name='epoch', trainable=False)
self.encoder = Encoder(
data_format=data_format,
groups=groups,
reduction=reduction,
l2_scale=l2_scale,
dropout=dropout,
downsampling=downsampling,
base_filters=base_filters,
depth=depth)
self.decoder = Decoder(
data_format=data_format,
groups=groups,
reduction=reduction,
l2_scale=l2_scale,
upsampling=upsampling,
base_filters=base_filters,
depth=depth,
out_ch=out_ch)
self.vae = VariationalAutoencoder(
data_format=data_format,
groups=groups,
reduction=reduction,
l2_scale=l2_scale,
upsampling=upsampling,
base_filters=base_filters,
depth=depth,
out_ch=in_ch)
def call(self, inputs, training=None, inference=None):
# Inference mode does not evaluate VAE branch.
assert (not inference or not training), \
'Cannot run training and inference modes simultaneously.'
inputs = self.encoder(inputs, training=training)
y_pred = self.decoder((inputs[-1], inputs[:-1]), training=training)
if inference:
return (y_pred, None, None, None)
y_vae, z_mean, z_logvar = self.vae(inputs[-1], training=training)
return (y_pred, y_vae, z_mean, z_logvar)