Bare-bones, minified versions of some common (and not-so-common) generative models, for pedagogical purposes.
First, install JAX following these instructions. For CPU-only, this is as simple as:
pip install "jax[cpu]"
Additional libraries:
pip install flax optax diffrax tensorflow_probability scikit-learn tqdm matplotlib
- β-VAEs: Variational autoencoders and basic rate-distortion theory.
- Diffusion models: Diffusion models, covering likelihood-based and score-matching interpretations.
- Normalizing flows (WiP annotations): Normalizing flows, specifically RealNVP.
- Continuous normalizing flows: Continuous-time normalizing flows from e.g., Grathwohl et al 2018.
- Consistency models (WiP annotations): Consistency models from Song et al 2023.
- Flow matching (WiP annotations): From Lipman et al 2022; see also Albergo et al 2023.
- Diffusion distillation (WiP): Progressive (Salimans et al 2022) and consistency (Song et al 2023) distillation.
- Discrete walk-jump sampling (WiP): From Frey et al 2023.