From 7d2fc53e0f15700ca5df2e92f0d08c1fff09001b Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 2 Jul 2023 19:39:29 +0530 Subject: [PATCH] Tests for Meta_info addedand rebased --- pymc/logprob/mixture.py | 45 +++++--------------------------- tests/logprob/test_mixture.py | 2 +- tests/logprob/test_transforms.py | 2 +- 3 files changed, 8 insertions(+), 41 deletions(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 5d09ce6a27a..ddfae83d73d 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -338,42 +338,6 @@ def find_measurable_index_mixture(fgraph, node): return [new_mixture_rv] -@node_rewriter([switch]) -def find_measurable_switch_mixture(fgraph, node): - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - - old_mixture_rv = node.default_output() - idx, *components = node.inputs - - if rv_map_feature.request_measurable(components) != components: - return None - - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(idx.owner.op) - - mix_op = MixtureRV( - indices_end_idx=2, - out_dtype=old_mixture_rv.dtype, - out_broadcastable=old_mixture_rv.broadcastable, - ndim_supp=ndim_supp, - supp_axes=supp_axes, - measure_type=measure_type, - ) - new_mixture_rv = mix_op.make_node( - *([NoneConst, as_nontensor_scalar(node.inputs[0])] + components[::-1]) - ).default_output() - - if pytensor.config.compute_test_value != "off": - if not hasattr(old_mixture_rv.tag, "test_value"): - compute_test_value(node) - - new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value - - return [new_mixture_rv] - - @_logprob.register(MixtureRV) def logprob_MixtureRV( op, values, *inputs: Optional[Union[TensorVariable, slice]], name=None, **kwargs @@ -455,9 +419,6 @@ class MeasurableSwitchMixture(MeasurableElemwise): valid_scalar_types = (Switch,) -measurable_switch_mixture = MeasurableSwitchMixture(scalar_switch) - - @node_rewriter([switch]) def find_measurable_switch_mixture(fgraph, node): rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) @@ -482,6 +443,12 @@ def find_measurable_switch_mixture(fgraph, node): if rv_map_feature.request_measurable(components) != components: return None + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(components[0].owner.op) + + measurable_switch_mixture = MeasurableSwitchMixture( + scalar_switch, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) + return [measurable_switch_mixture(switch_cond, *components)] diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 4fd5512014e..45518f6d8fc 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -52,7 +52,7 @@ as_index_constant, ) -from pymc.logprob.abstract import MeasurableVariable +from pymc.logprob.abstract import MeasurableVariable, get_measurable_meta_info from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.mixture import MeasurableSwitchMixture, MixtureRV, expand_indices from pymc.logprob.rewriting import construct_ir_fgraph diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index dbc9b3c2f30..17278684fdd 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -48,7 +48,7 @@ from pytensor.scan import scan from pymc.distributions.transforms import _default_transform, log, logodds -from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.logprob.abstract import MeasurableVariable, _logprob, get_measurable_meta_info from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp from pymc.logprob.transforms import ( ArccoshTransform,