From 1c3954038364512c6a05d362c2a9e2f6ca819e42 Mon Sep 17 00:00:00 2001 From: rishab Date: Mon, 24 Feb 2025 17:20:45 +0530 Subject: [PATCH 1/5] raise when sampling a multinomial --- pymc/distributions/multivariate.py | 6 ++++++ tests/distributions/test_distribution.py | 9 ++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 32f9e30f06..d9099e4492 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -619,6 +619,12 @@ def dist(cls, n, p, *args, **kwargs): return super().dist([n, p], *args, **kwargs) def support_point(rv, size, n, p): + observed = getattr(rv.tag, "observed", None) + if observed is None: + raise ValueError( + "Latent Multinomial variables are not supported for sampling. " + "Use a Categorical variable instead." + ) n = pt.shape_padright(n) mean = n * p mode = pt.round(mean) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index df97905073..74f65546db 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -82,7 +82,14 @@ def test_issue_4499(self): with pm.Model(check_bounds=False) as m: x = pm.DiracDelta("x", 1, size=10) npt.assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), 0 * 10) - + def test_issue_7548(self): + #Test for bug in Multinomial, it should raise when trying to sample a Multinomial variable + with pm.Model() as model: + p = [0.3, 0.4, 0.3] + n = 10 + x = pm.Multinomial("x", n=n, p=p) + with pytest.raises(ValueError, match="Latent Multinomial variables are not supported"): + pm.sample(draws=100, chains=1) def test_all_distributions_have_support_points(): import pymc.distributions as dist_module From 762150823b31e2d5c6601c55d3f8d7562a98d146 Mon Sep 17 00:00:00 2001 From: rishab Date: Mon, 24 Feb 2025 17:27:08 +0530 Subject: [PATCH 2/5] minor change --- tests/distributions/test_distribution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 74f65546db..3e1ca20e29 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -82,8 +82,9 @@ def test_issue_4499(self): with pm.Model(check_bounds=False) as m: x = pm.DiracDelta("x", 1, size=10) npt.assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), 0 * 10) + def test_issue_7548(self): - #Test for bug in Multinomial, it should raise when trying to sample a Multinomial variable + # Test for bug in Multinomial, it should raise when trying to sample a Multinomial variable with pm.Model() as model: p = [0.3, 0.4, 0.3] n = 10 @@ -91,6 +92,7 @@ def test_issue_7548(self): with pytest.raises(ValueError, match="Latent Multinomial variables are not supported"): pm.sample(draws=100, chains=1) + def test_all_distributions_have_support_points(): import pymc.distributions as dist_module From c3a1fe5a1d7db3ff1a080af3387b8274eba1d238 Mon Sep 17 00:00:00 2001 From: rishab Date: Mon, 24 Feb 2025 20:47:34 +0530 Subject: [PATCH 3/5] added cannot_sample_rv --- pymc/distributions/multivariate.py | 6 ------ pymc/sampling/mcmc.py | 9 +++++++++ pymc/step_methods/cannot_sample.py | 21 +++++++++++++++++++++ 3 files changed, 30 insertions(+), 6 deletions(-) create mode 100644 pymc/step_methods/cannot_sample.py diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index d9099e4492..32f9e30f06 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -619,12 +619,6 @@ def dist(cls, n, p, *args, **kwargs): return super().dist([n, p], *args, **kwargs) def support_point(rv, size, n, p): - observed = getattr(rv.tag, "observed", None) - if observed is None: - raise ValueError( - "Latent Multinomial variables are not supported for sampling. " - "Use a Categorical variable instead." - ) n = pt.shape_padright(n) mean = n * p mode = pt.round(mean) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f2dfa6e9c2..2e34ab6c0c 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -51,6 +51,7 @@ from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains from pymc.backends.zarr import ZarrChain, ZarrTrace from pymc.blocking import DictToArrayBijection +from pymc.distributions.multivariate import Multinomial from pymc.exceptions import SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain from pymc.model import Model, modelcontext @@ -63,6 +64,7 @@ ) from pymc.step_methods import NUTS, CompoundStep from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared +from pymc.step_methods.cannot_sample import CannotSampleRV from pymc.step_methods.hmc import quadpotential from pymc.util import ( ProgressBarManager, @@ -144,6 +146,13 @@ def instantiate_steppers( if initial_point is None: initial_point = model.initial_point() + for rv in model.free_RVs: + if isinstance(rv.owner.op, Multinomial) and getattr(rv.tag, "observed", None) is None: + for step_class in list(selected_steps.keys()): + if rv in selected_steps[step_class]: + selected_steps[step_class].remove(rv) + selected_steps.setdefault(CannotSampleRV, []).append(rv) + for step_class, vars in selected_steps.items(): if vars: name = getattr(step_class, "name") diff --git a/pymc/step_methods/cannot_sample.py b/pymc/step_methods/cannot_sample.py new file mode 100644 index 0000000000..0e92912c4c --- /dev/null +++ b/pymc/step_methods/cannot_sample.py @@ -0,0 +1,21 @@ +from pymc.step_methods.arraystep import ArrayStep + +class CannotSampleRV(ArrayStep): + """ + A step method that raises an error when sampling a latent Multinomial variable. + """ + name = "cannot_sample_rv" + def __init__(self, vars, **kwargs): + # Remove keys that ArrayStep.__init__ does not accept. + kwargs.pop("model", None) + kwargs.pop("initial_point", None) + kwargs.pop("compile_kwargs", None) + self.vars = vars + super().__init__(vars=vars,fs=[], **kwargs) + + def astep(self, q0): + # This method is required by the abstract base class. + raise ValueError( + "Latent Multinomial variables are not supported" + ) + From 3d58e0bd82c6593a50c440d2d788da2d3deb4ae1 Mon Sep 17 00:00:00 2001 From: rishab Date: Mon, 24 Feb 2025 20:48:28 +0530 Subject: [PATCH 4/5] formatting --- pymc/step_methods/cannot_sample.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/pymc/step_methods/cannot_sample.py b/pymc/step_methods/cannot_sample.py index 0e92912c4c..5a4a8e4d90 100644 --- a/pymc/step_methods/cannot_sample.py +++ b/pymc/step_methods/cannot_sample.py @@ -1,21 +1,34 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pymc.step_methods.arraystep import ArrayStep + class CannotSampleRV(ArrayStep): """ A step method that raises an error when sampling a latent Multinomial variable. """ + name = "cannot_sample_rv" + def __init__(self, vars, **kwargs): # Remove keys that ArrayStep.__init__ does not accept. kwargs.pop("model", None) kwargs.pop("initial_point", None) kwargs.pop("compile_kwargs", None) self.vars = vars - super().__init__(vars=vars,fs=[], **kwargs) + super().__init__(vars=vars, fs=[], **kwargs) def astep(self, q0): # This method is required by the abstract base class. - raise ValueError( - "Latent Multinomial variables are not supported" - ) - + raise ValueError("Latent Multinomial variables are not supported") From c27f28609a9b0d15522b11fea68f71f89800c991 Mon Sep 17 00:00:00 2001 From: rishab Date: Mon, 24 Feb 2025 20:53:35 +0530 Subject: [PATCH 5/5] minor formmatting --- pymc/step_methods/cannot_sample.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pymc/step_methods/cannot_sample.py b/pymc/step_methods/cannot_sample.py index 5a4a8e4d90..8fe086ee72 100644 --- a/pymc/step_methods/cannot_sample.py +++ b/pymc/step_methods/cannot_sample.py @@ -15,9 +15,7 @@ class CannotSampleRV(ArrayStep): - """ - A step method that raises an error when sampling a latent Multinomial variable. - """ + """A step method that raises an error when sampling a latent Multinomial variable.""" name = "cannot_sample_rv"