-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow for batched alpha
in StickBreakingWeights
#6042
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #6042 +/- ##
==========================================
- Coverage 89.27% 87.44% -1.83%
==========================================
Files 72 72
Lines 12890 12946 +56
==========================================
- Hits 11507 11321 -186
- Misses 1383 1625 +242
|
Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com>
ab3d6e2
to
7fde39a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good start. I found some issues about the interpretation of size in conjugation with batched alphas
alpha
to take batched data for StickBreakingWeights
alpha
in StickBreakingWeights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @purna135! :)
pymc/tests/test_distributions.py
Outdated
def test_stickbreakingweights_logp(self, value, alpha, K, logp): | ||
with Model() as model: | ||
def test_stickbreakingweights_logp(self, alpha, K, _compile_stickbreakingweights_logpdf): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it acceptable to combine test_stickbreakingweights_logp
and test_stickbreakingweights_vectorized
in a single test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If test_stickbreakingweights_vectorized
fails, it would point to shapes not being handled properly to form batches. So, lets keep them separate to have better isolation of tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I separate the test for batched alpha
I need some assistance calculating |
Left a comment above about the fixture. Also don't forget @larryshamalama remark above that we should test the pymc/pymc/tests/test_distributions_moments.py Lines 1171 to 1174 in 7af102d
Should be enough to test with a vector of two alphas, maybe one of those that is already tested for single alpha (reusing the same k) and the other being an extreme value like alpha=1 or alpha=0 (if that's valid), which might have a very simple moment. |
Yes, I got the test for pymc/pymc/tests/test_distributions_moments.py Lines 1147 to 1151 in 7af102d
|
You can check what the moment is for two distinct single alphas, and it should be the same for a batched alpha that has those two values. |
Ok got it now, do I need to create a separate test for batched alpha as we did in |
Nope, you can just add it as an extra condition in the existing tests. Moments is less sensitive than logp so we can keep it bundled together |
Looks complete to me. The failing test is unrelated. Just asked if @larryshamalama could leave a review as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great @purna135! Just one question for myself 😅
(np.arange(1, 7, dtype="float64").reshape(2, 3), 5), | ||
], | ||
) | ||
def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is stickbreakingweights_logpdf
passed as an argument here via the fixture sharing decorator of the function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's how fixtures work
What is this PR about?
Addressing #5383
This enables
StickBreakingWeight
'salpha
to accept batched data (>2D), make theinfer_shape
work with batched data, and fix therng_fn
by broadcasting alpha to K.Checklist
Major / Breaking Changes
Bugfixes / New features
StickBreakingWeights
now supports batchedalpha
parametersDocs / Maintenance