Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Incorporate RV meta information in intermediate MeasurableVariables #6360

Open
ricardoV94 opened this issue Nov 30, 2022 · 3 comments
Open

Comments

@ricardoV94
Copy link
Member

Description

This pertains to the logprob submodule. During logprob derivation of an expression like

import numpy as np
import pymc as pm

x_raw = pm.Normal.dist(np.arange(5), shape=(2, 5))
x = pm.math.clip(x_raw, -1, 1)  # Censored normal

pm.logp(x, np.zeros((2, 5)))

We create a MeasurableClip that replaces x, when we identify we can derive the logprob as a simple censored pdf. This MeasurableClip however does not retain any of the meta-information about the type of RV that it encapsulates (ndim_supp, dtype, support axis).

class MeasurableClip(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a clipped RV sub-graph."""
valid_scalar_types = (Clip,)
measurable_clip = MeasurableClip(scalar_clip)

If we wanted to further compose the graph, we would find issues when some operation needs to know this information

x_raw = pm.Normal.dist(np.arange(5), shape=(2, 5))
x = pm.math.clip(x_raw, -1, 1)  # Censored normal<
x = x.T

pm.logp(x, np.zeros((5, 2)))  # NotImplementedError: PyMC could not infer logp of input variable.

This happens because to infer the logprob of a transposed (dimshuffled) variable, we need to know the original support dimensionality and support axis (which is always the rightmost for pure distributions):

pymc/pymc/logprob/tensor.py

Lines 285 to 298 in a0d6ba0

# We can only apply this rewrite directly to `RandomVariable`s, as those are
# the only `Op`s for which we always know the support axis. Other measurable
# variables can have arbitrary support axes (e.g., if they contain separate
# `MeasurableDimShuffle`s). Most measurable variables with `DimShuffle`s
# should still be supported as long as the `DimShuffle`s can be merged/
# lifted towards the base RandomVariable.
# TODO: If we include the support axis as meta information in each
# intermediate MeasurableVariable, we can lift this restriction.
if not (
base_var.owner
and isinstance(base_var.owner.op, RandomVariable)
and base_var not in rv_map_feature.rv_values
):
return None # pragma: no cover

If we propagated that information to the MeasurableClip (ndim_supp=0, support_axis=None, dtype="mixed"), the Dimshuffle rewrite could be safely used and we could derive the logp for the second example. This is also useful for other rewrites...

More context in aesara-devs/aeppl#183

@Dhruvanshu-Joshi
Copy link
Member

Hey @ricardoV94 . I would like to work on this issue. From what I understand, we will need to pass meta-information all Measurable ops like MeasurableClip, MeasurableROund, MeasurableDimShuffle, etc using an init function. If you think my understanding and approach are the right direction, I'll proceed with implementing this.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Apr 20, 2023

Yes. And the meta-info we want (at least for now) is: ndim_supp, supp_axes, and something new like type which can be discrete, continuous, or mixed.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jul 5, 2023

We probably will want to further specify the measure_type for multivariate distributions, to distinguish what Tensorflow probability calls FullSpace (for which #6797 is correct) from what is not: https://www.tensorflow.org/probability/api_docs/python/tfp/experimental/tangent_spaces

In general supporting the change of "constrained" multivariate variables (Like Dirichlet or ZeroSumNormal) would require something like https://proceedings.mlr.press/v130/radul21a.html, which is what tfp does.

More context in the discussion surrounding #6808

# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
None yet
Development

No branches or pull requests

2 participants