Skip to content
/ dfw Public
forked from oval-group/dfw

Implementation of the Deep Frank-Wolfe Algorithm -- Pytorch

License

Notifications You must be signed in to change notification settings

petropusz/dfw

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Deep Frank-Wolfe For Neural Network Optimization

This repository contains the implementation of the paper Deep Frank-Wolfe For Neural Network Optimization in pytorch. If you use this work for your research, please cite the paper:

@Article{berrada2018deep,
  author       = {Berrada, Leonard and Zisserman, Andrew and Kumar, M Pawan},
  title        = {Deep Frank-Wolfe For Neural Network Optimization},
  journal      = {Under review},
  year         = {2018},
}

The DFW algorithm is a first-order optimization algorithm for deep neural networks. To use it for your learning task, consider the two following requirements:

  • the loss function has to be convex piecewise linear function (e.g. multi-class SVM as implemented here, or l1 loss)
  • the optimizer needs access to the value of the loss function of the current mini-batch as shown here

Beside these requirements, the optimizer can be used as plug-and-play, and its independent code is available in src/optim/dfw.py

Requirements

This code has been tested for pytorch 0.4.1 in python3. Detailed requirements are available in requirements.txt.

Reproducing the Results

  • To reproduce the CIFAR experiments: VISION_DATA=[path/to/your/cifar/data] python scripts/reproduce_cifar.py
  • To reproduce the SNLI experiments: follow the preparation instructions and run python scripts/reproduce_snli.py

Note that SGD benefits from a hand-designed learning rate schedule. In contrast, all the other optimizers (including DFW) automatically adapt their steps and rely on the tuning of the initial learning rate only. On average, you should obtain similar results to the ones reported in the paper (there might be some variance on some instances of CIFAR experiments):

CIFAR-10:

Wide Residual Networks Densely Connected Networks
Optimizer Test Accuracy (%)
Adagrad 86.07
Adam 84.86
AMSGrad 86.08
BPGrad 88.62
DFW 90.18
SGD 90.08
Optimizer Test Accuracy (%)
Adagrad 87.32
Adam 88.44
AMSGrad 90.53
BPGrad 90.85
DFW 90.22
SGD 92.02

CIFAR-100:

Wide Residual Networks Densely Connected Networks
Optimizer Test Accuracy (%)
Adagrad 57.64
Adam 58.46
AMSGrad 60.73
BPGrad 60.31
DFW 67.83
SGD 66.78
Optimizer Test Accuracy (%)
Adagrad 56.47
Adam 64.61
AMSGrad 68.32
BPGrad 59.36
DFW 69.55
SGD 70.33

SNLI:

CE LossSVM Loss
Optimizer Test Accuracy (%)
Adagrad 83.8
Adam 84.5
AMSGrad 84.2
BPGrad 83.6
DFW -
SGD 84.7
SGD* 84.5
Optimizer Test Accuracy (%)
Adagrad 84.6
Adam 85.0
AMSGrad 85.1
BPGrad 84.2
DFW 85.2
SGD 85.2
SGD* -

Acknowledgments

We use the following third-part implementations:

About

Implementation of the Deep Frank-Wolfe Algorithm -- Pytorch

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 100.0%