-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add support for symbolic initval using a singledispatch approach #4912
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,13 +19,15 @@ | |
import warnings | ||
|
||
from abc import ABCMeta | ||
from functools import singledispatch | ||
from typing import Optional | ||
|
||
import aesara | ||
import aesara.tensor as at | ||
|
||
from aesara.tensor.random.op import RandomVariable | ||
from aesara.tensor.random.var import RandomStateSharedVariable | ||
from aesara.tensor.var import TensorVariable | ||
|
||
from pymc3.aesaraf import change_rv_size | ||
from pymc3.distributions import _logcdf, _logp | ||
|
@@ -107,6 +109,13 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs): | |
value_var = rvs_to_values.get(var, var) | ||
return class_logcdf(value_var, *dist_params, **kwargs) | ||
|
||
class_initval = clsdict.get("get_moment") | ||
if class_initval: | ||
|
||
@_get_moment.register(rv_type) | ||
def get_moment(op, rv, size, *rv_inputs): | ||
return class_initval(rv, size, *rv_inputs) | ||
|
||
# Register the Aesara `RandomVariable` type as a subclass of this | ||
# `Distribution` type. | ||
new_cls.register(rv_type) | ||
|
@@ -328,6 +337,24 @@ def dist( | |
return rv_out | ||
|
||
|
||
@singledispatch | ||
def _get_moment(op, rv, size, *rv_inputs) -> TensorVariable: | ||
"""Fallback method for creating an initial value for a random variable. | ||
|
||
Parameters are the same as for the `.dist()` method. | ||
""" | ||
return None | ||
|
||
|
||
def get_moment(rv: TensorVariable) -> TensorVariable: | ||
"""Fallback method for creating an initial value for a random variable. | ||
|
||
Parameters are the same as for the `.dist()` method. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this was merged already, but this part of the docstrings is wrong There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, we should fix that then. CC @kc611 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes I missed that, those docstrings were supposed to be removed. I'm not sure what (docstring) will go in it's place though. Maybe I should just remove them for now ? We can add a proper explanation when we give the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, then just remove them for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did it in #4979 |
||
""" | ||
size = rv.owner.inputs[1] | ||
return _get_moment(rv.owner.op, rv, size, *rv.owner.inputs[3:]) | ||
|
||
|
||
class NoDistribution(Distribution): | ||
def __init__( | ||
self, | ||
|
Uh oh!
There was an error while loading. Please reload this page.