Official code for the Optimal Variance Control of the Score Function Gradient Estimator for Importance Weighted Bounds (a.k.a OVIS : Optimal Variance -- Importance Sampling). Published at NeuriPS 2020.
OVIS is a state-of-the-art gradient estimator for discrete VAEs. This repo provides a user-friendly interface to OVIS, and other gradient estimators. OVIS can easily be imported in your project to train and evaluate discrete VAEs. The implementation is compatible with a wide variety of VAE models, including hierarchical ones. This library allows reproducing all the experiments from the paper.
author = {Li\'{e}vin, Valentin and Dittadi, Andrea and Christensen, Anders and Winther, Ole},
booktitle = {Advances in Neural Information Processing Systems},
editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
pages = {16591--16602},
publisher = {Curran Associates, Inc.},
title = {Optimal Variance Control of the Score-Function Gradient Estimator for Importance-Weighted Bounds},
url = {},
volume = {33},
year = {2020}
- Reinforce
- Reinforce + Neural Baseline
- Vimco
- RWS (Reweighted Wake-Sleep)
- TVO (Thermodynamic Variational Objective)
OVIS can easily be imported in your own project to train your own discrete VAE/generative models. You simply need to define your model following the example bellow. The full example is available in
# install the latest release
pip install git+
# OR install in dev. mode
git clone && pip install -e ovis/
Gradient estimators can be initialized in 2 lines. Using the estimator, computing the loss for your model is a one liner.
# init the estimator
from ovis.estimators.config import parse_estimator_id
Estimator, config = parse_estimator_id("ovis-gamma1")
estimator = Estimator(mc=1, iw=16, **config)
# use it to compute the differentiable loss
loss, diagnostics, output = estimator(model, x)
OVIS relies on torch.distributions
to implement the variational distributions. The library has been tested with normal, bernoulli and categorical distributions, but this should work with other distributions as well as long as it comes with a .log_prob()
Every model should implement:
forward(self, x:Tensor, reparam:bool=False, **kwargs) -> OUTPUT
sample_from_prior(self, bs: int, **kwargs)-> OUTPUT
Where the output format is defined as OUTPUT=Dict[str, Union[Distribution, List[Tensor], List[Distribution]]]
The output is a dictionary with keys:
: distribution modellingp(x|z)
: latent samplesz
, one item for each layerpz
: prior distributionp(z)
, one item for each layerqz
: posterior distributionq(z|x)
, one item for each layer
from torch import nn, Tensor, zeros
from torch.distributions import Bernoulli
from ovis.models import TemplateModel
class SimpleModel(TemplateModel):
def __init__(self, xdim, zdim):
self.inference_network = nn.Linear(xdim, zdim)
self.generative_model = nn.Linear(zdim, xdim)
self.register_buffer('prior', zeros((1, zdim,)))
def forward(self, x:Tensor, reparam:bool=False, **kwargs):
# q(z|x)
qz = Bernoulli(logits=self.inference_network(x))
# z ~ q(z|x)
z = qz.rsample() if reparam else qz.sample()
# p(x)
pz = Bernoulli(logits=self.prior)
# p(x|z)
px = Bernoulli(logits=self.generative_model(z))
# store z, pz, qz as lists (useful for hierarchical models)
return {'px': px, 'z': [z], 'qz': [qz], 'pz': [pz]}
def sample_from_prior(self, bs: int, **kwargs):
pz = Bernoulli(logits=self.prior.expand(bs, *self.prior.shape[1:]))
z = pz.sample()
px = Bernoulli(logits=self.generative_model(z))
return {'px': px, 'z': [z], 'pz': [pz]}
# generate x ~ Bernoulli(0.5), initialize a simple VAE, forward pass, prior sampling
x = Bernoulli(logits=zeros((1, 10,))).sample()
model = SimpleModel(10, 10)
output = model(x)
output = model.sample_from_prior(1)
The code bellow shows a simple training loop for training a model.
Notice how parameters
can be used for various types of scheduling (i.e.
from ovis.analysis.gradients import get_gradients_statistics
from booster import Aggregator
agg = Aggregator()
parameters = {'alpha': 0.9, 'beta': 1}
for x in loader:
global_step += 1
loss, diagnostics, output = estimator(model, x, backward=False, **parameters)
# update parameters
# epoch summary
summary ='cpu')
# analyse the gradients of the parameters of the inference network
grad_stats, _ = get_gradients_statistics(estimator, model, x, mc_samples=10, key_filter='inference_network')
# log data
summary.log(tensorboard_writer, global_step)
conda create -n ovis python=3.7
conda activate ovis
# use the instructions from
conda install pytorch=1.5.1 torchvision cudatoolkit=10.2 -c pytorch
pip install -r requirements.txt
# [Optional] Install Latex (used for the figures)
This paper introduces novel results for the score function gradient estimator of the importance weighted variational bound (IWAE). We prove that in the limit of large
All experiments are managed through the script
which implement a mutli-threaded queue system based on
and a filelock
protection. See python --help
for more information about the number of
subprocesses and resuming experiments. The scripts
provides a few utilities to inspect and clean
the experiment database.
allows parsing an experiment directory and producing figures. Usage:
# begins an experiment with 2 processes per GPU (max. 2 GPUs)
python --exp exp_id --processes n_procs_per_gpu --max_gpus 2
# show the experiment status [queued, aborted, failed, running, success]
python --exp exp_id --check
# requeue aborted experiments
python --exp exp_id --requeue --requeue_level 1
# generate plots
python --exp exp_id --metrics train:loss/L_k,train:grads/snr --pivot_metrics train:loss/L_k,train:grads/snr
Anaysis of the gradients for a simple Gaussian model. Figure 1:
# run the experiment
python --exp asymptotic-variance
# produce the figures
python report_asymptotic_variance --exp asymptotic-variance
# access the results
open reports/asymptotic-variance
Train a simple Gaussian Mixture model. Figure 2:
# run the experiment
python --exp gaussian-mixture-model
# produce the figures
python --exp=gaussian-mixture-model \
--keys=dataset,estimator,iw \
--metrics=test:gmm/posterior_mse,test:gmm/prior_mse,train:grads/variance,train:grads/snr \
--detailed_metrics=test:gmm/posterior_mse,test:gmm/prior_mse,train:loss/ess,train:grads/variance,train:grads/snr \
--pivot_metrics=min:test:gmm/posterior_mse,min:test:gmm/prior_mse,mean:train:grads/snr \
# access the results
open reports/gaussian-mixture-model
Train a 3-layers Sigmoid Belief Network using the Importance-Weighted Bound (IW) and the Rényi Importance Weighted Bound (IWR). Run all experiments:
# run the experiment
python --exp sigmoid-belief-network
Figure 3 (left, VIMCO + OVIS-IW), 3 seeds:
# gather the data
python --exp=sigmoid-belief-network \
--include=iwbound \
--keys=dataset,estimator,iw \
--metrics=test:loss/L_k,train:loss/L_k,train:loss/kl_q_p,train:grads/snr \
--detailed_metrics=test:loss/L_k,train:loss/L_k,train:loss/kl_q_p,train:loss/kl,train:loss/ess,train:active_units/au,train:grads/snr \
--pivot_metrics=max:test:loss/L_k,max:train:loss/L_k,last:train:loss/kl_q_p,last:train:loss/ess \
# produce the figure
python --figure left
# access the results
open reports/sigmoid-belief-network-inc=iwbound
Figure 3 (right, TVO + OVIS-IWR), 3 seeds:
# gather the data
python --exp=sigmoid-belief-network \
--include=iwrbound \
--keys=dataset,estimator,iw \
--metrics=test:loss/L_k,train:loss/L_k,train:loss/kl_q_p,train:grads/snr \
--detailed_metrics=test:loss/L_k,train:loss/L_k,train:loss/kl_q_p,train:loss/kl,train:loss/ess,train:active_units/au,train:grads/snr \
--pivot_metrics=max:test:loss/L_k,max:train:loss/L_k,last:train:loss/kl_q_p,last:train:loss/ess \
# produce the figure
python --figure right
# access the results
open reports/sigmoid-belief-network-inc=iwrbound
Train a 1-layer Gaussian VAE. Figure 4:
# produce the figures
python --exp=gaussian-vae \
--keys=dataset,estimator,iw \
--metrics=test:loss/L_k,train:loss/L_k,train:loss/kl_q_p,train:grads/snr \
--detailed_metrics=train:loss/L_k,train:loss/kl_q_p,train:grads/snr \
# access the results
open reports/gaussian-vae
Fitting the Binarized MNIST, Fashion MNIST and Omniglot datasets. The hyperparameters are identical for all experiments. With and Without Rényi warmup.
python --exp=sigmoid-belief-network \
--keys=dataset,estimator,iw,warmup \
--metrics=test:loss/L_k,train:loss/L_k,train:loss/kl_q_p,train:grads/snr \
--detailed_metrics=test:loss/L_k,train:loss/L_k,train:loss/kl_q_p,train:loss/kl,train:loss/ess,train:active_units/au,train:grads/snr \
--pivot_metrics=max:test:loss/L_k,max:train:loss/L_k,last:train:loss/kl_q_p,last:train:loss/ess,last:train:active_units/au \
--downsample 50 \
--include tvo,vimco,ovis-gamma1
python --exp=gaussian-vae \
--keys=dataset,estimator,iw,alpha \
--metrics=test:loss/L_k,train:loss/L_k,train:loss/kl_q_p,train:grads/snr \
--detailed_metrics=test:loss/L_k,train:loss/L_k,train:loss/kl_q_p,train:loss/kl,train:loss/ess,train:active_units/au,train:grads/snr \
--pivot_metrics=max:test:loss/L_k,max:train:loss/L_k,last:train:loss/kl_q_p,last:train:loss/ess,last:train:grads/snr \
--downsample 50
In this experiment, we compare the asymptotic OVIS (gamma=1) with the sample based control OVIS-MC. By contrast with
the previous experiments, the total particle budget remains equals to K
The K
particles are used to estimate the gradient of the generative model,
particles are used to evaluate the score based estimate of the gradient of the inference network and S
are used to estimate the control variate. In the following plots, the identifier ovis-Sy
indicates that S = yK
See experiment .json
file for more details.
# run the experiment
python --exp budget-analysis
# produce the figures
python --exp=budget-analysis \
--keys=dataset,estimator,iw \
--metrics=test:loss/L_k,train:loss/L_k,train:loss/kl_q_p,train:grads/snr \
--detailed_metrics=train:loss/L_k,train:loss/kl_q_p \
# access the results
open reports/budget-analysis
# run the experiment
python --exp budget-analysis-sbm
# produce the figures
python --exp=budget-analysis-sbm \
--keys=dataset,estimator,iw \
--metrics=test:loss/L_k,train:loss/L_k,train:loss/kl_q_p,train:grads/snr \
--detailed_metrics=train:loss/L_k,train:loss/kl_q_p \
# access the results
open reports/budget-analysis-sbm
Checking the memory usage of the different estimators given different particles budgets.