Skip to content

Bare-bones implementations of some generative models in Jax: diffusion, normalizing flows, consistency models, flow matching, (beta)-VAEs, etc

License

Notifications You must be signed in to change notification settings

smsharma/minified-generative-models

Repository files navigation

Minified generative models

License: MIT

Bare-bones, minified versions of some common (and not-so-common) generative models, for pedagogical purposes.

Installation

First, install JAX following these instructions. For CPU-only, this is as simple as:

pip install "jax[cpu]"

Additional libraries:

pip install flax optax diffrax tensorflow_probability scikit-learn tqdm matplotlib

List of notebooks

  1. β-VAEs: Variational autoencoders and basic rate-distortion theory.
  2. Diffusion models: Diffusion models, covering likelihood-based and score-matching interpretations.
  3. Normalizing flows (WiP annotations): Normalizing flows, specifically RealNVP.
  4. Continuous normalizing flows: Continuous-time normalizing flows from e.g., Grathwohl et al 2018.
  5. Consistency models (WiP annotations): Consistency models from Song et al 2023.
  6. Flow matching (WiP annotations): From Lipman et al 2022; see also Albergo et al 2023.
  7. Diffusion distillation (WiP): Progressive (Salimans et al 2022) and consistency (Song et al 2023) distillation.
  8. Discrete walk-jump sampling (WiP): From Frey et al 2023.

Inspiration

assets/midwit.pngs

About

Bare-bones implementations of some generative models in Jax: diffusion, normalizing flows, consistency models, flow matching, (beta)-VAEs, etc

Topics

Resources

License

Stars

Watchers

Forks