Official PyTorch implementation of model ratatouille | paper
Alexandre Ramé, Kartik Ahuja, Jianyu Zhang, Matthieu Cord, Léon Bottou, David Lopez-Paz
We propose a new fine-tuning strategy that improves OOD generalization in computer vision by recycling and averaging weights specialized on diverse auxiliary tasks.
Foundation models are redefining how AI systems are built. Practitioners now follow a standard procedure to build their machine learning solutions: from a pre-trained foundation model, they fine-tune the weights on the target task of interest. Then, the Internet is swarmed by a handful of foundation models fine-tuned on many diverse tasks: these individual fine-tunings exist in isolation without benefiting from each other. In our opinion, this is a missed opportunity, as these specialized models contain rich and diverse features. In this paper, we thus propose model ratatouille, a new strategy to recycle the multiple fine-tunings of the same foundation model on diverse auxiliary tasks. Specifically, we repurpose these auxiliary weights as initializations for multiple parallel fine-tunings on the target task; then, we average all fine-tuned weights to obtain the final model. This recycling strategy aims at maximizing the diversity in weights by leveraging the diversity in auxiliary tasks. Empirically, it improves the state of the art on the reference DomainBed benchmark for out-of-distribution generalization. Looking forward, this work contributes to the emerging paradigm of updatable machine learning where, akin to open-source software development, the community collaborates to reliably update machine learning models.
Our code is adapted from the open-source DomainBed github, which is a PyTorch benchmark including datasets and algorithms evaluating OOD generalization. It was introduced in In Search of Lost Domain Generalization, ICLR 2021. More specifically, our code extends the DiWA github, which weight averages the models obtained from the hyperparameter search as a replacement to only selecting one single model: this was motivated and explained in model soups, ICML 2022 and DiWA, NeurIPS 2022 papers.
- python == 3.7.10
- torch == 1.12.1
- torchvision == 0.13.1
- numpy == 1.21.5
We consider the following datasets:
- VLCS (Fang et al., 2013)
- PACS (Li et al., 2017)
- OfficeHome (Venkateswara et al., 2017)
- A TerraIncognita (Beery et al., 2018) subset
- DomainNet (Peng et al., 2019)
You can download the datasets with following command:
python3 -m domainbed.scripts.download --data_dir ${data_dir}
Our procedure is in three stages.
- Auxiliary trainings: create a pool of specialized models on various auxiliary tasks.
- Target trainings: apply the standard hyperparameter search starting from these auxiliary initializations.
- Weight selection: average the fine-tuned weights.
The different experiments are saved in ${expe_dir}
.
For real-world applications, we envision that specialized weights may be downloaded from collaborative open-source repositories of neural networks. In practice in this github, to populate the folder ${expe_dir}/aux
, we will perform fine-tunings on DomainBed's datasets. Specifically, we use the sweep
script with either VLCS, PACS, OfficeHome, TerraIncognita or DomainNet as the ${auxiliary_dataset}
.
mkdir ${expe_dir}/lp # dir containing the linear probe runs
mkdir ${expe_dir}/aux # dir containing the auxiliary runs
for auxiliary_dataset in VLCS PACS OfficeHome TerraIncognita DomainNet
do
python -m domainbed.scripts.sweep launch\
--data_dir ${data_dir}\
--dataset ${auxiliary_dataset}\
--test_env -1\ ## this means that we train on all domains simultaneously: there is no OOD test env for auxiliary trainings.
--output_dir_lp ${expe_dir}/lp/${auxiliary_dataset}_notest\ ## where the shared linear probe is saved
--output_dir ${expe_dir}/aux/${auxiliary_dataset}_notest\ ## where the auxiliary hyperparameter sweep is saved
--n_hparams 4\ ## we only need 4 runs in the hyperparameter search
--n_trials 1 ## only one data split
done
First, if output_dir_lp
does not exist, we linear probe (lp) the classifier (to prevent feature distortion): this classifier initialization will be used in the subsequent runs. Second, we populate output_dir
with n_hparams
ERM runs following the hyperparameter distributions from here.
Critically, this procedure is agnostic to the target task, and thus is done only once.
Now we focus on a given ${target_dataset}
, and one ${test_env}
domain considered as the test domain: other domains are for training. As previously, we leverage the sweep
script.
mkdir ${expe_dir}/target # dir containing the target runs
target_dataset=OfficeHome ## or any other DomainBed's dataset
test_env=0 ## or any integer between 0 and 3
python -m domainbed.scripts.sweep launch\
--data_dir ${data_dir}\
--dataset ${target_dataset}\
--test_env ${test_env}\ ## domain not seen during training and kept apart for OOD evaluation
--output_dir_lp ${expe_dir}/lp/${target_dataset}_test${test_env}\ ## where the shared linear probe is saved
--output_dir ${expe_dir}/target/${target_dataset}_withaux\ ## where the target hyperparameter sweep is saved
--aux_dir ${expe_dir}/aux\ ## where the pool of auxiliary weights are saved
--n_hparams 20\ ## default number of hyperparameters, but 5 already provides good results
--n_trials 1 ## set to 3 to test different data splits
The arg aux_dir
is the directory containing the different auxiliary runs to initialize the featurizer. Obviously, to prevent any kind of information leakage, in the code we will discard from aux_dir
the models inter-trained on ${target_dataset}
: in brief, we ensure that ${target_dataset}
${auxiliary_dataset}
.
Ratatouille's main theoretical contribution states the linear mode connectivity across models fine-tuned on the target task starting from different initializations. Thus we average the weights obtained from previous sweep.
python -m domainbed.scripts.inference\
--data_dir ${data_dir}\
--dataset ${target_dataset}\
--test_env ${test_env}\
--input_dir ${expe_dir}/target/${target_dataset}_withaux\
--weight_selection uniform\ # or use greedy
--trial_seed 0
If you want to obtain standard deviations on different data splits, set --n_trials 3
in the sweep command. Then you can specify trial_seed
to either 0
, 1
or 2
: you can also average all 60
weights from the 3
trials by settingtrial_seed
to-1
, what we calluniform
Inter-training selects the best model based on ID validation accuracy from previous runs. To reproduce the results, call:
python -m domainbed.scripts.collect_results --input_dir ${expe_dir}/target/${target_dataset}_withaux
You first need to launch a new sweep without specifying aux_dir
.
python -m domainbed.scripts.sweep launch\
... # same as before
--output_dir ${expe_dir}/target/${target_dataset}_noaux\ ## change the output dir
--aux_dir none
Then call collect_results.py
(for vanilla fine-tuning) or inference.py
(for Soups/DiWA) with --input_dir ${expe_dir}/target/${target_dataset}_noaux
. In brief, model ratatouille is to inter-training as model soups is to vanilla fine-tuning.
Add --fusing_range 4
in the previous sweep command to operate linear interpolation at initialization as in fusing, where rather than selecting one single checkpoint at initialization, they linearly interpolate multiple auxiliary featurizers.
python -m domainbed.scripts.sweep launch\
... # same as before
--output_dir ${expe_dir}/target/${target_dataset}_withaux_fusing4\ ## change the output dir
--fusing_range 4 # The value `4` specifies how the interpolating coefficients are sampled.
Ratatouille sets a new state of the art on DomainBed.
Algorithm | Selection | PACS | VLCS | OfficeHome | TerraInc | DomainNet | Avg |
---|---|---|---|---|---|---|---|
Vanilla fine-tuning | ID val | 85.5 | 77.5 | 66.5 | 46.1 | 40.9 | 63.3 |
Coral | ID val | 86.2 | 78.8 | 68.7 | 47.6 | 41.5 | 64.6 |
SWAD | Loss-aware | 88.1 | 79.1 | 70.6 | 50.0 | 46.5 | 66.9 |
--- | --- | --- | --- | --- | --- | --- | --- |
ERM | ID val | 85.9 | 78.1 | 69.4 | 50.4 | 44.3 | 65.6 |
Soups/DiWA | Greedy | 88.0 | 78.5 | 71.5 | 51.6 | 47.7 | 67.5 |
Soups/DiWA | Uniform | 88.7 | 78.4 | 72.1 | 51.4 | 47.4 | 67.6 |
Soups/DiWA$^{\dagger}$ | Uniform$^{\dagger}$ | 89.0 | 78.6 | 72.8 | 51.9 | 47.7 | 68.0 |
--- | --- | --- | --- | --- | --- | --- | --- |
Inter-training | ID val | 89.0 | 77.7 | 69.9 | 46.7 | 44.5 | 65.6 |
Fusing | ID val | 88.0 | 78.5 | 71.5 | 46.7 | 44.4 | 65.8 |
Ratatouille | Uniform | 89.5 | 78.5 | 73.1 | 51.8 | 47.5 | 68.1 |
Ratatouille | Greedy | 90.5 | 78.7 | 73.4 | 49.2 | 47.7 | 67.9 |
Ratatouille$^{\dagger}$ | Uniform$^{\dagger}$ | 89.8 | 78.3 | 73.5 | 52.0 | 47.7 | 68.3 |
This source code is released under the MIT license, included here.
We named our method after this traditional French dish for two main reasons. Firstly, the ratatouille is often used as a way to recycle leftover vegetables. Secondly, the ratatouille is better prepared by cooking each ingredient separately before mixing them: this technique ensures that each ingredient “will taste truly of itself”, as noted by chef Joël Robuchon.
If you find this code useful for your research, please consider citing our work:
@article{rame2022recycling,
title={Model Ratatouille: Recycling Diverse Models for Out-of-Distribution Generalization},
author={Ram{\'e}, Alexandre and Ahuja, Kartik and Zhang, Jianyu and Cord, Matthieu and Bottou, L{\'e}on and Lopez-Paz, David},
journal={arXiv preprint arXiv:2212.10445},
year={2022}
}
Correspondence to alexandre.rame at isir.upmc.fr