Skip to content

Commit

Permalink
Tests for Meta_info addedand rebased
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Jul 2, 2023
1 parent af249d1 commit 7d2fc53
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 41 deletions.
45 changes: 6 additions & 39 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,42 +338,6 @@ def find_measurable_index_mixture(fgraph, node):
return [new_mixture_rv]


@node_rewriter([switch])
def find_measurable_switch_mixture(fgraph, node):
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

old_mixture_rv = node.default_output()
idx, *components = node.inputs

if rv_map_feature.request_measurable(components) != components:
return None

ndim_supp, supp_axes, measure_type = get_measurable_meta_info(idx.owner.op)

mix_op = MixtureRV(
indices_end_idx=2,
out_dtype=old_mixture_rv.dtype,
out_broadcastable=old_mixture_rv.broadcastable,
ndim_supp=ndim_supp,
supp_axes=supp_axes,
measure_type=measure_type,
)
new_mixture_rv = mix_op.make_node(
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + components[::-1])
).default_output()

if pytensor.config.compute_test_value != "off":
if not hasattr(old_mixture_rv.tag, "test_value"):
compute_test_value(node)

new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value

return [new_mixture_rv]


@_logprob.register(MixtureRV)
def logprob_MixtureRV(
op, values, *inputs: Optional[Union[TensorVariable, slice]], name=None, **kwargs
Expand Down Expand Up @@ -455,9 +419,6 @@ class MeasurableSwitchMixture(MeasurableElemwise):
valid_scalar_types = (Switch,)


measurable_switch_mixture = MeasurableSwitchMixture(scalar_switch)


@node_rewriter([switch])
def find_measurable_switch_mixture(fgraph, node):
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
Expand All @@ -482,6 +443,12 @@ def find_measurable_switch_mixture(fgraph, node):
if rv_map_feature.request_measurable(components) != components:
return None

ndim_supp, supp_axes, measure_type = get_measurable_meta_info(components[0].owner.op)

measurable_switch_mixture = MeasurableSwitchMixture(
scalar_switch, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type
)

return [measurable_switch_mixture(switch_cond, *components)]


Expand Down
2 changes: 1 addition & 1 deletion tests/logprob/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
as_index_constant,
)

from pymc.logprob.abstract import MeasurableVariable
from pymc.logprob.abstract import MeasurableVariable, get_measurable_meta_info
from pymc.logprob.basic import conditional_logp, logp
from pymc.logprob.mixture import MeasurableSwitchMixture, MixtureRV, expand_indices
from pymc.logprob.rewriting import construct_ir_fgraph
Expand Down
2 changes: 1 addition & 1 deletion tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from pytensor.scan import scan

from pymc.distributions.transforms import _default_transform, log, logodds
from pymc.logprob.abstract import MeasurableVariable, _logprob
from pymc.logprob.abstract import MeasurableVariable, _logprob, get_measurable_meta_info
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
from pymc.logprob.transforms import (
ArccoshTransform,
Expand Down

0 comments on commit 7d2fc53

Please # to comment.