This repository contains the implementation of the paper Deep Frank-Wolfe For Neural Network Optimization in pytorch. If you use this work for your research, please cite the paper:
@Article{berrada2018deep,
author = {Berrada, Leonard and Zisserman, Andrew and Kumar, M Pawan},
title = {Deep Frank-Wolfe For Neural Network Optimization},
journal = {Under review},
year = {2018},
}
The DFW algorithm is a first-order optimization algorithm for deep neural networks. To use it for your learning task, consider the two following requirements:
- the loss function has to be convex piecewise linear function (e.g. multi-class SVM as implemented here, or l1 loss)
- the optimizer needs access to the value of the loss function of the current mini-batch as shown here
Beside these requirements, the optimizer can be used as plug-and-play, and its independent code is available in src/optim/dfw.py
This code has been tested for pytorch 0.4.1 in python3. Detailed requirements are available in requirements.txt
.
- To reproduce the CIFAR experiments:
VISION_DATA=[path/to/your/cifar/data] python scripts/reproduce_cifar.py
- To reproduce the SNLI experiments: follow the preparation instructions and run
python scripts/reproduce_snli.py
Note that SGD benefits from a hand-designed learning rate schedule. In contrast, all the other optimizers (including DFW) automatically adapt their steps and rely on the tuning of the initial learning rate only. On average, you should obtain similar results to the ones reported in the paper (there might be some variance on some instances of CIFAR experiments):
Wide Residual Networks | Densely Connected Networks | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
Wide Residual Networks | Densely Connected Networks | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
CE Loss | SVM Loss | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
We use the following third-part implementations: