Skip to content

Commit

Permalink
Add SubsumingElemwise and associated rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 13, 2022
1 parent e9b99f1 commit d57903a
Show file tree
Hide file tree
Showing 4 changed files with 312 additions and 5 deletions.
4 changes: 2 additions & 2 deletions aemcmc/conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from kanren import eq, lall, run
from unification import var

from aemcmc.opt import sampler_rewrites_db
from aemcmc.opt import sampler_finder_db


def beta_binomial_conjugateo(observed_val, observed_rv_expr, posterior_expr):
Expand Down Expand Up @@ -102,6 +102,6 @@ def local_beta_binomial_posterior(fgraph, node):
conjugates_db.name = "conjugates_db"
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")

sampler_rewrites_db.register(
sampler_finder_db.register(
"conjugates", in2out(conjugates_db.query("+basic"), name="gibbs"), "basic"
)
4 changes: 2 additions & 2 deletions aemcmc/gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
multivariate_normal_rue2005,
polyagamma,
)
from aemcmc.opt import sampler_finder, sampler_rewrites_db
from aemcmc.opt import sampler_finder, sampler_finder_db

gibbs_db = LocalGroupDB(apply_all_opts=True)
gibbs_db.name = "gibbs_db"
Expand Down Expand Up @@ -779,6 +779,6 @@ def bernoulli_horseshoe_step(
return outputs, updates


sampler_rewrites_db.register(
sampler_finder_db.register(
"gibbs_db", in2out(gibbs_db.query("+basic"), name="gibbs"), "basic"
)
200 changes: 199 additions & 1 deletion aemcmc/opt.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import sys
from collections.abc import Mapping
from functools import wraps
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union

from aeppl.opt import PreserveRVMappings
from aesara.compile.builders import OpFromGraph, inline_ofg_expansion
from aesara.compile.mode import optdb
from aesara.graph.basic import Apply, Variable, io_toposort
from aesara.graph.features import AlreadyThere, Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import local_optimizer
from aesara.graph.opt import in2out, local_optimizer
from aesara.graph.optdb import SequenceDB
from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.var import TensorVariable
from cons.core import _car
from unification.core import _unify

SamplerFunctionReturnType = Optional[
Iterable[Tuple[Variable, Variable, Union[Dict[Variable, Variable]]]]
Expand All @@ -33,6 +39,16 @@
sampler_rewrites_db = SequenceDB()
sampler_rewrites_db.name = "sampler_rewrites_db"

sampler_finder_db = SequenceDB()
sampler_finder_db.name = "sampler_finder_db"

sampler_rewrites_db.register(
"sampler_finders",
sampler_finder_db,
"basic",
position=0,
)


def construct_ir_fgraph(
obs_rvs_to_values: Dict[Variable, Variable]
Expand Down Expand Up @@ -157,3 +173,185 @@ def sampler_finder(
return sampler_finder

return decorator


class SubsumingElemwise(OpFromGraph, Elemwise):
r"""A class represents an `Elemwise` with `DimShuffle`\ed arguments."""

def __init__(self, inputs, outputs, *args, **kwargs):
# TODO: Mock the `Elemwise` interface just enough for our purposes
self.elemwise_op = outputs[0].owner.op
self.scalar_op = self.elemwise_op.scalar_op
self.nfunc_spec = self.elemwise_op.nfunc_spec
self.inplace_pattern = self.elemwise_op.inplace_pattern
self.destroy_map = self.elemwise_op.destroy_map
OpFromGraph.__init__(self, inputs, outputs, *args, **kwargs)

def make_node(self, *inputs):
node = super().make_node(*inputs)
# Remove shared variable inputs.
# N.B. We aren't going to compute with this `Op`, so they're not needed
real_inputs = node.inputs[: len(node.inputs) - len(self.shared_inputs)]
return Apply(self, real_inputs, [o.clone() for o in node.outputs])

def perform(self, *args, **kwargs):
raise NotImplementedError(
"This `OpFromGraph` should have been in-line expanded."
)

def __str__(self):
return repr(self)

def __repr__(self):
return f"{type(self).__name__}{{{self.scalar_op}}}"

def __eq__(self, other):
# TODO: How are we going to make this work as expected?
if self is other:
return True

if isinstance(other, Elemwise):
return self.elemwise_op == other

return super().__eq__(other)

def __hash__(self):
return hash(self.elemwise_op)


def _unify_SubsumingElemwise(u: Elemwise, v: SubsumingElemwise, s: Mapping):
yield _unify(u, v.elemwise_op, s)


_unify.add(
(Elemwise, SubsumingElemwise, Mapping),
lambda u, v, s: _unify_SubsumingElemwise(u, v, s),
)
_unify.add(
(SubsumingElemwise, Elemwise, Mapping),
lambda u, v, s: _unify_SubsumingElemwise(v, u, s),
)
_unify.add(
(SubsumingElemwise, SubsumingElemwise, Mapping),
lambda u, v, s: _unify(v.elemwise_op, u.elemwise_op, s),
)


def car_SubsumingElemwise(x):
return type(x.elemwise_op)


_car.add((SubsumingElemwise,), car_SubsumingElemwise)


@local_optimizer([Elemwise])
def local_elemwise_dimshuffle_subsume(fgraph, node):
r"""This rewrite converts `DimShuffle`s in the `Elemwise` inputs into a single `Op`.
The replacement rule is
.. math:
\frac{
\operatorname{Elemwise}_{o}\left(
\operatorname{DimShuffle}_{z_i}(x_i), \dots
\right)
}{
\operatorname{OpFromGraph}_{\operatorname{Elemwise}_{o}\left(
\operatorname{DimShuffle}_{z_i}(y_i), \dots
\right)}\left(
x_i, \dots
\right)
}
//, \quad
// x_i \text{ is a } \operatorname{RandomVariable}
where :math:`o` is a scalar `Op`, :math:`z_i` are the `DimShuffle` settings
for the inputs at index :math:`i`.
"""

new_inputs = []
subsumed_inputs = []

out_ndim = node.outputs[0].type.ndim

found_subsumable_ds = False
for i in node.inputs:
if i.owner and isinstance(i.owner.op, DimShuffle):
# TODO FIXME: Only do this when the `DimShuffle`s are adding
# broadcastable dimensions. If they're doing more
# (e.g. transposing), separate the broadcasting from everything
# else.
ds_order = i.owner.op.new_order
dim_shuffle_input = i.owner.inputs[0]

ndim_diff = out_ndim - dim_shuffle_input.type.ndim

# The `DimShuffle`ing added by `Elemwise`
el_order = ds_order[:ndim_diff]
# The remaining `DimShuffle`ing that was added by something else
new_ds_order = ds_order[ndim_diff:]

# Only consider broadcast dimensions added on the left as
# having come from `Elemwise.make_node`
if len(el_order) == 0 or not all(d == "x" for d in el_order):
# In this case, the necessary broadcast elements were most
# likely not added by `Elemwise.make_node` (e.g. broadcasts are
# interspersed with transposes, or there are none at all), so
# we don't want to mess with them.
# TODO: We could still subsume some of these `DimShuffle`s,
# though
subsumed_inputs.append(i)
new_inputs.append(i)
continue

# if dim_shuffle_input.owner and isinstance(
# dim_shuffle_input.owner.op, RandomVariable
# ):
found_subsumable_ds = True

if new_ds_order and not new_ds_order == tuple(range(len(new_ds_order))):
# The remaining `DimShuffle`ing is substantial, so we need to
# apply it separately
new_dim_shuffle_input = dim_shuffle_input.dimshuffle(new_ds_order)
new_subsumed_input = new_dim_shuffle_input.dimshuffle(
el_order + tuple(range(new_dim_shuffle_input.type.ndim))
)

subsumed_inputs.append(new_subsumed_input)
new_inputs.append(new_dim_shuffle_input)
else:
subsumed_inputs.append(i)
new_inputs.append(dim_shuffle_input)

else:
subsumed_inputs.append(i)
new_inputs.append(i)

if not found_subsumable_ds:
return None # pragma: no cover

assert len(subsumed_inputs) == len(node.inputs)
new_outputs = node.op.make_node(*subsumed_inputs).outputs
new_op = SubsumingElemwise(new_inputs, new_outputs, inline=True)

new_out = new_op(*new_inputs)

return new_out.owner.outputs


sampler_ir_db.register(
"elemwise_dimshuffle_subsume",
in2out(local_elemwise_dimshuffle_subsume),
"basic",
position=-10,
)

# This step undoes `elemwise_dimshuffle_subsume`
sampler_rewrites_db.register(
"inline_ofg_expansion",
in2out(inline_ofg_expansion),
"basic",
position=sys.maxsize,
)
109 changes: 109 additions & 0 deletions tests/test_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import aesara.tensor as at
import numpy as np
from aesara.graph.basic import equal_computations
from aesara.tensor.elemwise import DimShuffle, Elemwise
from cons import car, cdr
from etuples import etuple, etuplize
from unification import unify

from aemcmc.opt import SubsumingElemwise, local_elemwise_dimshuffle_subsume


def test_SubsumingElemwise_basics():
a = at.vector("a")
b = at.scalar("b")

x = a * b

assert isinstance(x.owner.op, Elemwise)
b_ds = x.owner.inputs[1].owner.op
assert isinstance(b_ds, DimShuffle)

ee_mul_op = SubsumingElemwise([a, b], [x])

assert ee_mul_op == x.owner.op
assert x.owner.op == ee_mul_op

s = unify(at.mul, ee_mul_op)
assert s is not False

assert car(ee_mul_op) == car(x.owner.op)
assert cdr(ee_mul_op) == cdr(x.owner.op)

s = unify(etuplize(at.mul), etuplize(ee_mul_op))
assert s is not False

ee_et = etuplize(ee_mul_op(a, b))
x_et = etuple(etuplize(at.mul), a, b)

s = unify(ee_et, x_et)
assert s is not False

# TODO: Consider making this possible
# s = unify(ee_mul(a, b), x)
# assert s is not False


def test_local_elemwise_dimshuffle_subsume_basic():
srng = at.random.RandomStream(2398)

a = at.vector("a")
b = srng.normal(0, 1, name="b")

x = a * b

node = x.owner
assert isinstance(node.op, Elemwise)
b_ds = node.inputs[1].owner.op
assert isinstance(b_ds, DimShuffle)

(res,) = local_elemwise_dimshuffle_subsume.transform(None, node)
assert isinstance(res.owner.op, SubsumingElemwise)
assert equal_computations(
[res.owner.op.inner_outputs[0]], [x], res.owner.op.inner_inputs[:2], [a, b]
)
assert res.owner.inputs == [a, b]


def test_local_elemwise_dimshuffle_subsume_transpose():
"""Make sure that `local_elemwise_dimshuffle_subsume` is applied selectively."""
srng = at.random.RandomStream(2398)

a = at.vector("a")
# This transpose shouldn't be subsumed, but the one applied to `a` by
# `Elemwise.make_node` should
b = srng.normal(at.arange(4).reshape((2, 2)), 1, name="b").T

x = a * b

node = x.owner
assert isinstance(node.op, Elemwise)
b_ds = node.inputs[1].owner.op
assert isinstance(b_ds, DimShuffle)

(res,) = local_elemwise_dimshuffle_subsume.transform(None, node)
assert isinstance(res.owner.op, SubsumingElemwise)
assert equal_computations(
[res.owner.op.inner_outputs[0]], [x], res.owner.op.inner_inputs[:2], [a, b]
)
assert res.owner.inputs == [a, b]

a = at.tensor(np.float64, shape=(None, None, None), name="a")
# Again, the transpose part shouldn't be subsumed, but the added broadcast
# dimension should
b = srng.normal(at.arange(4).reshape((2, 2)), 1, name="b")
b_ds = b.dimshuffle(("x", 1, 0))

x = a * b_ds

node = x.owner
assert isinstance(node.op, Elemwise)
b_ds = node.inputs[1].owner.op
assert isinstance(b_ds, DimShuffle)

(res,) = local_elemwise_dimshuffle_subsume.transform(None, node)
assert isinstance(res.owner.op, SubsumingElemwise)
assert res.owner.inputs[0] == a
# The input corresponding to `b`/`b_ds` should be equivalent to `b.T`
assert isinstance(res.owner.inputs[1].owner.op, DimShuffle)
assert equal_computations([b.T], [res.owner.inputs[1]])

0 comments on commit d57903a

Please # to comment.