There is a MNIST CNN example within the template, prefixed with example, which shows how to create and use the template. Delete these files if you don't need them
Look for #TODO tags in these classes as hints of what you need to update to create your model.
In a nutshell here's how to use this template,for example to implement a VGG model you should do the following:
- Update the model script by creating a class called VGG and inherit from "BaseModel".
class VGG16(BaseModel):
def __init__(self, config: dict) -> None:
:param config: global configuration
- Update the create model function in this script to build the graph for your model architecture.
def _create_model(x: tf.Tensor, is_training: bool) -> tf.Tensor:
:param x: input data
:param is_training: flag if currently training
:return: completely constructed model
- Update the prediction and serving outputs for your model.
# TODO: update model predictions
predictions = {
"classes": tf.argmax(input=logits, axis=1),
"probabilities": tf.nn.softmax(logits),
if mode == tf.estimator.ModeKeys.PREDICT:
# TODO: update output during serving
export_outputs = {
"labels": tf.estimator.export.PredictOutput(
{"label": predictions["classes"], "id": features["id"]}
return tf.estimator.EstimatorSpec(
mode, predictions=predictions, export_outputs=export_outputs
- Add any summaries you want to have on tensorboard for training and evaluation.
# TODO: update summaries for tensorboard
tf.summary.scalar("loss", loss)
tf.summary.image("input", tf.reshape(x, [-1, 28, 28, 1]))
# if mode is evaluation
if mode == tf.estimator.ModeKeys.EVAL:
# TODO: update evaluation metrics
summaries_dict = {
"val_accuracy": tf.metrics.accuracy(
labels, predictions=predictions["classes"]
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, eval_metric_ops=summaries_dict
- Change your optimizer to suit your architecture, ensure if you are using Batch Normalisation that you are using control dependencies (see MNIST example).
# TODO: update optimiser
optimizer = tf.train.AdamOptimizer(lr)
- Update then train script, and make sure to inherit from the "BaseTrain" class and use your new model and data loaders
class VGGTrainer(BaseTrain):
def __init__(
config: dict,
model: VGGModel,
train: TFRecordDataLoader,
val: TFRecordDataLoader,
pred: TFRecordDataLoader,
) -> None:
This function will generally remain unchanged, it is used to train and
export the model. The only part which may change is the run
configuration, and possibly which execution to use (training, eval etc)
:param config: global configuration
:param model: input function used to initialise model
:param train: the training dataset
:param val: the evaluation dataset
:param pred: the prediction dataset
super().__init__(config, model, train, val, pred)
- In the same script update the export function to match the inputs for your model.
def _export_model(
self, estimator: tf.estimator.Estimator, save_location: str
) -> None:
Used to export your model in a format that can be used with
:param estimator: your estimator function
# this should match the input shape of your model
# TODO: update this to your input used in prediction/serving
x1 = tf.feature_column.numeric_column(
"input", shape=[self.config["batch_size"], 28, 28, 1]
# create a list in case you have more than one input
feature_columns = [x1]
feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)
export_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
# export the saved model
estimator.export_savedmodel(save_location, export_input_fn)
- Update data loader to correctly load your input data, add or remove augmentation as needed. This example is for tfrecords, it is possible to use other input data types ( Doing data loading and pre-processing on CPU helps reduce GPU bottlenecks.
def _parse_example(
self, example: tf.Tensor
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
Used to read in a single example from a tf record file and do any augmentations necessary
:param example: the tfrecord for to read the data from
:return: a parsed input example and its respective label
# do parsing on the cpu
with tf.device("/cpu:0"):
# define input shapes
# TODO: update this for your data set
features = {
"image": tf.FixedLenFeature(shape=[28, 28, 1], dtype=tf.float32),
"label": tf.FixedLenFeature(shape=[1], dtype=tf.int64),
example = tf.parse_single_example(example, features=features)
if self.mode == "train":
input_data = self._augment(example["image"])
input_data = example["image"]
return {"input": input_data}, example["label"]
- Update the task script ensuring your new model and trainer are used. These scripts are used to initialise and train your model.
def init() -> None:
The main function of the project used to initialise all the required functions for training the model
# get input arguments
args = get_args()
# get static config information
config = process_config()
# combine both into dictionary
config = {**config, **args}
# initialise model
model = VGGModel(config)
# create your data generators for each mode
train_data = TFRecordDataLoader(config, mode="train")
val_data = TFRecordDataLoader(config, mode="val")
test_data = TFRecordDataLoader(config, mode="test")
# initialise the estimator
trainer = VGGTrainer(config, model, train_data, val_data, test_data)
# start training
- Update the utils script adding any input arguments you need for your model, these will be added to the global config. For variables that are unlikely to change you can add them to the static config dictionary.
def process_config() -> dict:
Add in any static configuration that is unlikely to change very often
:return: a dictionary of static configuration data
config = {"exp_name": "example_model_train"}
return config
In order to train your model there is a series of bash scripts which can train your model for several different training environments. All of the local scripts will create log files. is created, or appended to, each time you run one of the scripts. You can use it as a scratch file to track your experiment details.
The stdout of your model will be written files in runlogs/ where each respective process will have a log. It also creates a .pid files which can be used to to kill the process if need be. An example of is shown below:
Example Training Job
Learning Rate: 0.001
Epochs: 100
Batch Size (train/eval): 512/ 512
Model will converge quickly
Model diverged even quicker
If you are not comfortable with vim or do not want to use this, you can remove it from scripts.
For each of the scripts you are going to need to update the hyper-parameters you are wanting to use for this training run. Cloud based file paths won't work on windows. Add any additional input arguments that you have added to your model.
# where to write tfevents
# experiment settings
# create a job name for the this run
now=$(date +"%Y%m%d_%H_%M_%S")
# locations locally or on the cloud for your files
Train CPU
This script will train the model without using any GPUs and you can optionally
specify a python environment to run the project from.
Usage: ./ [ENV_NAME]
Train GPU
This script will train the model using on specific GPU and you can optionally
specify a python environment to run the project from. It will also check to ensure
you have setup the CUDA environment variables. To find out GPU usage or which
GPU_ID to use you can run this in your terminal.
Usage: ./ <GPU_ID> [ENV_NAME]
Train distributed GPU
This script will allow you to simulate a distributed training environment locally on as many GPUs
as your machine has. In order to do this you must split the GPUs into workers, masters and parameter servers.
GPUs can be allocated to each of these types. You can also set this up directly in python ( The script contains an example using 3 GPUs:
\"master\": [\"localhost:27182\"],
\"ps\": [\"localhost:27183\"],
\"worker\": [
}, \"environment\": \"cloud\""
# ensure parameter server doesn't use any of the GPUs in this case
# Parameter Server can be run on cpu
task="{\"type\": \"ps\", \"index\": 0}"
export TF_CONFIG="{\"cluster\":${config}, \"task\":${task}}"
run ps
# Master should be run on GPU as it runs the evaluation
task="{\"type\": \"master\", \"index\": 0}"
export TF_CONFIG="{\"cluster\":${config}, \"task\":${task}}"
run master
# Workers (Number of GPUS-1 one used by the master server)
for gpu in 0 1
task="{\"type\": \"worker\", \"index\": $gpu}"
export TF_CONFIG="{\"cluster\":${config}, \"task\":${task}}"
run "worker${gpu}"
This setup has 1 master, 1 parameter server, and two workers. The master is allocated one GPU and the workers also have 1 GPU each. The parameter sever will be run on CPU. When defining new configurations you have to ensure that the ports used in the config are not in use.
Usage: ./ [ENV_NAME]
Train Cloud
This script requires that you have Google Cloud SDK installed (, and a Google Cloud Platform account
with access to ml-engine. Trial GCP accounts come with credit if you want to try this out. Training on the cloud does cost money, but it is very simple once setup.
The hptuning_config.yaml file will be used to specify the resources you are requesting for this job.
You are able to scale this for your needs, it will behave the same as the local distributed training.
More information here:
See # here:
It is required that the data is stored on GCP somewhere in a bucket, and you also need to specify what bucket to export your model and checkpoints to. Ensure that any additional packages your model needs are defined in and make sure you aren't specifying packages that are already part of ml-engine (
Usage: ./
├── base
│ ├── - this script contains the abstract class of the tfrecord data loader.
│ ├── - this script contains the abstract class of the model.
│ └── - this script contains the abstract class of the model trainer.
├── data - this folder contains any data your project may need.
├── data_loader
│ └── - this script is responsible for all data handling.
├── initialisers
│ └── - this script is used to start training model
├── models
│ └── - this script is where your model is defined for each training phase.
├── trainers
│ └── - this script is where your estimator configuration is defined.
└── utils
├── - this script is an example how to create tfrecords from numpy or cvs files
└── - this script handles your input variables and defines a global config.
Base model is an abstract class that must be Inherited by any model you create, the only requirement is that you implement a model function which is compliant with Estimator API.
- Model This function to is where your experiment is defined
- Create Model This function to is where your model architecture is defined
Here's where you implement your model. So you should:
- Create your model class and inherit the base_model class
- Override "model" where you write the tensorflow Estimator experiment
- Override "create_model" where you write your model architecture
Base trainer is an abstract class that just wraps the training process.
- Run This function sets up the Estimator configuration for the different training stages and runs your training loop
- Export Model This function exports the model to a given location with compatibility for Tensorflow Serving
- Predict This function takes a prediction dataset and returns predicted values using your model based off it's last saved checkpoint
Here's what you should implement in your trainer.
So you should:- Override the export_model function to match the inputs of your model
- Override the predict function for your projects requirements
This class is responsible for all data handling and processing and provides an easy interface that can be used by the trainer.
The current loader uses tfrecords which are the recommended way of loading data into a Tensorflow model.
If you are using tfrecords you should:
- Update the parse_example function so the input feature maps are the same as your model
- Add or remove any augmentation methods
Add any static configuration variables to the dict in utils, otherwise it is recommended to use input variables to handle the configuration settings
Here's where you combine all previous part.
- Update the model to use your new model class
- Update the datasets to point to your data loader function for each respective training step (train, eval, prediction etc)
- Update the trainer to use your new trainer class
This can now be used as the entry point to run your experiment.
- This project includes settings for the python linter Flake8 which are compatible with black and mypy. Update .flake8 if you would like to change these settings
- This project includes settings for the python formatter Black. The settings for black should be defined in the pyproject.toml file
- This project includes settings for the python optional static type check Mypy. The settings are defined in mypy.ini you can define files to ignore in here.
- This project includes settings for pre-commit hooks to run flake8, black and mypy. The settings are defined in .pre-commit-config.yaml. To maintain good code quality it is recommended you install and use pre-commit, which runs these tools each time you commit, ensuring they pass before you can push. If you are wanting to use this, make sure the other tools are installed using pip then run within your model folder:
pip install pre-commit
pre-commit install
- Update the details and requirements.txt for your project, specifying any packages that you need for your project.
Any kind of enhancement or contribution is welcomed.