Skip to content

Commit

Permalink
Logprob tests passed
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Jun 10, 2023
1 parent 6a9db0e commit 0ef5aa6
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 36 deletions.
2 changes: 1 addition & 1 deletion pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def __init__(self, scalar_op, *args, **kwargs):
super().__init__(scalar_op, *args, **kwargs)


def get_measurable_meta_info(base_op: Op) -> Tuple[int, Tuple[int], MeasureType]:
def get_measurable_meta_info(base_op: MeasurableVariable) -> Tuple[int, Tuple[int], MeasureType]:
if not isinstance(base_op, MeasurableVariable):
raise TypeError("base_op must be a RandomVariable or MeasurableVariable")

Expand Down
16 changes: 13 additions & 3 deletions pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ def find_measurable_comparisons(
ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_var.owner.op)

# ndim_supp, supp_axis, d_type = get_default_measurable_metainfo(base_var.owner.op)
compared_op = MeasurableComparison(node_scalar_op, ndim_supp, supp_axes, measure_type)
compared_op = MeasurableComparison(
scalar_op=node_scalar_op,
ndim_supp=ndim_supp,
supp_axes=supp_axes,
measure_type=measure_type,
)
compared_rv = compared_op.make_node(unmeasurable_base_var, const).default_output()
compared_rv.name = node.outputs[0].name
return [compared_rv]
Expand Down Expand Up @@ -159,7 +164,7 @@ def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[List[

base_var = node.inputs[0]
if isinstance(base_var.owner.op, MeasurableVariable):
ndim_supp, supp_axis, d_type = get_default_measurable_metainfo(base_var.owner.op, base_var)
ndim_supp, supp_axis, measure_type = get_measurable_meta_info(base_var.owner.op)

if not (
base_var.owner
Expand All @@ -177,7 +182,12 @@ def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[List[
node_scalar_op = node.op.scalar_op

# ndim_supp, supp_axis, d_type = get_default_measurable_metainfo(base_var.owner.op)
bitwise_op = MeasurableBitwise(node_scalar_op, ndim_supp, supp_axis, d_type)
bitwise_op = MeasurableBitwise(
scalar_op=node_scalar_op,
ndim_supp=ndim_supp,
supp_axes=supp_axis,
measure_type=measure_type,
)
bitwise_rv = bitwise_op.make_node(unmeasurable_base_var).default_output()
bitwise_rv.name = node.outputs[0].name
return [bitwise_rv]
Expand Down
8 changes: 3 additions & 5 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[List[Me
# Make base_var unmeasurable
ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_var.owner.op)
measurable_clip = MeasurableClip(
scalar_clip, ndim_supp=ndim_supp, support_axis=supp_axes, d_type=measure_type
scalar_clip, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type
)
unmeasurable_base_var = ignore_logprob(base_var)
clipped_rv_node = measurable_clip.make_node(unmeasurable_base_var, lower_bound, upper_bound)
Expand Down Expand Up @@ -201,11 +201,9 @@ def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[Lis
# Make base_var unmeasurable
unmeasurable_base_var = ignore_logprob(base_var)

ndim_supp, supp_axis, d_type = get_default_measurable_metainfo(
base_var.owner.op, base_var.owner.op
)
ndim_supp, supp_axis, d_type = get_measurable_meta_info(base_var.owner.op)
rounded_op = MeasurableRound(
node.op.scalar_op, ndim_supp=ndim_supp, support_axis=supp_axis, d_type=d_type
node.op.scalar_op, ndim_supp=ndim_supp, supp_axes=supp_axis, measure_type=d_type
)
rounded_rv = rounded_op.make_node(unmeasurable_base_var).default_output()
rounded_rv.name = rounded_var.name
Expand Down
14 changes: 10 additions & 4 deletions pymc/logprob/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ def find_measurable_specify_shapes(fgraph, node) -> Optional[List[MeasurableSpec
):
return None # pragma: no cover

ndim_supp, supp_axes, measuer_type = get_measurable_meta_info(base_rv.owner.op)
new_op = MeasurableSpecifyShape(ndim_supp, supp_axes, measuer_type)
ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rv.owner.op)
new_op = MeasurableSpecifyShape(
ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type
)
# Make base_var unmeasurable
unmeasurable_base_rv = ignore_logprob(base_rv)
new_rv = new_op.make_node(unmeasurable_base_rv, *shape).default_output()
Expand Down Expand Up @@ -141,9 +143,13 @@ def find_measurable_asserts(fgraph, node) -> Optional[List[MeasurableCheckAndRai
return None # pragma: no cover

op = node.op
ndim_supp, supp_axis, d_type = get_default_measurable_metainfo(base_rv.owner.op, base_rv)
ndim_supp, supp_axis, d_type = get_measurable_meta_info(base_rv.owner.op)
new_op = MeasurableCheckAndRaise(
exc_type=op.exc_type, msg=op.msg, ndim_supp=ndim_supp, support_axis=supp_axis, d_type=d_type
exc_type=op.exc_type,
msg=op.msg,
ndim_supp=ndim_supp,
supp_axes=supp_axis,
measure_type=d_type,
)
# Make base_var unmeasurable
unmeasurable_base_rv = ignore_logprob(base_rv)
Expand Down
4 changes: 2 additions & 2 deletions pymc/logprob/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]:
ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rv.owner.op)
new_op = MeasurableCumsum(
ndim_supp=ndim_supp,
support_axis=supp_axes,
d_type=measure_type,
supp_axes=supp_axes,
measure_type=measure_type,
axis=node.op.axis or 0,
mode="add",
)
Expand Down
33 changes: 24 additions & 9 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class MixtureRV(MeasurableVariable, Op):
__props__ = ("indices_end_idx", "out_dtype", "out_broadcastable")

def __init__(self, *args, indices_end_idx, out_dtype, out_broadcastable, **kwargs):
super().__init__()
# super().__init__(*args, **kwargs)
self.indices_end_idx = indices_end_idx
self.out_dtype = out_dtype
self.out_broadcastable = out_broadcastable
Expand Down Expand Up @@ -331,11 +331,16 @@ def mixture_replace(fgraph, node):
new_comp_rv = ignore_logprob(component_rv)
new_mixture_rvs.append(new_comp_rv)

ndim_supp, supp_axes, measure_type = get_measurable_meta_info(mixture_rvs[0].owner.op)

# Replace this sub-graph with a `MixtureRV`
mix_op = MixtureRV(
1 + len(mixing_indices),
old_mixture_rv.dtype,
old_mixture_rv.broadcastable,
ndim_supp=ndim_supp,
supp_axes=supp_axes,
measure_type=measure_type,
indices_end_idx=1 + len(mixing_indices),
out_dtype=old_mixture_rv.dtype,
out_broadcastable=old_mixture_rv.broadcastable,
)
new_node = mix_op.make_node(*([join_axis] + mixing_indices + new_mixture_rvs))

Expand Down Expand Up @@ -380,10 +385,15 @@ def switch_mixture_replace(fgraph, node):
new_comp_rv = ignore_logprob(component_rv)
mixture_rvs.append(new_comp_rv)

ndim_supp, supp_axes, measure_type = get_measurable_meta_info(mixture_rvs[0].owner.op)

mix_op = MixtureRV(
2,
old_mixture_rv.dtype,
old_mixture_rv.broadcastable,
indices_end_idx=2,
out_dtype=old_mixture_rv.dtype,
out_broadcastable=old_mixture_rv.broadcastable,
ndim_supp=ndim_supp,
supp_axes=supp_axes,
measure_type=measure_type,
)
new_node = mix_op.make_node(*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs))

Expand Down Expand Up @@ -519,10 +529,15 @@ def find_measurable_ifelse_mixture(fgraph, node):
unmeasurable_base_rvs = ignore_logprob_multiple_vars(base_rvs, rv_map_feature.rv_values)

# TODO: Assert all base_rvs have the same meta-info types!
ndim_supp, supp_axis, d_type = get_measurable_meta_info(base_rvs[0].owner.op)
ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rvs[0].owner.op)

return (
MeasurableIfElse(ndim_supp, supp_axis, d_type, n_outs=node.op.n_outs)
MeasurableIfElse(
ndim_supp=ndim_supp,
supp_axes=supp_axes,
measure_type=measure_type,
n_outs=node.op.n_outs,
)
.make_node(if_var, *unmeasurable_base_rvs)
.outputs
)
Expand Down
2 changes: 1 addition & 1 deletion pymc/logprob/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def find_measurable_scans(fgraph, node):
curr_scanargs.inner_outputs,
curr_scanargs.info,
ndim_supp=all_ndim_supp,
support_axis=all_supp_axes,
supp_axes=all_supp_axes,
measure_type=all_measure_type,
mode=node.op.mode,
)
Expand Down
18 changes: 9 additions & 9 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,13 @@ def find_measurable_stacks(
ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_vars[0].owner.op)

if is_join:
measurable_stack = MeasurableJoin(axis, ndim_supp, supp_axes, measure_type)(
axis, *unmeasurable_base_vars
)
measurable_stack = MeasurableJoin(
axis, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type
)(axis, *unmeasurable_base_vars)
else:
measurable_stack = MeasurableMakeVector(node.op.dtype, ndim_supp, supp_axes, measure_type)(
*unmeasurable_base_vars
)
measurable_stack = MeasurableMakeVector(
node.op.dtype, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type
)(*unmeasurable_base_vars)

measurable_stack.name = stack_out.name

Expand Down Expand Up @@ -317,14 +317,14 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuf
# Make base_vars unmeasurable
base_var = ignore_logprob(base_var)

ndim_supp, supp_axis, d_type = get_default_measurable_metainfo(base_var.owner.op, base_var)
ndim_supp, supp_axis, d_type = get_measurable_meta_info(base_var.owner.op)

measurable_dimshuffle = MeasurableDimShuffle(
node.op.input_broadcastable,
node.op.new_order,
ndim_supp=ndim_supp,
support_axis=supp_axis,
d_type=d_type,
supp_axes=supp_axis,
measure_type=d_type,
)(base_var)
measurable_dimshuffle.name = node.outputs[0].name

Expand Down
4 changes: 2 additions & 2 deletions tests/logprob/test_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,13 @@ def test_measurable_elemwise():
d_type = "mixed"

with pytest.raises(TypeError, match=re.escape("scalar_op exp is not valid")):
MeasurableElemwise(exp, ndim_supp=ndim_supp, support_axis=support_axis, d_type=d_type)
MeasurableElemwise(exp, ndim_supp=ndim_supp, supp_axes=support_axis, measure_type=d_type)

class TestMeasurableElemwise(MeasurableElemwise):
valid_scalar_types = (Exp,)

measurable_exp_op = TestMeasurableElemwise(
ndim_supp=ndim_supp, support_axis=support_axis, d_type=d_type, scalar_op=exp
ndim_supp=ndim_supp, supp_axes=support_axis, measure_type=d_type, scalar_op=exp
)
measurable_exp = measurable_exp_op(0.0)
assert isinstance(measurable_exp.owner.op, MeasurableVariable)
Expand Down

0 comments on commit 0ef5aa6

Please # to comment.