Skip to content
/ tadlib Public

a Tiny Automatic Differentiation Library for understanding how neural networks works, implemented in pure Java

License

Notifications You must be signed in to change notification settings

pingng/tadlib

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TADLib - Tiny Automatic Differentiation Library

What is TADLib?

TADLib was created for understanding how 'autograd' and basic neural networks can be implemented. It provides Automatic Differentiation (AD) like Tensorflow and PyTorch, but is implemented in pure Java. Hardware acceleration is supported using OpenCL.

The main code uses tensors/multi dim arrays. A separate auto grad implementation using only scalar values can also be found. It should be even easier to understand the concept of auto grad from this package.

Examples

A fully connected Neural Network for the MNIST dataset

First we need some weights:

hiddenW = tensor(randWeight(weightRnd, shape(28 * 28, 32)));
hiddenB = tensor(randWeight(weightRnd, shape(32)));

outW = tensor(randWeight(weightRnd, shape(32, 10)));
outB = tensor(randWeight(weightRnd, shape(10)));

Then we need the forward pass:

// relu(inputs @ hiddenW + hiddenB)
Tensor firstLayer = relu(add(
        matmul(xTrain, hiddenW),
        hiddenB));
// (firstLayer @ outW + outB)
Tensor prediction = add(
        matmul(firstLayer, outW),
        outB);

The outLayer is the output logits for each output classes.

We must then calculate the loss (scaled by the number of examples in the batch):

Tensor totalSoftmaxCost = sumSoftmaxCrossEntropy(toOneHot(trainingData.output, 10), prediction);
Tensor avgSoftmaxCost = div(totalSoftmaxCost, constant(trainingData.getBatchSize()));

Then we trigger backpropagation of the gradients:

avgSoftmaxCost.backward();

And, finally, update the weights (plain SGD):

hiddenW.update((values, gradient) -> values.sub(gradient.mul(learningRate)));
hiddenB.update((values, gradient) -> values.sub(gradient.mul(learningRate)));
outW.update((values, gradient) -> values.sub(gradient.mul(learningRate)));
outB.update((values, gradient) -> values.sub(gradient.mul(learningRate)));

Examples

OpenCL support

OpenCL support can be enabled by assigning the provider:

ProviserStore.setProvider(new OpenCLProvider());

Operations will run a lot faster using OpenCL.

The OpenCL integration, as the java code, is naive & minimal to allow for (hopefully) better readability. The performance will certainly not reach the level of Tensorflow nor PyTorch, but it will be fast enough for more experimentation and less time waiting.

See the opencl package for more details.

Scalar implementation

An even simpler implementation using scalar values can be found in the singlevalue package.

About

What is the point/goal of TADLib?

The focus of TADLib is to show how nn works under the hood. It runs conceptually like Tensorflow or PyTorch in eager/immediate mode. TADLib is of course much more simple and runs orders of magnitude slower. The advantage is that it allows you to follow/debug/trace the flow of each value, since it is implemented with plain double arrays and uses normal java math operations.

The code is meant to be simple to read and not too difficult to follow. Some limitations are:

  • limited set of math ops
  • no/minimal optimizations
  • immutable (mostly)
  • ...which means it is slow :)

What can it do?

It provides all the primitives to implement a standard multi layered convolutional neural net for the MNIST-class problems. Using TADLib is like coding a nn in Tensorflow using Variables and math ops to manually create the layers and structure of the model, but with the added verbosity of Java.

It is possible to create larger models with TADLib, but it will run too slow to be practically usable.

There is support for parallel execution of some math operations. It helps training of MNIST like dataset, but it will still be too slow for real world problems.

References

The main auto grad structure of TADLib is heavily inspired by Joel Grus' auto grad tutorial:

https://www.youtube.com/playlist?list=PLeDtc0GP5ICldMkRg-DkhpFX1rRBNHTCs
(Livecoding an Autograd Library)

A huge thanks to Joel :)

Other refs:

License

Copyright © 2021, Ping Ng Released under the MIT License.