This code was used for experiments with Wide Residual Networks http://arxiv.org/abs/1605.07146 by Sergey Zagoruyko and Nikos Komodakis.
Deep residual networks were shown to be able to scale up to thousands of layers and still have improving performance. However, each fraction of a percent of improved accuracy costs nearly doubling the number of layers, and so training very deep residual networks has a problem of diminishing feature reuse, which makes these networks very slow to train.
To tackle these problems, in this work we conduct a detailed experimental study on the architecture of ResNet blocks, based on which we propose a novel architecture where we decrease depth and increase width of residual networks. We call the resulting network structures wide residual networks (WRNs) and show that these are far superior over their commonly used thin and very deep counterparts.
For example, we demonstrate that even a simple 16-layer-deep wide residual network outperforms in accuracy and efficiency all previous deep residual networks, including thousand-layer-deep networks. We further show that WRNs achieve incredibly good results (e.g., achieving new state-of-the-art results on CIFAR-10, CIFAR-100 and SVHN) and train several times faster than pre-activation ResNets.
Test error (%, flip/translation augmentation) on CIFAR:
Method | CIFAR-10 | CIFAR-100 |
---|---|---|
pre-ResNet-164 | 5.46 | 24.33 |
pre-ResNet-1001 | 4.92 | 22.71 |
WRN-28-10 | 4.17 | 20.5 |
WRN-28-10-dropout | 4.39 | 20.0 |
See http://arxiv.org/abs/1605.07146 for details.
The code depends on Torch http://torch.ch. Follow instructions here and run:
luarocks install optnet
luarocks install iterm
We recommend installing CUDNN v5 for speed. Alternatively you can run on CPU or on GPU with OpenCL (coming).
For visualizing training curves we used ipython notebook with pandas and bokeh and suggest using anaconda.
The code supports loading simple datasets in torch format. We provide the following:
- MNIST data preparation script
- CIFAR-10 (coming)
- CIFAR-10 whitened (using pylearn2) preprocessed dataset
- CIFAR-100 (coming)
- CIFAR-100 whitened (using pylearn2) preprocessed dataset
- SVHN data preparation script
To whiten CIFAR-10 and CIFAR-100 we used the following scripts https://github.com/lisa-lab/pylearn2/blob/master/pylearn2/scripts/datasets/make_cifar10_gcn_whitened.py and then converted to torch using https://gist.github.com/szagoruyko/ad2977e4b8dceb64c68ea07f6abf397b.
We are running ImageNet experiments and will update the paper and this repo soon.
We provide several scripts for reproducing results in the paper. Below are several examples.
model=wide-resnet widen_factor=4 depth=40 ./scripts/train_cifar.sh
This will train WRN-40-4 on CIFAR-10 whitened (supposed to be in datasets
folder). This network achieves about the same accuracy as ResNet-1001 and trains in 6 hours on a single Titan X.
Log is saved to logs/wide-resnet_$RANDOM$RANDOM
folder with json entries for each epoch and can be visualized with itorch/ipython later.
For reference we provide logs for this experiment and ipython notebook to visualize the results. After running it you should see these training curves:
Another example:
model=wide-resnet widen_factor=10 depth=28 dropout=0.3 dataset=./datasets/cifar100_whitened.t7 ./scripts/train_cifar.sh
This network achieves 20.0% error on CIFAR-100 in about a day on a single Titan X.
As WRNs are much faster to train than ResNets we don't provide any multi-GPU code. https://github.com/facebook/fb.resnet.torch should be trivial to modify for running WRN.
Additional models in this repo:
- NIN (7.4% on CIFAR-10)
- VGG (modified from cifar.torch, 6.3% on CIFAR-10)
- pre-activation ResNet (from https://github.com/KaimingHe/resnet-1k-layers)
The code evolved from https://github.com/szagoruyko/cifar.torch. To reduce memory usage we use @fmassa's optimize-net, which automatically shares output and gradient tensors between modules. This keeps memory usage below 4 Gb even for our best networks. Also, it can generate network graph plots as the one for WRN-16-2 below.