-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathexample_model.py
148 lines (124 loc) · 5.22 KB
/
example_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import tensorflow as tf
from base.model import BaseModel
from typing import Dict
class Mnist(BaseModel):
def __init__(self, config: dict) -> None:
"""
Create a model used to classify hand written images using the MNIST dataset
:param config: global configuration
"""
super().__init__(config)
def model(
self, features: Dict[str, tf.Tensor], labels: tf.Tensor, mode: str
) -> tf.Tensor:
"""
Define your model metrics and architecture, the logic is dependent on the mode.
:param features: A dictionary of potential inputs for your model
:param labels: Input label set
:param mode: Current training mode (train, test, predict)
:return: An estimator spec used by the higher level API
"""
# set flag if the model is currently training
is_training = mode == tf.estimator.ModeKeys.TRAIN
# get input data
image = features["input"]
# initialise model architecture
logits = _create_model(image, self.config["keep_prob"], is_training)
# define model predictions
predictions = {
"class": tf.argmax(input=logits, axis=1),
"probabilities": tf.nn.softmax(logits),
}
if mode == tf.estimator.ModeKeys.PREDICT:
# define what to output during serving
export_outputs = {
"labels": tf.estimator.export.PredictOutput(
{"id": features["id"], "label": predictions["class"]}
)
}
return tf.estimator.EstimatorSpec(
mode, predictions=predictions, export_outputs=export_outputs
)
# calculate loss
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
# add summaries for tensorboard
tf.summary.scalar("loss", loss)
tf.summary.image("input", tf.reshape(image, [-1, 28, 28, 1]))
if mode == tf.estimator.ModeKeys.EVAL:
# create a evaluation metric
summaries_dict = {
"val_accuracy": tf.metrics.accuracy(
labels, predictions=predictions["classes"]
)
}
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, eval_metric_ops=summaries_dict
)
# assert only reach this point during training mode
assert mode == tf.estimator.ModeKeys.TRAIN
# collect operations which need updating before back-prob e.g. Batch norm
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# create learning rate variable for hyper-parameter tuning
lr = tf.Variable(
initial_value=self.config["learning_rate"], name="learning-rate"
)
# initialise optimiser
optimizer = tf.train.AdamOptimizer(lr)
# Do these operations after updating the extra ops due to BatchNorm
with tf.control_dependencies(extra_update_ops):
train_op = optimizer.minimize(
loss,
global_step=tf.train.get_global_step(),
colocate_gradients_with_ops=True,
)
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
def _fc_block(x: tf.Tensor, size: int, is_training: bool, drop: float) -> tf.Tensor:
"""
Create a fully connected block using batch-norm and drop out
:param x: input layer which proceeds this block
:param size: number of nodes in layer
:param is_training: flag if currently training
:param drop: percentage of data to drop
:return: fully connected block of layers
"""
x = tf.layers.Dense(size)(x)
x = tf.layers.BatchNormalization(fused=True)(x, training=is_training)
x = tf.nn.relu(x)
return tf.layers.Dropout(drop)(x, training=is_training)
def _conv_block(
x: tf.Tensor, layers: int, filters: int, is_training: bool
) -> tf.Tensor:
"""
Create a convolutional block using batch norm
:param x: input layer which proceeds this block
:param layers: number of conv blocks to create
:param filters: number of filters in each conv layer
:param is_training: flag if currently training
:return: block/s of residual layers
"""
for i in range(layers):
x = tf.layers.Conv2D(filters, 3, padding="same")(x)
x = tf.layers.BatchNormalization(fused=True)(x, training=is_training)
x = tf.nn.relu(x)
return tf.layers.MaxPooling2D(2, 2, padding="valid")(x)
def _create_model(x: tf.Tensor, drop: float, is_training: bool) -> tf.Tensor:
"""
A basic deep CNN used to train the MNIST classifier
:param x: input data
:param drop: percentage of data to drop during dropout
:param is_training: flag if currently training
:return: completely constructed model
"""
x = tf.reshape(x, [-1, 28, 28, 1])
_layers = [1, 1]
_filters = [32, 64]
# create the residual blocks
for i, l in enumerate(_layers):
x = _conv_block(x, l, _filters[i], is_training)
x = tf.layers.Flatten()(x)
_fc_size = [1024]
# create the fully connected blocks
for s in _fc_size:
x = _fc_block(x, s, is_training, drop)
# add an output layer (10 classes, one output for each)
return tf.layers.Dense(10)(x)