Skip to content

Commit

Permalink
Allow arbitrary dimension numbers in stax layers.
Browse files Browse the repository at this point in the history
* Note that FC / GlobalAvgPool layers will still enforce `NC` dimension numbers for now, and
* batching still only works with the leading batch (`N`) dimension, so keeping `N` as leading dimension is highly recommended.
* `NHWC` is recommended currently for attention, `NCHW` or `NHWC` for CNNs.

Also:
1) Remove some warnings that I feel are redundant given our codebase - please let me know if I'm wrong.
2) Make the code slightly more generic, hopefully facilitating future ND-cnn cases.
3) Make `Flatten` work on batches of size 0.
4) Make `GeneralConv` public as it now supports different dimension numbers.
PiperOrigin-RevId: 290168022
  • Loading branch information
romanngg committed Jan 27, 2020
1 parent 818678a commit 98ce60d
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 5,172 deletions.
29 changes: 8 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Then either run
```
pip install neural-tangents
```
or, to use the bleeding-edge version from GitHub source,
or, to build the bleeding-edge version from source,
```
git clone https://github.com/google/neural-tangents
pip install -e neural-tangents
Expand Down Expand Up @@ -79,7 +79,6 @@ colab examples:
- [Neural Tangents Cookbook](https://colab.sandbox.google.com/github/google/neural-tangents/blob/master/notebooks/neural_tangents_cookbook.ipynb)
- [Weight Space Linearization](https://colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/weight_space_linearization.ipynb)
- [Function Space Linearization](https://colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/function_space_linearization.ipynb)
- [Neural Network Phase Diagram](https://colab.sandbox.google.com/github/google/neural-tangents/blob/master/notebooks/phase_diagram.ipynb)


## 5-Minute intro
Expand Down Expand Up @@ -168,7 +167,7 @@ y_test_ntk = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test,

### Infinitely WideResnet

We can define a more compex, (infinitely) Wide Residual Network [[14]](#14-wide-residual-networks-bmvc-2018-sergey-zagoruyko-nikos-komodakis) using the same `nt.stax` building blocks:
We can define a more compex, (infinitely) Wide Residual Network [[14]](#8-wide-residual-networks-bmvc-2018-sergey-zagoruyko-nikos-komodakis) using the same `nt.stax` building blocks:

```python
from neural_tangents import stax
Expand Down Expand Up @@ -246,7 +245,7 @@ import neural_tangents as nt # 64-bit precision enabled
We remark the following differences between our library and the JAX one.

* All `nt.stax` layers are instantiated with a function call, i.e. `nt.stax.Relu()` vs `jax.experimental.stax.Relu`.
* All layers with trainable parameters use the _NTK parameterization_ by default (see [[10]](#10-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler), Remark 1). However, Dense and Conv layers also support the _standard parameterization_ via a `parameterization` keyword argument (see [[15]](#15-on-the-infinite-width-limit-of-neural-networks-with-a-standard-parameterization)).
* All layers with trainable parameters use the _NTK parameterization_ by default (see [[10]](#5-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler), Remark 1). However, Dense and Conv layers also support the _standard parameterization_ via a `parameterization` keyword argument. <!-- TODO(jaschasd) add link to note deriving NTK for standard parameterization -->
* `nt.stax` and `jax.experimental.stax` may have different layers and options available (for example `nt.stax` layers support `CIRCULAR` padding, but only `NHWC` data format).

### Python 2 is not supported
Expand All @@ -259,7 +258,7 @@ The kernel of an infinite network `kernel_fn(x1, x2).ntk` combined with `nt.pre

### Weight space

Continuous gradient descent in an infinite network has been shown in [[11]](#11-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) to correspond to training a _linear_ (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.
Continuous gradient descent in an infinite network has been shown in [[11]](#6-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) to correspond to training a _linear_ (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.

For this, we provide two convenient methods:

Expand Down Expand Up @@ -298,7 +297,7 @@ logits = apply_fn_lin((W, b), x) # (3, 2) np.ndarray

### Function space:

Outputs of a linearized model evolve identically to those of an infinite one [[11]](#11-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) but with a different kernel - specifically, the Neural Tangent Kernel [[10]](#10-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler) evaluated on the specific `apply_fn` of the finite network given specific `params_0` that the network is initialized with. For this we provide the `nt.empirical_kernel_fn` function that accepts any `apply_fn` and returns a `kernel_fn(x1, x2, params)` that allows to compute the empirical NTK and NNGP kernels on specific `params`.
Outputs of a linearized model evolve identically to those of an infinite one [[11]](#6-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) but with a different kernel - specifically, the Neural Tangent Kernel [[10]](#5-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler) evaluated on the specific `apply_fn` of the finite network given specific `params_0` that the network is initialized with. For this we provide the `nt.empirical_kernel_fn` function that accepts any `apply_fn` and returns a `kernel_fn(x1, x2, params)` that allows to compute the empirical NTK and NNGP kernels on specific `params`.

#### Example:

Expand Down Expand Up @@ -356,25 +355,15 @@ a small dataset using a small learning rate.

## Papers

Neural Tangents has been used in the following papers:


* [Disentangling Trainability and Generalization in Deep Learning.](https://arxiv.org/abs/1912.13053) \
Lechao Xiao, Jeffrey Pennington, Samuel S. Schoenholz

* [Information in Infinite Ensembles of Infinitely-Wide Neural Networks.](https://arxiv.org/abs/1911.09189) \
Ravid Shwartz-Ziv, Alexander A. Alemi

* [Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel.](https://arxiv.org/abs/1905.13654) \
Soufiane Hayou, Arnaud Doucet, Judith Rousseau
Neural tangents has been used in the following papers:

* [Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient
Descent.](https://arxiv.org/abs/1902.06720) \
Jaehoon Lee*, Lechao Xiao*, Samuel S. Schoenholz, Yasaman Bahri, Roman Novak, Jascha
Sohl-Dickstein, Jeffrey Pennington

* [On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf) \
Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee
* [Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel.](https://arxiv.org/abs/1905.13654) \
Soufiane Hayou, Arnaud Doucet, Judith Rousseau

Please let us know if you make use of the code in a publication and we'll add it
to the list!
Expand Down Expand Up @@ -427,5 +416,3 @@ If you use the code in a publication, please cite the repo using the .bib,
##### [13] [Mean Field Residual Networks: On the Edge of Chaos.](https://arxiv.org/abs/1712.08969) *NeurIPS 2017.* Greg Yang, Samuel S. Schoenholz

##### [14] [Wide Residual Networks.](https://arxiv.org/abs/1605.07146) *BMVC 2018.* Sergey Zagoruyko, Nikos Komodakis

##### [15] [On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf) *arXiv 2020.* Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee
Loading

0 comments on commit 98ce60d

Please # to comment.