Implementation of Consistency Models, a class of diffusion-adjacent models introduced in Song et al (2023), in Jax.
When used as standalone generative models, consistency models achieve state of the art performance in one- and few-step generation, outperforming existing techniques distilling diffusion models.
A minified, self-contained implementation of the discrete-time version of the model trained on MNIST is in the notebook mnist-example.ipynb.
- This repo uses a simple MLP-Mixer as the backbone for the consistency function.
- I only implement what the paper calls consistency training (CT), where the model is trained from scratch, rather than consistency distillation (CD), where the model is distilled from a pre-trained diffusion model.
- The continuous-time objective is implemented, but I have not gotten this to work well for consistency training. In the paper, the authors note, "For consistency training (CT), we find it important to initialize consistency models from a pre-trained EDM model in order to stabilize training when using continuous-time objectives. We hypothesize that this is caused by the large variance in our continuous-time loss functions", so this may not be surprising.
Train and logging (optional, through wandb
):
python train.py --config ./config/cifar10.py
Samples with 5 (left) and 2 (right) step generation, MNIST trained over 100k steps with a batch size of 512 in consistency-mnist.ipynb.
Samples with 5 (left) and 2 (right) step generation, CIFAR-10 trained over ~900k steps with a batch size of 512. These don't look... great, likely because of the choice of MLP-Mixer architecture backbone.