Example of training a model (and saving and restoring checkpoints) using the TensorFlow Java API.
-
Train for a few steps:
mvn -q compile exec:java -Dexec.args="model/graph.pb checkpoint"
-
Resume training from previous checkpoint and train some more:
mvn -q exec:java -Dexec.args="model/graph.pb checkpoint"
-
Delete checkpoint:
rm -rf checkpoint
The model in model/graph.pb
represents a very simple linear model:
y = x * W + b
The graph.pb
file is generated by executing create_graph.py
in Python.
The training is orchestrated by src/main/java/Train.java
, which generates
training data of the form y = 3.0 * x + 2.0
and over time, using gradient
descent, the model should "learn" and the value of W
should converge to 3.0,
and b
to 2.0.