diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f2dfa6e9c2..2e34ab6c0c 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -51,6 +51,7 @@ from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains from pymc.backends.zarr import ZarrChain, ZarrTrace from pymc.blocking import DictToArrayBijection +from pymc.distributions.multivariate import Multinomial from pymc.exceptions import SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain from pymc.model import Model, modelcontext @@ -63,6 +64,7 @@ ) from pymc.step_methods import NUTS, CompoundStep from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared +from pymc.step_methods.cannot_sample import CannotSampleRV from pymc.step_methods.hmc import quadpotential from pymc.util import ( ProgressBarManager, @@ -144,6 +146,13 @@ def instantiate_steppers( if initial_point is None: initial_point = model.initial_point() + for rv in model.free_RVs: + if isinstance(rv.owner.op, Multinomial) and getattr(rv.tag, "observed", None) is None: + for step_class in list(selected_steps.keys()): + if rv in selected_steps[step_class]: + selected_steps[step_class].remove(rv) + selected_steps.setdefault(CannotSampleRV, []).append(rv) + for step_class, vars in selected_steps.items(): if vars: name = getattr(step_class, "name") diff --git a/pymc/step_methods/cannot_sample.py b/pymc/step_methods/cannot_sample.py new file mode 100644 index 0000000000..8fe086ee72 --- /dev/null +++ b/pymc/step_methods/cannot_sample.py @@ -0,0 +1,32 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pymc.step_methods.arraystep import ArrayStep + + +class CannotSampleRV(ArrayStep): + """A step method that raises an error when sampling a latent Multinomial variable.""" + + name = "cannot_sample_rv" + + def __init__(self, vars, **kwargs): + # Remove keys that ArrayStep.__init__ does not accept. + kwargs.pop("model", None) + kwargs.pop("initial_point", None) + kwargs.pop("compile_kwargs", None) + self.vars = vars + super().__init__(vars=vars, fs=[], **kwargs) + + def astep(self, q0): + # This method is required by the abstract base class. + raise ValueError("Latent Multinomial variables are not supported") diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index df97905073..3e1ca20e29 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -83,6 +83,15 @@ def test_issue_4499(self): x = pm.DiracDelta("x", 1, size=10) npt.assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), 0 * 10) + def test_issue_7548(self): + # Test for bug in Multinomial, it should raise when trying to sample a Multinomial variable + with pm.Model() as model: + p = [0.3, 0.4, 0.3] + n = 10 + x = pm.Multinomial("x", n=n, p=p) + with pytest.raises(ValueError, match="Latent Multinomial variables are not supported"): + pm.sample(draws=100, chains=1) + def test_all_distributions_have_support_points(): import pymc.distributions as dist_module