This repository is an implementation of Optimizing Millions of Hyperparameters by Implicit Differentiation.
Create a Python 3.7 environment and install required packages:
conda create -n ift-env python=3.7
source activate ift-env
conda install pytorch torchvision cudatoolkit=9.0 -c pytorch
pip install -r requirements.txt
Install Jupyter lab:
conda install -c conda-forge jupyterlab
Consider the following tests to verify the environment is correctly setup:
python mnist_test.py
--datasize <train set size>
--valsize <validation set size>
--lrh <hyperparameter lr need to be negative>
--epochs <min epochs for training model>
--hepochs <# of iterations for hyperparameter update>
--l2 <initial log weight decay>
--restart <reinitialize model weight after each hyperparameter update or not>
--model <cnn for lenet like model, mlp for logistic regession and mlp>
--dataset <CIFAR10 or MNIST>
--num_layers <# of hidden layer for mlp>
--hessian<KFAC: KFAC estiamte; direct:true hessian and inverse>
--jacobian<direct: true jacobian; product: use d_L/d_theta * d_L/d_lambda>
Trained models after each hyperparameter update will be stored in folder defined in line 627 in mnist_test.py
.
To use CG to compute inverse of hessian, change line 660's hyperparameter updator.
python mnist_test.py --datasize 40000 --valsize 10000 --lrh 0.01 --epochs=100 --hepochs=10 --l2=1e-5 --restart=10 --model=mlp --dataset=MNIST --num_layers=1 --hessian=KFAC --jacobian=direct
First, make sure you are on the master node:
ssh <USERNAME>@q.vectorinstitute.ai
Submit a job to the Slurm scheduler:
srun --partion=gpu --gres=gpu:1 --mem=4GB python mnist_test.py
Or, submit a batch of jobs defined by srun_script.sh
:
sbatch --array=0-2 srun_script.sh
View queued jobs for a user:
squeue -u $USERNAME
Cancel jobs for a user:
scancel -u $USERNAME
Cancel a specific job:
scancel $JOBID
Here, we should place commands for deploying experiments with and without Slurm
To deploy all of the experiments data generation:
sbatch run_all.sh
Data Augmentation Network
python train_augment_net2.py --use_augment_net
Loss Reweighting Network
python train_augment_net2.py --use_reweighting_net --loss_weight_type=softmax
The LSTM code in this repository is built on the AWD-LSTM codebase.
These commands should be run from inside the rnn
folder.
First, download the PTB dataset by running:
./getdata.sh
Tune LSTM hyperparameters with 1-step unrolling
python train.py
To train an STN, run the following command from inside the stn
folder:
python hypertrain.py --tune_all --save
python train_checkpoint.py --dataset cifar10 --model resnet18 --data_augmentation
python finetune_checkpoint.py --load_checkpoint=baseline_checkpoints/cifar10_resnet18_sgdm_lr0.1_wd0.0005_aug1.pt --num_finetune_epochs=10 --wdecay=1e-4
Explain what experiment does, and what figure it is in the paper.
To run python script:
python script.py
To deploy with Slurm:
srun ...
.
βββ HAM_dataset.py
βββ README.md
βββ cutout.py
βββ data_loaders.py
βββ finetune_checkpoint.py
βββ finetune_ift_checkpoint.py
βββ grid_search.py
βββ images
βββ inverse_comparison.py
βββ isic_config.py
βββ isic_loader.py
βββ kfac.py
βββ kfac_utils.py
βββ minst_ref.py
βββ mnist_test.py
βββ models
βΒ Β βββ __init__.py
βΒ Β βββ resnet.py
βΒ Β βββ resnet_cifar.py
βΒ Β βββ simple_models.py
βΒ Β βββ unet.py
βΒ Β βββ wide_resnet.py
βββ papers
βΒ Β βββ haoping_project
βΒ Β βΒ Β βββ main.tex
βΒ Β βΒ Β βββ neurips2019.tex
βΒ Β βΒ Β βββ neurips_2019.sty
βΒ Β βΒ Β βββ references.bib
βΒ Β βββ nips
βΒ Β βββ main.tex
βΒ Β βββ neurips_2019.sty
βΒ Β βββ references.bib
βββ random_search.py
βββ requirements.txt
βββ rnn
βΒ Β βββ config_scripts
βΒ Β βΒ Β βββ dropoute_ift_no_lrdecay.yaml
βΒ Β βΒ Β βββ dropouto
βΒ Β βΒ Β βΒ Β βββ dropouto_2layer_lrdecay.yaml
βΒ Β βΒ Β βΒ Β βββ dropouto_2layer_no_lrdecay.yaml
βΒ Β βΒ Β βΒ Β βββ dropouto_ift_lrdecay.yaml
βΒ Β βΒ Β βΒ Β βββ dropouto_ift_neumann_1_lrdecay.yaml
βΒ Β βΒ Β βΒ Β βββ dropouto_ift_neumann_1_no_lrdecay.yaml
βΒ Β βΒ Β βΒ Β βββ dropouto_ift_no_lrdecay.yaml
βΒ Β βΒ Β βΒ Β βββ dropouto_lrdecay.yaml
βΒ Β βΒ Β βΒ Β βββ dropouto_no_lrdecay.yaml
βΒ Β βΒ Β βΒ Β βββ dropouto_perparam_ift_no_lrdecay.yaml
βΒ Β βΒ Β βββ wdecay
βΒ Β βΒ Β βββ ift_wdecay_per_param_no_lrdecay.yaml
βΒ Β βΒ Β βββ wdecay_ift_lrdecay.yaml
βΒ Β βΒ Β βββ wdecay_ift_neumann_1_lrdecay.yaml
βΒ Β βββ create_command_script.py
βΒ Β βββ data.py
βΒ Β βββ embed_regularize.py
βΒ Β βββ getdata.sh
βΒ Β βββ locked_dropout.py
βΒ Β βββ logger.py
βΒ Β βββ model_basic.py
βΒ Β βββ plot_utils.py
βΒ Β βββ rnn_utils.py
βΒ Β βββ run_grid_search.py
βΒ Β βββ train.py
βΒ Β βββ train2.py
βΒ Β βββ weight_drop.py
βββ search_configs
βΒ Β βββ cifar100_wideresnet_bern_dropout_sep.yaml
βΒ Β βββ cifar100_wideresnet_gauss_dropout_sep.yaml
βΒ Β βββ cifar10_resnet32_data_aug.yaml
βΒ Β βββ cifar10_resnet32_grid.yaml
βΒ Β βββ cifar10_resnet32_random.yaml
βΒ Β βββ cifar10_resnet32_wdecay_per_layer.yaml
βΒ Β βββ cifar10_wideresnet_bern_dropout.yaml
βΒ Β βββ cifar10_wideresnet_bern_dropout_sep.yaml
βΒ Β βββ cifar10_wideresnet_gauss_dropout.yaml
βΒ Β βββ cifar10_wideresnet_gauss_dropout_sep.yaml
βΒ Β βββ isic_grid.yaml
βΒ Β βββ isic_random.yaml
βββ search_scripts
βΒ Β βββ cifar100_wideresnet_bern_dropout_sep
βΒ Β βββ cifar100_wideresnet_gauss_dropout_sep
βΒ Β βββ cifar100_wideresnet_random
βΒ Β βββ cifar10_wideresnet_bern_dropout
βΒ Β βββ cifar10_wideresnet_bern_dropout_sep
βΒ Β βββ cifar10_wideresnet_gauss_dropout
βΒ Β βββ cifar10_wideresnet_gauss_dropout_sep
βββ srun_script.sh
βββ stn
βΒ Β βββ datasets
βΒ Β βΒ Β βββ __init__.py
βΒ Β βΒ Β βββ cifar.py
βΒ Β βΒ Β βββ loaders.py
βΒ Β βββ hypermodels
βΒ Β βΒ Β βββ __init__.py
βΒ Β βΒ Β βββ alexnet.py
βΒ Β βΒ Β βββ hyperconv2d.py
βΒ Β βΒ Β βββ hyperlinear.py
βΒ Β βΒ Β βββ small.py
βΒ Β βββ hypertrain.py
βΒ Β βββ models
βΒ Β βΒ Β βββ __init__.py
βΒ Β βΒ Β βββ alexnet.py
βΒ Β βΒ Β βββ small.py
βΒ Β βββ util
βΒ Β βββ __init__.py
βΒ Β βββ cutout.py
βΒ Β βββ dropout.py
βΒ Β βββ hyperparameter.py
βββ train.py
βββ train_augment_net2.py
βββ train_augment_net_graph.py
βββ train_augment_net_multiple.py
βββ train_augment_net_slurm.py
βββ train_baseline.py
βββ train_checkpoint.py
βββ utils
βββ csv_logger.py
βββ discrete_utils.py
βββ logger.py
βββ plot_utils.py
βββ util.py
17 directories, 103 files