-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Adding meta-information for MeasurableOps #6754
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6754 +/- ##
==========================================
+ Coverage 92.02% 92.03% +0.01%
==========================================
Files 95 95
Lines 16261 16302 +41
==========================================
+ Hits 14964 15004 +40
- Misses 1297 1298 +1
|
Though all the individual tests in tests/logprob run successfully (except the broken def get_default_measurable_metainfo(base_op: MeasurableVariable, base_dtype) -> Tuple[Any, Union[Tuple[Any, ...], Any], Union[MeasureType, Any]]:
if not isinstance(base_op, MeasurableVariable):
raise TypeError("base_op must be a RandomVariable or MeasurableVariable")
ndim_supp = base_op.ndim_supp
supp_axes = getattr(base_op, "supp_axes", None)
if supp_axes is None:
supp_axes = tuple(range(-base_op.ndim_supp, 0))
measure_type = getattr(base_op, "measure_type", None)
if measure_type is None:
measure_type = (
MeasureType.Discrete if base_dtype.dtype.startswith("int") else MeasureType.Continuous
)
return ndim_supp, supp_axes, measure_type However, we still face |
@Dhruvanshu-Joshi I pushed a commit that uses multiple inheritance to incorporate the metainfo in the MeasurableVariable subclasses. This avoids having to re-define the I haven't tested at all, and each rewrite should be checked manually to make sure we are doing the right thing. Special attention should be given to Ops with multiple measurable outputs (Scan, IfElse), as we need to preserve the meta-info for each output. After checking and cleaning up the code, a good test case would be to remove the current limitation on Dimshuffles of non-pure RVs as mentioned in #6360. This should be done in a separate commit! :) Feel free to ask any questions about the code I pushed (you can leave comments directly on the changed lines here on Github) |
@ricardoV94 I have made some changes so that all the logprob tests pass. |
Also I am working on solving the merge conflicts by referring the PR #6746 and will update if doing so solves the |
That means we just need to pass the Op instead in those cases, no? |
0ef5aa6
to
31a8deb
Compare
Are we not doing that by passing Also, for the error in the return type of |
78b7158
to
7d2fc53
Compare
7d2fc53
to
118be0f
Compare
What is this PR about?
This PR aims to solve issue #6360 and is a cotinuation of the PR #6685 by incorporating RV meta information in intermediate MeasurableVariables. The Measurable ops covered are MeasurableComparison, MeasurableClip, MeasurableRound, MeasurableSpecifyShape, MeasurableCheckAndRaise, MeasurableIfElse, MeasurableScan, MeasurableMakeVector, MeasurableJoin, MeasurableDimShuffle, MeasurableTransforms and DiracDelta.
I understand that this PR does not align with the commendable changes made in PR #6746. However, I just want a review from the maintainers on the changes made here and if it agrees with what they had in mind. I'll make all the necessary changes required for this PR to align with #6746 once it is merged.
Checklist
Major / Breaking Changes
New features
Bugfixes
Documentation
Maintenance
📚 Documentation preview 📚: https://pymc--6754.org.readthedocs.build/en/6754/