Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Allow CustomDist with inferred logp in Mixture #6746

Merged
merged 7 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ jobs:
tests/logprob/test_abstract.py
tests/logprob/test_basic.py
tests/logprob/test_binary.py
tests/logprob/test_checks.py
tests/logprob/test_censoring.py
tests/logprob/test_composite_logprob.py
tests/logprob/test_cumsum.py
Expand Down
6 changes: 0 additions & 6 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

import pymc as pm

from pymc.logprob.abstract import _get_measurable_outputs
from pymc.pytensorf import convert_observed_data

__all__ = [
Expand Down Expand Up @@ -135,11 +134,6 @@ def make_node(self, rng, *args, **kwargs):
return super().make_node(rng, *args, **kwargs)


@_get_measurable_outputs.register(MinibatchIndexRV)
def minibatch_index_rv_measuarable_outputs(op, node):
return []


minibatch_index = MinibatchIndexRV()


Expand Down
3 changes: 0 additions & 3 deletions pymc/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from pymc.distributions.shape_utils import to_tuple
from pymc.distributions.transforms import _default_transform
from pymc.logprob.basic import logp
from pymc.logprob.utils import ignore_logprob
from pymc.model import modelcontext
from pymc.pytensorf import floatX, intX
from pymc.util import check_dist_not_registered
Expand Down Expand Up @@ -202,7 +201,6 @@ def __new__(
raise ValueError("Given dims do not exist in model coordinates.")

lower, upper, initval = cls._set_values(lower, upper, size, shape, initval)
dist = ignore_logprob(dist)

if isinstance(dist.owner.op, Continuous):
res = _ContinuousBounded(
Expand Down Expand Up @@ -236,7 +234,6 @@ def dist(
):
cls._argument_checks(dist, **kwargs)
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval=None)
dist = ignore_logprob(dist)
if isinstance(dist.owner.op, Continuous):
res = _ContinuousBounded.dist(
[dist, lower, upper],
Expand Down
23 changes: 1 addition & 22 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from pytensor.graph.utils import MetaType
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.var import TensorVariable
from typing_extensions import TypeAlias
Expand All @@ -49,13 +48,7 @@
shape_from_dims,
)
from pymc.exceptions import BlockModelAccessError
from pymc.logprob.abstract import (
MeasurableVariable,
_get_measurable_outputs,
_icdf,
_logcdf,
_logprob,
)
from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.model import BlockModelAccess
from pymc.printing import str_for_dist
Expand Down Expand Up @@ -401,20 +394,6 @@ def dist(
MeasurableVariable.register(SymbolicRandomVariable)


@_get_measurable_outputs.register(SymbolicRandomVariable)
def _get_measurable_outputs_symbolic_random_variable(op, node):
# This tells PyMC that any non RandomType outputs are measurable

# Assume that if there is one default_output, that's the only one that is measurable
# In the rare case this is not what one wants, a specialized _get_measuarable_outputs
# can dispatch for a subclassed Op
if op.default_output is not None:
return [node.default_output()]

# Otherwise assume that any outputs that are not of RandomType are measurable
return [out for out in node.outputs if not isinstance(out.type, RandomType)]


@node_rewriter([SymbolicRandomVariable])
def inline_symbolic_random_variable(fgraph, node):
"""
Expand Down
12 changes: 4 additions & 8 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size
from pymc.distributions.transforms import _default_transform
from pymc.distributions.truncated import Truncated
from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob, _logprob_helper
from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob
from pymc.logprob.basic import logp
from pymc.logprob.transforms import IntervalTransform
from pymc.logprob.utils import ignore_logprob
from pymc.pytensorf import floatX
from pymc.util import check_dist_not_registered
from pymc.vartypes import continuous_types, discrete_types
Expand Down Expand Up @@ -267,10 +267,6 @@ def rv_op(cls, weights, *components, size=None):

assert weights_ndim_batch == 0

# Component RVs terms are accounted by the Mixture logprob, so they can be
# safely ignored in the logprob graph
components = [ignore_logprob(component) for component in components]

# Create a OpFromGraph that encapsulates the random generating process
# Create dummy input variables with the same type as the ones provided
weights_ = weights.type()
Expand Down Expand Up @@ -350,10 +346,10 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
if len(components) == 1:
# Need to broadcast value across mixture axis
mix_axis = -components[0].owner.op.ndim_supp - 1
components_logp = _logprob_helper(components[0], pt.expand_dims(value, mix_axis))
components_logp = logp(components[0], pt.expand_dims(value, mix_axis))
else:
components_logp = pt.stack(
[_logprob_helper(component, value) for component in components],
[logp(component, value) for component in components],
axis=-1,
)

Expand Down
6 changes: 1 addition & 5 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
)
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
from pymc.logprob.abstract import _logprob
from pymc.logprob.utils import ignore_logprob
from pymc.math import kron_diag, kron_dot
from pymc.pytensorf import floatX, intX
from pymc.util import check_dist_not_registered
Expand Down Expand Up @@ -1191,9 +1190,6 @@ def dist(cls, n, eta, sd_dist, **kwargs):
raise TypeError("sd_dist must be a scalar or vector distribution variable")

check_dist_not_registered(sd_dist)
# sd_dist is part of the generative graph, but should be completely ignored
# by the logp graph, since the LKJ logp explicitly includes these terms.
sd_dist = ignore_logprob(sd_dist)
return super().dist([n, eta, sd_dist], **kwargs)

@classmethod
Expand Down Expand Up @@ -2527,7 +2523,7 @@ def check_zerosum_axes(cls, n_zerosum_axes: Optional[int]) -> int:
@classmethod
def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None):
shape = to_tuple(size) + tuple(support_shape)
normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape))
normal_dist = pm.Normal.dist(sigma=sigma, shape=shape)

if n_zerosum_axes > normal_dist.ndim:
raise ValueError("Shape of distribution is too small for the number of zerosum axes")
Expand Down
19 changes: 2 additions & 17 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from pymc.exceptions import NotConstantValueError
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import logp
from pymc.logprob.utils import ignore_logprob, reconsider_logprob
from pymc.pytensorf import constant_fold, floatX, intX
from pymc.util import check_dist_not_registered

Expand Down Expand Up @@ -111,11 +110,6 @@ def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> pt.TensorVari
if init_dist in ancestors([innovation_dist]) or innovation_dist in ancestors([init_dist]):
raise ValueError("init_dist and innovation_dist must be completely independent")

# PyMC should not be concerned that these variables don't have values, as they will be
# accounted for in the logp of RandomWalk
init_dist = ignore_logprob(init_dist)
innovation_dist = ignore_logprob(innovation_dist)

steps = cls.get_steps(
innovation_dist=innovation_dist,
steps=steps,
Expand Down Expand Up @@ -235,14 +229,12 @@ def random_walk_moment(op, rv, init_dist, innovation_dist, steps):


@_logprob.register(RandomWalkRV)
def random_walk_logp(op, values, init_dist, innovation_dist, steps, **kwargs):
def random_walk_logp(op, values, *inputs, **kwargs):
# Although we can derive the logprob of random walks, it does not collapse
# what we consider the core dimension of steps. We do it manually here.
(value,) = values
# Recreate RV and obtain inner graph
rv_node = op.make_node(
reconsider_logprob(init_dist), reconsider_logprob(innovation_dist), steps
)
rv_node = op.make_node(*inputs)
rv = clone_replace(
op.inner_outputs, replace={u: v for u, v in zip(op.inner_inputs, rv_node.inputs)}
)[op.default_output]
Expand Down Expand Up @@ -571,9 +563,6 @@ def dist(
)
init_dist = Normal.dist(0, 100, shape=(*sigma.shape, ar_order))

# We can ignore init_dist, as it will be accounted for in the logp term
init_dist = ignore_logprob(init_dist)

return super().dist([rhos, sigma, init_dist, steps, ar_order, constant], **kwargs)

@classmethod
Expand Down Expand Up @@ -789,8 +778,6 @@ def dist(cls, omega, alpha_1, beta_1, initial_vol, *, steps=None, **kwargs):
initial_vol = pt.as_tensor_variable(initial_vol)

init_dist = Normal.dist(0, initial_vol)
# We can ignore init_dist, as it will be accounted for in the logp term
init_dist = ignore_logprob(init_dist)

return super().dist([omega, alpha_1, beta_1, initial_vol, init_dist, steps], **kwargs)

Expand Down Expand Up @@ -973,8 +960,6 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
UserWarning,
)
init_dist = Normal.dist(0, 100, shape=sde_pars[0].shape)
# We can ignore init_dist, as it will be accounted for in the logp term
init_dist = ignore_logprob(init_dist)

return super().dist([init_dist, steps, sde_pars, dt, sde_fn], **kwargs)

Expand Down
118 changes: 10 additions & 108 deletions pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,9 @@

import abc

from copy import copy
from functools import singledispatch
from typing import Callable, List, Sequence, Tuple
from typing import Sequence, Tuple

from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op
from pytensor.graph.utils import MetaType
from pytensor.tensor import TensorVariable
Expand Down Expand Up @@ -69,10 +67,15 @@ def _logprob_helper(rv, *values, **kwargs):
"""Helper that calls `_logprob` dispatcher."""
logprob = _logprob(rv.owner.op, values, *rv.owner.inputs, **kwargs)

for rv in values:
if rv.name:
logprob.name = f"{rv.name}_logprob"
break
name = rv.name
if (not name) and (len(values) == 1):
name = values[0].name
if name:
if isinstance(logprob, (list, tuple)):
for i, term in enumerate(logprob):
term.name = f"{name}_logprob.{i}"
else:
logprob.name = f"{name}_logprob"

return logprob

Expand Down Expand Up @@ -135,107 +138,6 @@ class MeasurableVariable(abc.ABC):
MeasurableVariable.register(RandomVariable)


class UnmeasurableMeta(MetaType):
def __new__(cls, name, bases, dict):
if "id_obj" not in dict:
dict["id_obj"] = None

return super().__new__(cls, name, bases, dict)

def __eq__(self, other):
if isinstance(other, UnmeasurableMeta):
return hash(self.id_obj) == hash(other.id_obj)
return False

def __hash__(self):
return hash(self.id_obj)


class UnmeasurableVariable(metaclass=UnmeasurableMeta):
"""
id_obj is an attribute, i.e. tuple of length two, of the unmeasurable class object.
e.g. id_obj = (NormalRV, noop_measurable_outputs_fn)
"""


def get_measurable_outputs(op: Op, node: Apply) -> List[Variable]:
"""Return only the outputs that are measurable."""
if isinstance(op, MeasurableVariable):
return _get_measurable_outputs(op, node)
else:
return []


@singledispatch
def _get_measurable_outputs(op, node):
return node.outputs


@_get_measurable_outputs.register(RandomVariable)
def _get_measurable_outputs_RandomVariable(op, node):
return node.outputs[1:]


def noop_measurable_outputs_fn(*args, **kwargs):
return []


def assign_custom_measurable_outputs(
node: Apply,
measurable_outputs_fn: Callable = noop_measurable_outputs_fn,
type_prefix: str = "Unmeasurable",
) -> Apply:
"""Assign a custom ``_get_measurable_outputs`` dispatch function to a measurable variable instance.

The node is cloned and a custom `Op` that's a copy of the original node's
`Op` is created. That custom `Op` replaces the old `Op` in the cloned
node, and then a custom dispatch implementation is created for the clone
`Op` in `_get_measurable_outputs`.

If `measurable_outputs_fn` isn't specified, a no-op is used; the result is
a clone of `node` that will effectively be ignored by
`factorized_joint_logprob`.

Parameters
----------
node
The node to recreate with a new cloned `Op`.
measurable_outputs_fn
The function that will be assigned to the new cloned `Op` in the
`_get_measurable_outputs` dispatcher.
The default is a no-op function (i.e. no measurable outputs)
type_prefix
The prefix used for the new type's name.
The default is ``"Unmeasurable"``, which matches the default
``"measurable_outputs_fn"``.
"""

new_node = node.clone()
op_type = type(new_node.op)

if op_type in _get_measurable_outputs.registry.keys() and isinstance(op_type, UnmeasurableMeta):
if _get_measurable_outputs.registry[op_type] != measurable_outputs_fn:
raise ValueError(
f"The type {op_type.__name__} with hash value {hash(op_type)} "
"has already been dispatched a measurable outputs function."
)
return node

new_op_dict = op_type.__dict__.copy()
new_op_dict["id_obj"] = (new_node.op, measurable_outputs_fn)
new_op_dict.setdefault("original_op_type", op_type)

new_op_type = type(
f"{type_prefix}{op_type.__name__}", (op_type, UnmeasurableVariable), new_op_dict
)
new_node.op = copy(new_node.op)
new_node.op.__class__ = new_op_type

_get_measurable_outputs.register(new_op_type)(measurable_outputs_fn)

return new_node


class MeasurableElemwise(Elemwise):
"""Base class for Measurable Elemwise variables"""

Expand Down
Loading