With this project, you can train Generative Adversarial Networks (GANs). While you can train with any type of image, this repository focuses on generating images from games.
- PyTorch 2 Compile
- Mixed Precision training (fp16 or bf16)
- Gradient Accumulation
- Inception Score and FID evaluation
- HF🤗 Accelerate - Adds Multi-GPU, TPU, and distributed support
- Easy to start training
- Testing
Provided in the code is a sample of the coil-100 dataset, which is used for testing. You can easily replace this with your own dataset, below are popular datasets used for image generation.
Dataset | Number of Images | Number of Labels | Resolutions Available | Description |
---|---|---|---|---|
ImageNet | 1m | 1k | varying | Real world objects |
CIFAR | 60k | 10 or 100 | 32x32 | Real world objects |
ArtBench-10 | 60k | 10 | 32x32, 256x256, or original | Art in 10 distinctive styles |
FFHQ | 70k | 1 | 1024x1024 | High Quality images of people's faces |
LSUN | 1m | 30 | varying | 10 scenes, and 20 objects |
The following are the Python packages needed.
- Pytorch 2.0+
- torchvision 1.5+
- SciPy 1.7+
- TorchMetrics
- torchinfo
- torch-ema
- tqdm
The following are the current models that are available. Changing models is as easy as specifying which model to use in the configuration file.
From the parent folder, you can run this command to start training a DCGAN model
python3 -m src.train_gan configs/dcgan_128_96.ini
Starter model configuration files and configuration README can be found in the configs directory.
models/Deep-biggan-bs64-ch128-mxp-n64-trunc0.75
- 194,460 images (84 GB)
- 20 games
- Resolution: 128 x 96
Training Batch | ||
---|---|---|
Generated Images | ||