- Ubuntu 16.04
- NVIDIA GPU
- python >= 3.6
- Install pytorch >= v1.1.0 following official instruction.
- Clone this repo:
git clone https://github.com/VITA-Group/Sandwich-Batch-Normalization
cd GAN
- Install dependencies:
pip install -r requirements.txt
-
Prepare dataset:
-
Download pretrained model weights from our model zoo (GoogleDrive) and put it to
zoo
.
mkdir zoo
Experiments on CIFAR-10:
bash scripts/train_autogan_ccbn_cifar10.sh
bash scripts/train_autogan_sabn_cifar10.sh
bash scripts/train_sngan_ccbn_cifar10.sh
bash scripts/train_sngan_sabn_cifar10.sh
Experiments on ImageNet (cats and dogs):
bash scripts/sngan_ccbn_imagenet.sh
bash scripts/sngan_sabn_imagenet.sh
tensorboard --logdir output --port 6001
Testing using model zoo's model weights (GoogleDrive)
bash scripts/test_autogan_sabn_cifar10.sh
bash scripts/test_sngan_sabn_cifar10.sh
bash scripts/test_sngan_sabn_imagenet.sh
Evaluation results:
Model | Inception Score ↑ | FID ↓ |
---|---|---|
AutoGAN | 8.43 | 10.51 |
BigGAN | 8.91 | 8.57 |
SNGAN | 8.76 | 10.18 |
AutoGAN-SaBN (ours) | 8.72 (+0.29) | 9.11 (−1.40) |
BigGAN-SaBN (ours) | 9.01 (+0.10) | 8.03 (−0.54) |
SNGAN-SaBN (ours) | 8.89 (+0.13) | 8.97 (−1.21) |
Visual results on ImageNet (128*128 resolution):
SNGAN | SNGAN-SaBN (ours) |
---|---|
![]() |
![]() |