- Ubuntu 16.04
- NVIDIA GPU
- python >= 3.6
- Install pytorch >= v1.1.0 following official instruction.
- Clone this repo:
git clone https://github.com/VITA-Group/Sandwich-Batch-Normalization
cd Adv
- Install dependencies:
pip install -r requirements.txt
bash scripts/train_bn.sh
bash scripts/train_auxbn.sh
bash scripts/train_saauxbn.sh
Check Tensorboard:
tensorboard --logdir output --port 6001
The evaluation results:
Evaluation | BN | AuxBN (clean branch) | SaAuxBN (clean branch) (ours) |
---|---|---|---|
Clean (SA) | 84.84 | 94.47 | 94.62 |
Evaluation | BN | AuxBN (adv branch) | SaAuxBN (adv branch) (ours) |
---|---|---|---|
Clean (SA) | 84.84 | 83.42 | 84.08 |
PGD-10 (RA) | 41.57 | 43.05 | 44.93 |
PGD-20 (RA) | 40.02 | 41.60 | 43.14 |
The visualization of test loss:
Test loss | Adversarial test loss |
---|---|
![]() |
![]() |