This repository is an implementation of the debiasing method proposed in our paper based on an unofficial pytorch code for Fixmatch. This implementation can reproduce the resultson CIFAR-10 of our paper but also on CIFAR-100.
As explained in the paper, we modified the supervised loss of Fixmatch to include strong augmentations:
$$
L(\theta;x,y) = \frac{1}{2}\left(\mathbb{E}{x_1\sim\textit{weak}(x)}[-\log(p{\theta}(y|x_1))] + \mathbb{E}{x_2\sim\textit{strong}(x)}[-\log(p{\theta}(y|x_2))]\right),
$$
where
The training objective for the Complete Case is $$ \hat{\mathcal{R}}{CC}(\theta) = \frac{1}{n_l}\sum{i=1}^{n_l}L(\theta; x_i,y_i). $$
The training objective for Fixmatch is $$ \hat{\mathcal{R}}{DeSSL}(\theta) = \frac{1}{n_l}\sum{i=1}^{n_l}L(\theta; x_i,y_i) \color{red}{+ \frac{\lambda}{n_u}\sum_{i=1}^{n_u}H(\theta; x_i)} . $$
The training objective for DeFixmatch is $$ \hat{\mathcal{R}}{DeSSL}(\theta) = \frac{1}{n_l}\sum{i=1}^{n_l}L(\theta; x_i,y_i) \color{red}{+ \frac{\lambda}{n_u}\sum_{i=1}^{n_u}H(\theta; x_i)} \color{blue}{- \frac{\lambda}{n_l}\sum_{i=1}^{n_l}H(\theta; x_i)}. $$
To install requirements:
pip install -r requirements.txt
We recommend using distributed training for high performance.
With V100x4 GPUs, CIFAR10 training takes about 16 hours (0.7 days), and CIFAR100 training takes about 62 hours (2.6 days).
To train the complete case model on CIFAR-10 with
python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 4000 --dataset cifar10 --num_classes 10 --overwrite --modified_fixmatch --ulb_loss_ratio 0
To train the Fixmatch on CIFAR-10 with
python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 4000 --dataset cifar10 --num_classes 10 --overwrite --modified_fixmatch --ulb_loss_ratio 0.5
To train the Fixmatch on CIFAR-10 with
python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 4000 --dataset cifar10 --num_classes 10 --overwrite --debiased --ulb_loss_ratio 0.5
Trained models are saved in the directory: saved_models/
To evaluate my model on CIFAR-10 using a checkpoint, run:
python eval.py --load_path model.pth --dataset cifar10 --num_classes 10
You can find pretrained models on CIFAR-10 using
Our models achieves the following performance on CIFAR-10:
Model name | Accuracy | Cross-entropy | Worst class Accuracy |
---|---|---|---|
Complete Case | 87.27 |
0.60 |
70.08 |
Fixmatch | 93.87 |
0.27 |
82.25 |
DeFixmatch | 95.44 |
0.20 |
87.16 |