Skip to content

Allow for batched alpha in StickBreakingWeights #6042

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 7 commits into from
Aug 31, 2022

Conversation

purna135
Copy link
Member

@purna135 purna135 commented Aug 9, 2022

What is this PR about?
Addressing #5383
This enables StickBreakingWeight's alpha to accept batched data (>2D), make the infer_shape work with batched data, and fix the rng_fn by broadcasting alpha to K.

Checklist

Major / Breaking Changes

  • ...

Bugfixes / New features

  • StickBreakingWeights now supports batched alpha parameters

Docs / Maintenance

  • ...

@codecov
Copy link

codecov bot commented Aug 9, 2022

Codecov Report

Merging #6042 (2a5df64) into main (ad16bf4) will decrease coverage by 1.82%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6042      +/-   ##
==========================================
- Coverage   89.27%   87.44%   -1.83%     
==========================================
  Files          72       72              
  Lines       12890    12946      +56     
==========================================
- Hits        11507    11321     -186     
- Misses       1383     1625     +242     
Impacted Files Coverage Δ
pymc/distributions/multivariate.py 91.25% <100.00%> (-0.75%) ⬇️
pymc/distributions/timeseries.py 43.36% <0.00%> (-35.28%) ⬇️
pymc/model_graph.py 65.66% <0.00%> (-29.80%) ⬇️
pymc/model.py 76.14% <0.00%> (-12.06%) ⬇️
pymc/step_methods/hmc/quadpotential.py 73.76% <0.00%> (-6.94%) ⬇️
pymc/util.py 75.29% <0.00%> (-2.36%) ⬇️
pymc/distributions/discrete.py 97.65% <0.00%> (-1.57%) ⬇️
pymc/step_methods/hmc/base_hmc.py 89.76% <0.00%> (-0.79%) ⬇️
pymc/gp/gp.py 92.73% <0.00%> (-0.45%) ⬇️
... and 9 more

Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com>
@purna135 purna135 force-pushed the generalize_StickBreaking branch from ab3d6e2 to 7fde39a Compare August 9, 2022 20:22
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good start. I found some issues about the interpretation of size in conjugation with batched alphas

@ricardoV94 ricardoV94 changed the title allow alpha to take batched data for StickBreakingWeights Allow for batched alpha in StickBreakingWeights Aug 10, 2022
Copy link
Member

@larryshamalama larryshamalama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @purna135! :)

@larryshamalama
Copy link
Member

larryshamalama commented Aug 16, 2022

Would the moment and logp methods in the distribution class need appropriate broadcasting for vector-valued alphas? I am just reading over this PR and might have missed previous discussions

@ricardoV94
Copy link
Member

Would the moment and logp methods in the distribution class need appropriate broadcasting for vector-valued alphas? I am just reading over this PR and might have missed previous discussions

We have tests for batched alpha, but not moment (we should)

Comment on lines 2296 to 2300
def test_stickbreakingweights_logp(self, value, alpha, K, logp):
with Model() as model:
def test_stickbreakingweights_logp(self, alpha, K, _compile_stickbreakingweights_logpdf):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it acceptable to combine test_stickbreakingweights_logp and test_stickbreakingweights_vectorized in a single test?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If test_stickbreakingweights_vectorized fails, it would point to shapes not being handled properly to form batches. So, lets keep them separate to have better isolation of tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I separate the test for batched alpha

@purna135
Copy link
Member Author

We have tests for batched alpha, but not moment (we should)

I need some assistance calculating expected in test_stickbreakingweights_moment.

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 25, 2022

Left a comment above about the fixture. Also don't forget @larryshamalama remark above that we should test the moment function works for batched alpha as well. The existing tests are in here:

def test_stickbreakingweights_moment(alpha, K, size, expected):
with Model() as model:
StickBreakingWeights("x", alpha=alpha, K=K, size=size)
assert_moment_is_expected(model, expected)

Should be enough to test with a vector of two alphas, maybe one of those that is already tested for single alpha (reusing the same k) and the other being an extreme value like alpha=1 or alpha=0 (if that's valid), which might have a very simple moment.

@purna135
Copy link
Member Author

Yes, I got the test for moment but I am not sure how the expected is calculated here.
Is there any equation to determine the expected?

@pytest.mark.parametrize(
"alpha, K, size, expected",
[
(3, 11, None, np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11)),
(5, 19, None, np.append((5 / 6) ** np.arange(19) * 1 / 6, (5 / 6) ** 19)),

@ricardoV94
Copy link
Member

Yes, I got the test for moment but I am not sure how the expected is calculated here. Is there any equation to determine the expected?

@pytest.mark.parametrize(
"alpha, K, size, expected",
[
(3, 11, None, np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11)),
(5, 19, None, np.append((5 / 6) ** np.arange(19) * 1 / 6, (5 / 6) ** 19)),

You can check what the moment is for two distinct single alphas, and it should be the same for a batched alpha that has those two values.

@purna135
Copy link
Member Author

Ok got it now, do I need to create a separate test for batched alpha as we did in TestStickBreakingWeights_1D_alpha ?

@ricardoV94
Copy link
Member

Ok got it now, do I need to create a separate test for batched alpha as we did in TestStickBreakingWeights_1D_alpha ?

Nope, you can just add it as an extra condition in the existing tests. Moments is less sensitive than logp so we can keep it bundled together

@ricardoV94
Copy link
Member

Looks complete to me. The failing test is unrelated. Just asked if @larryshamalama could leave a review as well.

Copy link
Member

@larryshamalama larryshamalama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great @purna135! Just one question for myself 😅

(np.arange(1, 7, dtype="float64").reshape(2, 3), 5),
],
)
def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is stickbreakingweights_logpdf passed as an argument here via the fixture sharing decorator of the function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's how fixtures work

@ricardoV94 ricardoV94 merged commit 0b191ad into pymc-devs:main Aug 31, 2022
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants