Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add a function that constructs samplers #45

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aemcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from . import _version

__version__ = _version.get_versions()["version"]

# Register rewrite databases
import aemcmc.conjugates
import aemcmc.gibbs
123 changes: 123 additions & 0 deletions aemcmc/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Dict, Tuple

from aesara.graph.basic import Variable
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.var import TensorVariable

from aemcmc.opt import (
SamplerTracker,
construct_ir_fgraph,
expand_subsumptions,
sampler_rewrites_db,
)


def construct_sampler(
obs_rvs_to_values: Dict[TensorVariable, TensorVariable], srng: RandomStream
) -> Tuple[
Dict[TensorVariable, TensorVariable],
Dict[Variable, Variable],
Dict[TensorVariable, TensorVariable],
]:
r"""Eagerly construct a sampler for a given set of observed variables and their observations.

Parameters
==========
obs_rvs_to_values
A ``dict`` of variables that maps stochastic elements
(e.g. `RandomVariable`\s) to symbolic `Variable`\s representing their
observed values.

Returns
=======
A ``dict`` that maps each random variable to its sampler step and
any updates generated by the sampler steps.
"""

fgraph, obs_rvs_to_values, memo, new_to_old_rvs = construct_ir_fgraph(
obs_rvs_to_values
)

fgraph.attach_feature(SamplerTracker(srng))

_ = sampler_rewrites_db.query("+basic").optimize(fgraph)

random_vars = tuple(rv for rv in fgraph.outputs if rv not in obs_rvs_to_values)

discovered_samplers = fgraph.sampler_mappings.rvs_to_samplers

rvs_to_init_vals = {rv: rv.clone() for rv in random_vars}
posterior_sample_steps = rvs_to_init_vals.copy()
# Replace occurrences of observed variables with their observed values
posterior_sample_steps.update(obs_rvs_to_values)

# TODO FIXME: Get/extract `Scan`-generated updates
posterior_updates: Dict[Variable, Variable] = {}

rvs_without_samplers = set()

for rv in fgraph.outputs:

if rv in obs_rvs_to_values:
continue

rv_steps = discovered_samplers.get(rv)

if not rv_steps:
rvs_without_samplers.add(rv)
continue

# TODO FIXME: Just choosing one for now, but we should consider them all.
step_desc, step, updates = rv_steps.pop()

# Expand subsumed `DimShuffle`d inputs to `Elemwise`s
if updates:
update_keys, update_values = zip(*updates.items())
else:
update_keys, update_values = tuple(), tuple()

sfgraph = FunctionGraph(
outputs=(step,) + tuple(update_keys) + tuple(update_values),
clone=False,
copy_inputs=False,
copy_orphans=False,
)

# Update the other sampled random variables in this step's graph
sfgraph.replace_all(list(posterior_sample_steps.items()), import_missing=True)

expand_subsumptions.optimize(sfgraph)

step = sfgraph.outputs[0]

# Update the other sampled random variables in this step's graph
# (step,) = clone_replace([step], replace=posterior_sample_steps)

posterior_sample_steps[rv] = step

if updates:
keys_offset = len(update_keys) + 1
update_keys = sfgraph.outputs[1:keys_offset]
update_values = sfgraph.outputs[keys_offset:]
updates = dict(zip(update_keys, update_values))
posterior_updates.update(updates)

if rvs_without_samplers:
# TODO: Assign NUTS to these
raise NotImplementedError(
f"Could not find a posterior samplers for {rvs_without_samplers}"
)

# TODO: Track/handle "auxiliary/augmentation" variables introduced by sample
# steps?

return (
{
new_to_old_rvs[rv]: step
for rv, step in posterior_sample_steps.items()
if rv not in obs_rvs_to_values
},
posterior_updates,
{new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()},
)
67 changes: 52 additions & 15 deletions aemcmc/conjugates.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import aesara.tensor as at
from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.optdb import OptimizationDatabase
from aesara.graph.opt import in2out, local_optimizer
from aesara.graph.optdb import LocalGroupDB
from aesara.graph.unify import eval_if_etuple
from aesara.tensor.random.basic import BinomialRV
from etuples import etuple, etuplize
from kanren import eq, lall
from kanren import eq, lall, run
from unification import var

conjugatesdb = OptimizationDatabase()
from aemcmc.opt import sampler_finder_db


def beta_binomial_conjugateo(observed_rv_expr, posterior_expr):
def beta_binomial_conjugateo(observed_val, observed_rv_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a beta prior with a binomial observation model.

Expand All @@ -24,15 +26,15 @@ def beta_binomial_conjugateo(observed_rv_expr, posterior_expr):

Parameters
----------
observed_val
The observed value.
observed_rv_expr
A tuple that contains expressions that represent the observed variable
and it observed value respectively.
An expression that represents the observed variable.
posterior_exp
An expression that represents the posterior distribution of the latent
variable.

"""

# Beta-binomial observation model
alpha_lv, beta_lv = var(), var()
p_rng_lv = var()
Expand All @@ -44,12 +46,10 @@ def beta_binomial_conjugateo(observed_rv_expr, posterior_expr):
n_lv = var()
Y_et = etuple(etuplize(at.random.binomial), var(), var(), var(), n_lv, p_et)

y_lv = var() # observation

# Posterior distribution for p
new_alpha_et = etuple(etuplize(at.add), alpha_lv, y_lv)
new_alpha_et = etuple(etuplize(at.add), alpha_lv, observed_val)
new_beta_et = etuple(
etuplize(at.sub), etuple(etuplize(at.add), beta_lv, n_lv), y_lv
etuplize(at.sub), etuple(etuplize(at.add), beta_lv, n_lv), observed_val
)
p_posterior_et = etuple(
etuplize(at.random.beta),
Expand All @@ -61,10 +61,47 @@ def beta_binomial_conjugateo(observed_rv_expr, posterior_expr):
)

return lall(
eq(observed_rv_expr[0], Y_et),
eq(observed_rv_expr[1], y_lv),
eq(observed_rv_expr, Y_et),
eq(posterior_expr, p_posterior_et),
)


conjugatesdb.register("beta_binomial", KanrenRelationSub(beta_binomial_conjugateo))
@local_optimizer([BinomialRV])
def local_beta_binomial_posterior(fgraph, node):

sampler_mappings = getattr(fgraph, "sampler_mappings", None)

rv_var = node.outputs[1]
key = ("local_beta_binomial_posterior", rv_var)

if sampler_mappings is None or key in sampler_mappings.rvs_seen:
return None # pragma: no cover

q = var()

rv_et = etuplize(rv_var)

res = run(None, q, beta_binomial_conjugateo(rv_var, rv_et, q))
res = next(res, None)

if res is None:
return None # pragma: no cover

beta_rv = rv_et[-1].evaled_obj
beta_posterior = eval_if_etuple(res)

sampler_mappings.rvs_to_samplers.setdefault(beta_rv, []).append(
("local_beta_binomial_posterior", beta_posterior, None)
)
sampler_mappings.rvs_seen.add(key)

return rv_var.owner.outputs


conjugates_db = LocalGroupDB(apply_all_opts=True)
conjugates_db.name = "conjugates_db"
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")

sampler_finder_db.register(
"conjugates", in2out(conjugates_db.query("+basic"), name="gibbs"), "basic"
)
52 changes: 28 additions & 24 deletions aemcmc/dists.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,21 @@ def rng_fn(cls, rng, h, z, size=None):


def multivariate_normal_rue2005(rng, b, Q):
"""
Sample from a multivariate normal distribution of the form N(Qinv * b, Qinv).
r"""Sample from a multivariate normal distribution.

More specifically, this function draws a sample from the following distribution:

.. math::

\operatorname{N}\left( Q^{-1} b, Q^{-1} \right)

We use the algorithm described in [1]. This algorithm is suitable for when
the number of regression coefficients is significantly less than the number
of data points.
It uses the algorithm described in [1]_, which is suitable for cases in
which the number of regression coefficients is significantly less than the
number of data points.

References
----------
..[1] Rue, H. and Held, L. (2005), Gaussian Markov Random Fields, Boca
.. [1] Rue, H. and Held, L. (2005), Gaussian Markov Random Fields, Boca
Raton: Chapman & Hall/CRC.
"""
if _is_sparse_variable(Q):
Expand All @@ -48,6 +53,7 @@ def multivariate_normal_rue2005(rng, b, Q):
w = at.slinalg.solve_triangular(L, b, lower=True)
u = at.slinalg.solve_triangular(L.T, w, lower=False)
z = rng.standard_normal(size=L.shape[0])
z.owner.outputs[0].name = "z_rng"
v = at.slinalg.solve_triangular(L.T, z, lower=False)
return u + v

Expand All @@ -59,44 +65,41 @@ def multivariate_normal_cong2017(
phi: TensorVariable,
t: TensorVariable,
) -> TensorVariable:
r"""
Sample from a multivariate normal distribution with a structured mean and covariance.
r"""Sample from a multivariate normal distribution with a structured mean and covariance.

As described in Example 4 [page 17] of [1], The covariance of this normal
As described in Example 4 [page 17] of [1]_, The covariance of this normal
distribution should be decomposable into a sum of a positive-definite matrix
and a low-rank symmetric matrix such that:

.. math::

\mathcal{N}(\mathbf{\Lambda}^{-1}\mathbf{\Phi}^T\mathbf{\Omega t}, \mathbf{\Lambda}^{-1})
\operatorname{N}\left(\Lambda^{-1} \Phi^{\top} \Omega t, \Lambda^{-1}\right)

where

.. math::

\begin{align*}
\mathbf{\Lambda} = (\mathbf{A}+\mathbf{\Phi}^T\mathbf{\Omega \Phi})
\end{align*}
\Lambda = A + \Phi^{\top} \Omega \Phi

and :math:`\mathbf{A}` is the positive-definite part and
:math:`\mathbf{\Phi}^T\mathbf{\Omega \Phi}` is the eigen-factorization of
and :math:`A` is the positive-definite part and
:math:`\Phi^{\top} \Omega \Phi` is the eigen-factorization of
the low-rank part of the "structured" covariance.

Parameters
----------
rng: TensorVariable
rng
The random number generating object to be used during sampling.
A: TensorVariable
A
The entries of the diagonal elements of the positive-definite part of
the structured covariance.
omega: TensorVariable
omega
The elements of the diagonal matrix in the eigen-decomposition of the
low-rank part of the structured covariancec of the multivariate normal
low-rank part of the structured covariance of the multivariate normal
distribution.
phi: TensorVariable
phi
A matrix containing the eigenvectors of the eigen-decomposition of the
low-rank part of the structured covariance of the normal distribution.
t: TensorVariable
t
A 1D array whose length is the number of eigenvalues of the low-rank
part of the structured covariance.

Expand All @@ -105,7 +108,7 @@ def multivariate_normal_cong2017(
This algorithm is suitable for high-dimensional regression problems and the
runtime scales linearly with the number of regression coefficients. This
implementation assumes that `A` and `omega` are diagonal matrices and
the parameters `A` and ``omega`` are expected to be vectors that contain
the parameters `A` and `omega` are expected to be vectors that contain
diagonal entries of the respective matrices they represent.

Note the the algorithm described in [2]_ is a special case when `omega` is
Expand All @@ -123,16 +126,17 @@ def multivariate_normal_cong2017(

References
----------
..[1] Cong, Yulai; Chen, Bo; Zhou, Mingyuan. Fast Simulation of Hyperplane-
.. [1] Cong, Yulai; Chen, Bo; Zhou, Mingyuan. Fast Simulation of Hyperplane-
Truncated Multivariate Normal Distributions. Bayesian Anal. 12 (2017),
no. 4, 1017--1037. doi:10.1214/17-BA1052.
..[2] Bhattacharya, A., Chakraborty, A., and Mallick, B. K. (2016).
.. [2] Bhattacharya, A., Chakraborty, A., and Mallick, B. K. (2016).
“Fast sampling with Gaussian scale mixture priors in high-dimensional
regression.” Biometrika, 103(4): 985.033
"""
A_inv = 1 / A
a_rows = A.shape[0]
z = rng.standard_normal(size=a_rows + omega.shape[0])
z.owner.outputs[0].name = "z_rng"
y1 = at.sqrt(A_inv) * z[:a_rows]
y2 = (1 / at.sqrt(omega)) * z[a_rows:]
Ainv_phi = A_inv[:, None] * phi.T
Expand Down
Loading