This repository contains the official PyTorch implementaion for the paper: Generative Marginalization Models [paper link], by Sulin Liu, Peter J. Ramadge, and Ryan P. Adams.
We introduce marginalization models (MaMs), a new family of generative models for high-dimensional discrete data.
MaMs directly model the marginal distribution
The learned marginals should satisfy the "marginalization self-consistency":
where
To learn self-consistent marginals, we propose scalable training objectives that minimize the error of the following one-step self-consistency constraints imposed on marginals and conditionals over all possible orderings:
Marginals are order-agnostic, hence MaMs allow any-order generation.
Any-order autoregressive models [1,2,3] also allow any-order marginal inference via factorizing
-
significant speeding up inference in test time for any marginal, with orders of magnitude.
$\mathcal{O}(1)$ v.s.$\mathcal{O}(D)$ with autoregressive models. - enabling scalable training of any-order generative models on high-dimensional problems under energy-based training, a common setting in physical sciences or applications with a target reward function.
In contrast, training ARMs requires
$\mathcal{O}(D)$ evaluation (a sequence of conditionals) to get the log-likelihood of one data point required for energy-based training.
git clone https://github.com/PrincetonLIPS/MaM.git
cd MaM
# optional virtual env
python -m venv env
source env/bin/activate
python -m pip install -r requirements.txt
To train MaMs for maximum likelihood estimation, we fit the marginals by maximizing the expected log-likelihood over data distribution while enforcing the marginalization self-consistency.
For the most efficient training, the marginals can be learned in two-steps:
1. Fit the conditionals
cd ao_arm
python image_main.py # MNIST-Binary dataset
python text_main.py # text8 language modeling
python mol_main.py load_full=True # MOSES molecule string dataset
2. Fit the marginals
cd mam
python image_main.py load_pretrain=True # MNIST-Binary
python text_main.py load_pretrain=True # text8
python mol_main.py load_pretrain=True # MOSES molecule string
Coming soon: code and model checkpoints for more image datasets including CIFAR-10 and Imagenet-32.
In this setting, we do not have data samples from the distribution of interest. Instead, we have access to evaluate the unnormalized (log) probability mass function
Training of ARM are expensive because of the need to calculate
cd mam
# ising model energy-based training
python ising_eb_main.py
# molecule property energy-based training with a given reward function
python mol_property_eb_main.py
Please check the paper for technical details and experimental results. Please consider citing our work if you find it helpful:
@article{liu2023mam,
title={Generative Marginalization Models},
author={Liu, Sulin and Ramadge, Peter J and Adams, Ryan P},
journal={arXiv preprint arXiv:2310.12920},
year={2023}
}
The code for training any-order conditionals of autoregressive models (in ao_arm/
) are adapted from https://github.com/AndyShih12/mac, using the original any-order masking strategy proposed for training AO-ARMs without the [mask]
token in the output.