Minimal unofficial implementation of consistency models (CM) proposed by Song et al. 2023 on 1D toy tasks.
pip install -e .
This repo contains implementations for Consistency Distillation (CD) and Consistency Training (CT). For better performance on Consistenct Training, there exists an option to pretrain the model with an diffusion objective before changing to the CT objective. The usage of a diffusion training objective before starting the CT helps to stabilize the training process significantly.
To try it out:
-
Consistency Distillation:
cd_main.py
-
Discrete Consisteny Training:
cm_main.py
-
Continuous Consisteny Training:
ct_cm_main.py
Data
I have implemented some simple 1D toy tasks to try out the capabilities of multimodality and expressiveness of consistency models. Just change the input string for the datamanager class to one of the following datasets 'three_gmm_1D', 'uneven_two_gmm_1D', 'two_gmm_1D', 'single_gaussian_1D'
.
After the training results of the trained model will be plottted. The plots look like the example ones below.
I used 2000 training steps for these results with diffusion pretraining.
Two Gaussians
From left to right: EDM Diffusion pretraining with Euler, Multistep prediction with Consisteny Models and Single Step prediction.
Three Gaussians
From left to right: EDM Diffusion pretraining with Euler, Multistep prediction with Consisteny Models and Single Step prediction.
-
Consistency training is not really stable, which is not surprising, since the authors also discuss its shortcomings in the paper and even recommend to use pretrained diffusion models as initialization for the training
-
Image hyperparamters do not translate well to other domains. I hat limited sucess with the recommened parameters for Cifar10 and other image-based applications. By significanlty reducing the maximimum noise level results improved. I also increased the minimum number of discrete noise levels and the maximum.
-
Multistep prediction of consistency models has a certain drift towards to outside, which I cannot explain. I just used linear noise scheduler for the multistep sampling, so maybe with better discretization results will improve
-
Discrete training works a lot better than the continuous version. The authors report similar observations for the high-dimensional image domain.
- Implement Consistency Distillation Training
- Add new toy tasks
- Check conditional training
- Find good hyperaparmeters
- Improve plotting method
-
the general consistency class is based on the code from OpenAI's consistency models repo
-
some sampling methods and other functions are from k_diffusion
-
the model is based on the paper Consistency Models
@article{song2023consistency,
title={Consistency Models},
author={Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya},
journal={arXiv preprint arXiv:2303.01469},
year={2023},
}