diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 993acc0df4..cdbfc5c32b 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -30,11 +30,9 @@ ) import numpy as np -import pytensor from pymc.backends.report import SamplerReport from pymc.model import modelcontext -from pymc.pytensorf import compile from pymc.util import get_var_name logger = logging.getLogger(__name__) @@ -171,10 +169,14 @@ def __init__( if fn is None: # borrow=True avoids deepcopy when inputs=output which is the case for untransformed value variables - fn = compile( - inputs=[pytensor.In(v, borrow=True) for v in model.value_vars], - outputs=[pytensor.Out(v, borrow=True) for v in vars], + fn = model.compile_fn( + inputs=model.value_vars, + outputs=vars, on_unused_input="ignore", + random_seed=False, + borrow_inputs=True, + borrow_outputs=True, + wrap_point_fn=False, ) fn.trust_input = True diff --git a/pymc/initial_point.py b/pymc/initial_point.py index c276a5c496..581d6269f3 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -28,9 +28,8 @@ from pymc.pytensorf import ( SeedSequenceSeed, compile, - find_rng_nodes, replace_rng_nodes, - reseed_rngs, + seed_compiled_function, toposort_replace, ) from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name @@ -167,7 +166,12 @@ def make_initial_point_fn( # Replace original rng shared variables so that we don't mess with them # when calling the final seeded function initial_values = replace_rng_nodes(initial_values) - func = compile(inputs=[], outputs=initial_values, mode=pytensor.compile.mode.FAST_COMPILE) + func = compile( + inputs=[], + outputs=initial_values, + mode=pytensor.compile.mode.FAST_COMPILE, + random_seed=False, + ) varnames = [] for var in model.free_RVs: @@ -179,11 +183,9 @@ def make_initial_point_fn( varnames.append(name) def make_seeded_function(func): - rngs = find_rng_nodes(func.maker.fgraph.outputs) - @functools.wraps(func) def inner(seed, *args, **kwargs): - reseed_rngs(rngs, seed) + seed_compiled_function(func, seed) values = func(*args, **kwargs) return dict(zip(varnames, values)) diff --git a/pymc/model/core.py b/pymc/model/core.py index 469001e804..622f44066e 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -58,11 +58,10 @@ SeedSequenceSeed, compile, convert_observed_data, - gradient, - hessian, inputvars, join_nonshared_inputs, rewrite_pregrad, + seed_compiled_function, ) from pymc.util import ( UNSET, @@ -73,6 +72,8 @@ get_transformed_name, get_value_vars_from_user_vars, get_var_name, + invalidates_memoize, + memoize, treedict, treelist, ) @@ -455,7 +456,8 @@ def __init__( ): self.name = self._validate_name(name) self.check_bounds = check_bounds - self._parent = model if not isinstance(model, _UnsetType) else MODEL_MANAGER.parent_context + self.parent = model if not isinstance(model, _UnsetType) else MODEL_MANAGER.parent_context + self.isroot = self.parent is not None if coords_mutable is not None: warnings.warn( @@ -514,10 +516,6 @@ def get_context( raise TypeError("No model on context stack") return model - @property - def parent(self): - return self._parent - @property def root(self): model = self @@ -525,10 +523,7 @@ def root(self): model = model.parent return model - @property - def isroot(self): - return self.parent is None - + @memoize def logp_dlogp_function( self, grad_vars=None, @@ -574,7 +569,6 @@ def logp_dlogp_function( grad_vars, extra_vars_and_values, model=self, - initial_point=initial_point, ravel_inputs=ravel_inputs, **kwargs, ) @@ -641,6 +635,7 @@ def compile_d2logp( **compile_kwargs, ) + @memoize def logp( self, vars: Variable | Sequence[Variable] | None = None, @@ -720,18 +715,21 @@ def logp( logp_scalar.name = logp_scalar_name return logp_scalar + @memoize def dlogp( self, vars: Variable | Sequence[Variable] | None = None, jacobian: bool = True, - ) -> Variable: + ravel_outputs: bool = True, + return_logp: bool = False, + ) -> list[Variable] | Variable | tuple[Variable, list[Variable] | Variable]: """Gradient of the models log-probability w.r.t. ``vars``. Parameters ---------- - vars : list of random variables or potential terms, optional - Compute the gradient with respect to those variables. If None, use all - free and observed random variables, as well as potential terms in model. + vars : list of random variables, optional + Compute the gradient with respect to those variables. + If None, consider all continuous free variables. jacobian : bool Whether to include jacobian terms in logprob graph. Defaults to True. @@ -740,7 +738,7 @@ def dlogp( dlogp graph """ if vars is None: - value_vars = None + value_vars = self.continuous_value_vars else: if not isinstance(vars, list | tuple): vars = [vars] @@ -757,21 +755,27 @@ def dlogp( cost = self.logp(jacobian=jacobian) cost = rewrite_pregrad(cost) - return gradient(cost, value_vars) - + gradient = pt.grad(cost, value_vars) + if ravel_outputs: + gradient = pt.concatenate([g.reshape(-1) for g in gradient], axis=0) + if return_logp: + return cost, gradient + return gradient + + @memoize def d2logp( self, vars: Variable | Sequence[Variable] | None = None, jacobian: bool = True, - negate_output=True, + negate_output: bool | None = None, ) -> Variable: - """Hessian of the models log-probability w.r.t. ``vars``. + """Hessian of the models log-probability w.r.t. the flattened vector of ``vars``. Parameters ---------- - vars : list of random variables or potential terms, optional - Compute the gradient with respect to those variables. If None, use all - free and observed random variables, as well as potential terms in model. + vars : list of random variables, optional + Compute the hessian with respect to those variables. + If None, consider all continuous free variables. jacobian : bool Whether to include jacobian terms in logprob graph. Defaults to True. @@ -780,7 +784,7 @@ def d2logp( d²logp graph """ if vars is None: - value_vars = None + value_vars = self.continuous_value_vars else: if not isinstance(vars, list | tuple): vars = [vars] @@ -795,9 +799,26 @@ def d2logp( f"Requested variable {var} not found among the model variables" ) - cost = self.logp(jacobian=jacobian) - cost = rewrite_pregrad(cost) - return hessian(cost, value_vars, negate_output=negate_output) + grad = self.dlogp( + vars=[self.values_to_rvs[value] for value in value_vars], + jacobian=jacobian, + ravel_outputs=True, + ) + hess = jacobian(grad, value_vars, vectorize=True) + if negate_output is not None: + if negate_output: + warnings.warn( + "negate_output is deprecated and will fail in a future release. To comply with the API change, set it to None and negate the result manually", + FutureWarning, + ) + hess = -hess + else: + warnings.warn( + "negate_output is deprecated and will fail in a future release. To comply with the API change, set it to None. The result is not negated by default.", + FutureWarning, + ) + + return hess @property def datalogp(self) -> Variable: @@ -812,6 +833,9 @@ def varlogp(self) -> Variable: @property def varlogp_nojac(self) -> Variable: """PyTensor scalar of log-probability of the unobserved random variables (excluding deterministic) without jacobian term.""" + warnings.warn( + "varlogp_nojac is deprecated, use `model.logp(vars=self.free_RVs, jacobian=False)`" + ) return self.logp(vars=self.free_RVs, jacobian=False) @property @@ -824,11 +848,7 @@ def potentiallogp(self) -> Variable: """PyTensor scalar of log-probability of the Potential terms.""" # Convert random variables in Potential expression into their log-likelihood # inputs and apply their transforms, if any - potentials = self.replace_rvs_by_values(self.potentials) - if potentials: - return pt.sum([pt.sum(factor) for factor in potentials]) - else: - return pt.constant(0.0) + return self.logp(vars=self.potentials) @property def value_vars(self): @@ -903,6 +923,10 @@ def dim_lengths(self) -> dict[str, TensorVariable]: return self._dim_lengths def shape_from_dims(self, dims): + warnings.warn( + "model.shape_from_dims is deprecated and will be removed in a future release", + FutureWarning, + ) shape = [] if len(set(dims)) != len(dims): raise ValueError("Can not contain the same dimension name twice.") @@ -917,6 +941,7 @@ def shape_from_dims(self, dims): shape.extend(np.shape(self.coords[dim])) return tuple(shape) + @invalidates_memoize def add_coord( self, name: str, @@ -977,6 +1002,7 @@ def add_coord( self._dim_lengths[name] = length self._coords[name] = values + @invalidates_memoize def add_coords( self, coords: dict[str, Sequence | None], @@ -991,6 +1017,7 @@ def add_coords( for name, values in coords.items(): self.add_coord(name, values, length=lengths.get(name, None)) + @invalidates_memoize def set_dim(self, name: str, new_length: int, coord_values: Sequence | None = None): """Update a mutable dimension. @@ -1026,6 +1053,10 @@ def set_dim(self, name: str, new_length: int, coord_values: Sequence | None = No dim_length.set_value(new_length) return + @memoize + def _make_initial_point(self): + return make_initial_point_fn(model=self, return_transformed=True) + def initial_point(self, random_seed: SeedSequenceSeed = None) -> dict[str, np.ndarray]: """Compute the initial point of the model. @@ -1039,9 +1070,10 @@ def initial_point(self, random_seed: SeedSequenceSeed = None) -> dict[str, np.nd ip : dict of {str : array_like} Maps names of transformed variables to numeric initial values in the transformed space. """ - fn = make_initial_point_fn(model=self, return_transformed=True) + fn = self._make_initial_point() return Point(fn(random_seed), model=self) + @invalidates_memoize def set_initval(self, rv_var, initval): """Set an initial value (strategy) for a random variable.""" if initval is not None and not isinstance(initval, Variable | str): @@ -1050,6 +1082,7 @@ def set_initval(self, rv_var, initval): self.rvs_to_initial_values[rv_var] = initval + @invalidates_memoize def set_data( self, name: str, @@ -1185,6 +1218,7 @@ def set_data( shared_object.set_value(values) + @invalidates_memoize def register_rv( self, rv_var: RandomVariable, @@ -1263,6 +1297,7 @@ def register_rv( return rv_var + @invalidates_memoize def make_obs_var( self, rv_var: TensorVariable, @@ -1362,6 +1397,7 @@ def make_obs_var( return rv_var + @invalidates_memoize def create_value_var( self, rv_var: TensorVariable, @@ -1444,11 +1480,13 @@ def create_value_var( return value_var + @invalidates_memoize def register_data_var(self, data, dims=None): """Register a data variable with the model.""" self.data_vars.append(data) self.add_named_variable(data, dims=dims) + @invalidates_memoize def add_named_variable(self, var, dims: tuple[str | None, ...] | None = None): """Add a random graph variable to the named variables of the model. @@ -1566,6 +1604,7 @@ def copy(self): return clone_model(self) + @memoize def replace_rvs_by_values( self, graphs: Sequence[TensorVariable], @@ -1612,6 +1651,40 @@ def compile_fn( **kwargs, ) -> Function: ... + @memoize + def _compile_fn( + self, + outs: Variable | Sequence[Variable], + *, + inputs: Sequence[Variable] | None = None, + mode=None, + borrow_inputs=False, + borrow_outputs=False, + **kwargs, + ) -> PointFunc | Function: + if inputs is None: + inputs = inputvars(outs) + + if borrow_inputs: + inputs = [pytensor.In(inp, borrow=True) for inp in inputs] + + if borrow_outputs: + if isinstance(outs, list | tuple): + outs = [pytensor.Out(o, borrow=True) for o in outs] + else: + outs = pytensor.Out(outs, borrow=True) + + with self: + return compile( + inputs, + outs, + allow_input_downcast=True, + accept_inplace=True, + mode=mode, + random_seed=False, + **kwargs, + ) + def compile_fn( self, outs: Variable | Sequence[Variable], @@ -1619,6 +1692,8 @@ def compile_fn( inputs: Sequence[Variable] | None = None, mode=None, point_fn: bool = True, + borrow_inputs: bool = False, + borrow_outputs: bool = False, **kwargs, ) -> PointFunc | Function: """Compiles a PyTensor function. @@ -1641,21 +1716,19 @@ def compile_fn( ------- Compiled PyTensor function """ - if inputs is None: - inputs = inputvars(outs) - - with self: - fn = compile( - inputs, - outs, - allow_input_downcast=True, - accept_inplace=True, - mode=mode, - **kwargs, - ) - + random_seed = kwargs.pop("random_seed", None) + fn = self._compile_fn( + outs, + inputs=inputs, + mode=mode, + point_fn=point_fn, + borrow_inputs=borrow_inputs, + borrow_outputs=borrow_outputs, + **kwargs, + ) + seed_compiled_function(fn, random_seed) if point_fn: - return PointFunc(fn) + fn = PointFunc(fn) return fn def profile( @@ -1695,29 +1768,13 @@ def profile( ) if point is None: point = self.initial_point() + point_values = point.values() for _ in range(n): - f(**point) + f(*point_values) return f.profile - def update_start_vals(self, a: dict[str, np.ndarray], b: dict[str, np.ndarray]): - r"""Update point `a` with `b`, without overwriting existing keys. - - Values specified for transformed variables in `a` will be recomputed - conditional on the values of `b` and stored in `b`. - - Parameters - ---------- - a : dict - - b : dict - """ - raise FutureWarning( - "The `Model.update_start_vals` method was removed." - " To change initial values you may set the items of `Model.initial_values` directly." - ) - def eval_rv_shapes(self) -> dict[str, tuple[int, ...]]: """Evaluate shapes of untransformed AND transformed free variables. @@ -1795,7 +1852,7 @@ def check_start_vals(self, start, **kwargs): "You can call `model.debug()` for more details." ) - def point_logps(self, point=None, round_vals=2, **kwargs): + def point_logps(self, point=None, round_vals=2): """Compute the log probability of `point` for all random variables in the model. Parameters @@ -1817,12 +1874,12 @@ def point_logps(self, point=None, round_vals=2, **kwargs): point = self.initial_point() factors = self.basic_RVs + self.potentials - factor_logps_fn = [pt.sum(factor) for factor in self.logp(factors, sum=False)] + factor_logps_fn = self.compile_logp(factors, sum=False) return { - factor.name: np.round(np.asarray(factor_logp), round_vals) + factor.name: np.round(np.asarray(factor_logp.sum()), round_vals) for factor, factor_logp in zip( factors, - self.compile_fn(factor_logps_fn, **kwargs)(point), + factor_logps_fn(point), ) } @@ -2167,6 +2224,10 @@ def compile_fn( ------- Compiled PyTensor function """ + warnings.warn( + "compile_fn is deprecated. Use `model.compile_fn` or `pytensorf.compile` instead.", + FutureWarning, + ) model = modelcontext(model) return model.compile_fn( outs, diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 78eb3f7bbc..432561054e 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -36,6 +36,7 @@ walk, ) from pytensor.graph.fg import FunctionGraph, Output +from pytensor.link.jax import JAXLinker from pytensor.scalar.basic import Cast from pytensor.scan.op import Scan from pytensor.tensor.basic import _as_tensor_variable @@ -50,7 +51,7 @@ from pytensor.tensor.variable import TensorVariable from pymc.exceptions import NotConstantValueError -from pymc.util import makeiter +from pymc.util import _get_seeds_per_chain, makeiter from pymc.vartypes import continuous_types, isgenerator, typefilter PotentialShapeType = int | np.ndarray | Sequence[int | Variable] | TensorVariable @@ -59,7 +60,6 @@ __all__ = [ "CallableTensor", "compile", - "compile_pymc", "cont_inputs", "convert_data", "convert_observed_data", @@ -424,10 +424,11 @@ def make_shared_replacements(point, vars, model): ------- Dict of variable -> new shared variable """ - othervars = set(model.value_vars) - set(vars) + vars_set = set(vars) return { var: pytensor.shared(point[var.name], var.name + "_shared", shape=var.type.shape) - for var in othervars + for var in model.value_vars + if var not in vars_set } @@ -686,6 +687,27 @@ def reseed_rngs( rng.set_value(np.random.Generator(bit_generator), borrow=True) +def seed_compiled_function(function, seed: SeedSequenceSeed): + rng_variables = [ + inp + for inp in function.maker.fgraph.inputs + if isinstance(inp.type, RandomGeneratorSharedVariable) + ] + if rng_variables: + if isinstance(function.maker.linker, JAXLinker): + import jax + + (int_seed,) = _get_seeds_per_chain(seed, 1) + rng_values = jax.random.split(jax.random.key(int_seed), len(rng_variables)) + else: + rng_values = [ + np.random.Generator(np.random.PCG64(sub_seed)) + for sub_seed in np.random.SeedSequence(seed).spawn(len(rng_variables)) + ] + for rng_variable, rng_value in zip(rng_variables, rng_values): + rng_variable.set_value(rng_value, borrow=True) + + def collect_default_updates_inner_fgraph(node: Apply) -> dict[Variable, Variable]: """Collect default updates from node with inner fgraph.""" op = node.op @@ -877,7 +899,7 @@ def find_default_update(clients, rng: Variable) -> None | Variable: def compile( inputs, outputs, - random_seed: SeedSequenceSeed = None, + random_seed: SeedSequenceSeed | bool = None, mode=None, **kwargs, ) -> Function: @@ -926,8 +948,9 @@ def compile( # We always reseed random variables as this provides RNGs with no chances of collision if rng_updates: - rngs = cast(list[SharedVariable], list(rng_updates)) - reseed_rngs(rngs, random_seed) + if random_seed is not False: + rngs = cast(list[SharedVariable], list(rng_updates)) + reseed_rngs(rngs, random_seed) # If called inside a model context, see if check_bounds flag is set to False try: @@ -954,14 +977,6 @@ def compile( return pytensor_function -def compile_pymc(*args, **kwargs): - warnings.warn( - "compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC", - FutureWarning, - ) - return compile(*args, **kwargs) - - def constant_fold( xs: Sequence[TensorVariable], raise_not_constant: bool = True ) -> tuple[np.ndarray | Variable, ...]: diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index b1f9c39895..c82dfa94f6 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -30,9 +30,7 @@ import xarray from arviz import InferenceData -from pytensor import tensor as pt from pytensor.graph.basic import ( - Apply, Constant, Variable, ancestors, @@ -60,7 +58,6 @@ _get_seeds_per_chain, default_progress_theme, get_default_varnames, - point_wrapper, ) __all__ = ( @@ -113,9 +110,9 @@ def compile_forward_sampling_function( outputs: list[Variable], vars_in_trace: list[Variable], basic_rvs: list[Variable] | None = None, - givens_dict: dict[Variable, Any] | None = None, constant_data: dict[str, np.ndarray] | None = None, constant_coords: set[str] | None = None, + model=None, **kwargs, ) -> tuple[Callable[..., np.ndarray | list[np.ndarray]], set[Variable]]: """Compile a function to draw samples, conditioned on the values of some variables. @@ -131,7 +128,6 @@ def compile_forward_sampling_function( - Variables in the outputs list - ``SharedVariable`` instances that are not ``RandomGeneratorSharedVariable``, and whose values changed with respect to what they were at inference time - Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list - - Variables that are keys in the ``givens_dict`` - Variables that have volatile inputs Concretely, this function can be used to compile a function to sample from the @@ -142,15 +138,6 @@ def compile_forward_sampling_function( ignored and new values will be computed (in the case of deterministics and potentials) or sampled (in the case of random variables). - This function also enables a way to impute values for any variable in the computational - graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used - to set the ``givens`` argument of the pytensor function compilation. This will essentially - replace a node in the computational graph with any other expression that has the same - type as the desired node. Passing variables in the givens_dict is considered an intervention - that might lead to different variable values from those that could have been seen during - inference, as such, **any variable that is passed in the ``givens_dict`` will be considered - volatile**. - Parameters ---------- outputs : List[pytensor.graph.basic.Variable] @@ -163,10 +150,6 @@ def compile_forward_sampling_function( be considered as random variable instances. This includes variables that have a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or Censored distributions. - givens_dict : Optional[Dict[pytensor.graph.basic.Variable, Any]] - A dictionary that maps tensor variables to the values that should be used to replace them - in the compiled function. The types of the key and value should match or an error will be - raised during compilation. constant_data : Optional[Dict[str, numpy.ndarray]] A dictionary that maps the names of ``Data`` instances to their corresponding values at inference time. If a model was created with ``Data``, these @@ -195,9 +178,6 @@ def compile_forward_sampling_function( Set of all basic_rvs that were considered volatile and will be resampled when the function is evaluated """ - if givens_dict is None: - givens_dict = {} - if basic_rvs is None: basic_rvs = [] @@ -226,7 +206,6 @@ def shared_value_matches(var): for node in nodes: if ( node in fg.outputs - or node in givens_dict or ( # SharedVariables, except RandomState/Generators isinstance(node, SharedVariable) and not isinstance(node, RandomGeneratorSharedVariable) @@ -263,20 +242,15 @@ def expand(node): # the entire graph list(walk(fg.outputs, expand)) - # Populate the givens list - givens = [ - ( - node, - value - if isinstance(value, Variable | Apply) - else pt.constant(value, dtype=getattr(node, "dtype", None), name=node.name), - ) - for node, value in givens_dict.items() - ] + if model is None: + fn = compile(inputs, fg.outputs, on_unused_input="ignore", **kwargs) + else: + # Go through model to cache function + fn = model.compile_fn(fg.outputs, inputs=inputs, on_unused_input="ignore", **kwargs) return ( - compile(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs), - set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled + fn, + set(basic_rvs) & volatile_nodes, # Basic RVs that will be resampled ) @@ -467,8 +441,6 @@ def sample_prior_predictive( vars_to_sample, vars_in_trace=[], basic_rvs=model.basic_RVs, - givens_dict=None, - random_seed=random_seed, **compile_kwargs, ) @@ -901,17 +873,16 @@ def sample_posterior_predictive( compile_kwargs.setdefault("allow_input_downcast", True) compile_kwargs.setdefault("accept_inplace", True) - _sampler_fn, volatile_basic_rvs = compile_forward_sampling_function( + sampler_fn, volatile_basic_rvs = compile_forward_sampling_function( outputs=vars_to_sample, vars_in_trace=vars_in_trace, basic_rvs=model.basic_RVs, - givens_dict=None, - random_seed=random_seed, constant_data=constant_data, constant_coords=constant_coords, **compile_kwargs, + random_seed=random_seed, + on_unused_input="ignore", ) - sampler_fn = point_wrapper(_sampler_fn) # All model variables have a name, but mypy does not know this _log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}") # type: ignore[arg-type, return-value] ppc_trace_t = _DefaultTrace(samples) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f2dfa6e9c2..6e1bf1fcc0 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -109,7 +109,6 @@ def instantiate_steppers( selected_steps: Mapping[type[BlockedStep], list[Any]], *, step_kwargs: dict[str, dict] | None = None, - initial_point: PointType | None = None, compile_kwargs: dict | None = None, ) -> Step | list[Step]: """Instantiate steppers assigned to the model variables. @@ -141,9 +140,6 @@ def instantiate_steppers( used_keys = set() if selected_steps: - if initial_point is None: - initial_point = model.initial_point() - for step_class, vars in selected_steps.items(): if vars: name = getattr(step_class, "name") @@ -152,7 +148,6 @@ def instantiate_steppers( step = step_class( vars=vars, model=model, - initial_point=initial_point, compile_kwargs=compile_kwargs, **kwargs, ) @@ -769,20 +764,8 @@ def joined_blas_limiter(): rngs = get_random_generator(random_seed).spawn(chains) random_seed_list = [rng.integers(2**30) for rng in rngs] - if not discard_tuned_samples and not return_inferencedata and not isinstance(trace, ZarrTrace): - warnings.warn( - "Tuning samples will be included in the returned `MultiTrace` object, which can lead to" - " complications in your downstream analysis. Please consider to switch to `InferenceData`:\n" - "`pm.sample(..., return_inferencedata=True)`", - UserWarning, - stacklevel=2, - ) - # small trace warning - if draws == 0: - msg = "Tuning was enabled throughout the whole trace." - _log.warning(msg) - elif draws < 100: + if 0 < draws < 100: msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate." _log.warning(msg) @@ -858,7 +841,6 @@ def joined_blas_limiter(): steps=provided_steps, selected_steps=selected_steps, step_kwargs=kwargs, - initial_point=initial_points[0], compile_kwargs=compile_kwargs, ) if isinstance(step, list): @@ -1097,31 +1079,6 @@ def _sample_return( return mtrace -def _check_start_shape(model, start: PointType): - """Check that the prior evaluations and initial points have identical shapes. - - Parameters - ---------- - model : pm.Model - The current model on context. - start : dict - The complete dictionary mapping (transformed) variable names to numeric initial values. - """ - e = "" - try: - actual_shapes = model.eval_rv_shapes() - except NotImplementedError as ex: - warnings.warn(f"Unable to validate shapes: {ex.args[0]}", UserWarning) - return - for name, sval in start.items(): - ashape = actual_shapes.get(name) - sshape = np.shape(sval) - if ashape != tuple(sshape): - e += f"\nExpected shape {ashape} for var '{name}', got: {sshape}" - if e != "": - raise ValueError(f"Bad shape in start point:{e}") - - def _sample_many( *, draws: int, @@ -1595,12 +1552,13 @@ def init_nuts( pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"), ] - logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs) - logp_dlogp_func.trust_input = True + logp_dlogp_func = model.logp_dlogp_function( + ravel_inputs=True, trust_input=True, **compile_kwargs + ) def model_logp_fn(ip: PointType) -> np.ndarray: q, _ = DictToArrayBijection.map(ip) - return logp_dlogp_func([q], extra_vars={})[0] + return logp_dlogp_func(q)[0] initial_points = _init_jitter( model, @@ -1726,14 +1684,7 @@ def model_logp_fn(ip: PointType) -> np.ndarray: else: raise ValueError(f"Unknown initializer: {init}.") - step = pm.NUTS( - potential=potential, - model=model, - rng=random_seed_list[0], - initial_point=initial_points[0], - logp_dlogp_func=logp_dlogp_func, - **kwargs, - ) + step = pm.NUTS(potential=potential, model=model, rng=random_seed_list[0], **kwargs) # Filter deterministics from initial_points value_var_names = [var.name for var in model.value_vars] diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index 0c20e09a47..9367fa78de 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -179,26 +179,22 @@ def __init__( model=None, blocked: bool = True, dtype=None, - logp_dlogp_func=None, rng: RandomGenerator = None, - initial_point: PointType | None = None, compile_kwargs: dict | None = None, **pytensor_kwargs, ): model = modelcontext(model) - if logp_dlogp_func is None: - if compile_kwargs is None: - compile_kwargs = {} - logp_dlogp_func = model.logp_dlogp_function( - vars, - dtype=dtype, - ravel_inputs=True, - initial_point=initial_point, - **compile_kwargs, - **pytensor_kwargs, - ) - logp_dlogp_func.trust_input = True + if compile_kwargs is None: + compile_kwargs = {} + logp_dlogp_func = model.logp_dlogp_function( + vars, + dtype=dtype, + ravel_inputs=True, + trust_input=True, + **compile_kwargs, + **pytensor_kwargs, + ) self._logp_dlogp_func = logp_dlogp_func diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index e8c96e8c4b..9d0a171d4f 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -22,7 +22,7 @@ import numpy as np -from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType +from pymc.blocking import DictToArrayBijection, RaveledVars, StatsType from pymc.exceptions import SamplingError from pymc.model import Point, modelcontext from pymc.pytensorf import floatX @@ -98,7 +98,6 @@ def __init__( adapt_step_size=True, step_rand=None, rng=None, - initial_point: PointType | None = None, **pytensor_kwargs, ): """Set up Hamiltonian samplers with common structures. @@ -144,7 +143,6 @@ def __init__( model=self._model, dtype=dtype, rng=rng, - initial_point=initial_point, **pytensor_kwargs, ) @@ -152,9 +150,7 @@ def __init__( self.Emax = Emax self.iter_count = 0 - if initial_point is None: - initial_point = self._model.initial_point() - + initial_point = self._model.initial_point() nuts_vars = [initial_point[v.name] for v in vars] size = sum(v.size for v in nuts_vars) diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 70c650653d..5d78a6708e 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -30,7 +30,6 @@ import pymc as pm from pymc.blocking import DictToArrayBijection, RaveledVars -from pymc.initial_point import PointType from pymc.pytensorf import ( CallableTensor, compile, @@ -163,7 +162,6 @@ def __init__( model=None, mode=None, rng=None, - initial_point: PointType | None = None, compile_kwargs: dict | None = None, blocked: bool = False, ): @@ -194,8 +192,7 @@ def __init__( :py:func:`pymc.util.get_random_generator` for more information. """ model = pm.modelcontext(model) - if initial_point is None: - initial_point = model.initial_point() + initial_point = model.initial_point() if vars is None: vars = model.value_vars @@ -466,7 +463,6 @@ def __init__( tune_interval=100, model=None, rng=None, - initial_point: PointType | None = None, compile_kwargs: dict | None = None, blocked: bool = True, ): @@ -591,7 +587,6 @@ def __init__( transit_p=0.8, model=None, rng=None, - initial_point: PointType | None = None, compile_kwargs: dict | None = None, blocked: bool = True, ): @@ -605,8 +600,7 @@ def __init__( vars = get_value_vars_from_user_vars(vars, model) - if initial_point is None: - initial_point = model.initial_point() + initial_point = model.initial_point() self.dim = sum(initial_point[v.name].size for v in vars) if order == "random": @@ -713,7 +707,6 @@ def __init__( order="random", model=None, rng: RandomGenerator = None, - initial_point: PointType | None = None, compile_kwargs: dict | None = None, blocked: bool = True, ): @@ -721,8 +714,7 @@ def __init__( vars = get_value_vars_from_user_vars(vars, model) - if initial_point is None: - initial_point = model.initial_point() + initial_point = model.initial_point() dimcats: list[tuple[int, int]] = [] # The above variable is a list of pairs (aggregate dimension, number @@ -948,13 +940,11 @@ def __init__( model=None, mode=None, rng=None, - initial_point: PointType | None = None, compile_kwargs: dict | None = None, blocked: bool = True, ): model = pm.modelcontext(model) - if initial_point is None: - initial_point = model.initial_point() + initial_point = model.initial_point() initial_values_size = sum(initial_point[n.name].size for n in model.value_vars) if vars is None: @@ -1118,15 +1108,13 @@ def __init__( tune_interval=100, tune_drop_fraction: float = 0.9, model=None, - initial_point: PointType | None = None, compile_kwargs: dict | None = None, mode=None, rng=None, blocked: bool = True, ): model = pm.modelcontext(model) - if initial_point is None: - initial_point = model.initial_point() + initial_point = model.initial_point() initial_values_size = sum(initial_point[n.name].size for n in model.value_vars) if vars is None: diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 9c10acfdf4..e27e1f87a2 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -21,7 +21,6 @@ from rich.table import Column from pymc.blocking import RaveledVars, StatsType -from pymc.initial_point import PointType from pymc.model import modelcontext from pymc.pytensorf import compile, join_nonshared_inputs, make_shared_replacements from pymc.step_methods.arraystep import ArrayStepShared @@ -88,7 +87,6 @@ def __init__( model=None, iter_limit=np.inf, rng=None, - initial_point: PointType | None = None, compile_kwargs: dict | None = None, blocked: bool = False, # Could be true since tuning is independent across dims? ): @@ -103,8 +101,7 @@ def __init__( else: vars = get_value_vars_from_user_vars(vars, model) - if initial_point is None: - initial_point = model.initial_point() + initial_point = model.initial_point() shared = make_shared_replacements(initial_point, vars, model) [logp], raveled_inp = join_nonshared_inputs( diff --git a/pymc/util.py b/pymc/util.py index 979b3beebf..fe3791f3a8 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -399,7 +399,7 @@ def __setstate__(self, state): self.__dict__.update(state) -def locally_cachedmethod(f): +def memoize(f): from collections import defaultdict def self_cache_fn(f_name): @@ -411,6 +411,16 @@ def cf(self): return cachedmethod(self_cache_fn(f.__name__), key=hash_key)(f) +def invalidates_memoize(f): + @functools.wraps(f) + def wrapper_fn(self, *args, **kwargs): + if cache := getattr(self, "_cache", None): + cache.clear() + return f(self, *args, **kwargs) + + return wrapper_fn + + def check_dist_not_registered(dist, model=None): """Check that a dist is not registered in the model already.""" from pymc.model import modelcontext diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index deedfc8d9f..ce6446b110 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -82,8 +82,8 @@ RandomState, WithMemoization, _get_seeds_per_chain, - locally_cachedmethod, makeiter, + memoize, ) from pymc.variational.minibatch_rv import MinibatchRandomVariable, get_scaling from pymc.variational.updates import adagrad_window @@ -150,12 +150,12 @@ def node_property(f): def wrapper(fn): ff = append_name(f)(fn) f_ = pytensor.config.change_flags(compute_test_value="off")(ff) - return property(locally_cachedmethod(f_)) + return property(memoize(f_)) return wrapper else: f_ = pytensor.config.change_flags(compute_test_value="off")(f) - return property(locally_cachedmethod(f_)) + return property(memoize(f_)) @pytensor.config.change_flags(compute_test_value="ignore") diff --git a/pymc/variational/stein.py b/pymc/variational/stein.py index 0534bb6fa4..f0ad019e18 100644 --- a/pymc/variational/stein.py +++ b/pymc/variational/stein.py @@ -17,7 +17,7 @@ from pytensor.graph.replace import graph_replace from pymc.pytensorf import floatX -from pymc.util import WithMemoization, locally_cachedmethod +from pymc.util import WithMemoization, memoize from pymc.variational.opvi import node_property from pymc.variational.test_functions import rbf @@ -93,6 +93,6 @@ def logp_norm(self): ) return sized_symbolic_logp / self.approx.symbolic_normalizing_constant - @locally_cachedmethod + @memoize def _kernel(self): return self._kernel_f(self.input_joint_matrix) diff --git a/tests/model/test_core.py b/tests/model/test_core.py index b26a9d96b7..0c03f15d74 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -1855,3 +1855,21 @@ def test_guassian_process_copy_failure(self, copy_method) -> None: match="Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883", ): copy_method(gaussian_process_model) + + +@pytest.mark.parametrize() +def test_memoization(): + with pm.Model() as m: + x = pm.Normal("x") + y = pm.Normal("y") + + res1 = m.logp() + res2 = m.logp() + res3 = m.logp(sum=False) + + res4 = m.logp() + assert res1 is res2 + assert res1 is not res3 + assert res1 is res4 + + m.invalidate_cache() diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 090b76130b..be69f21346 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -212,22 +212,6 @@ def test_reset_tuning(self): assert step.potential._n_samples == tune assert step.step_adapt._count == tune + 1 - @pytest.mark.parametrize( - "start, error", - [ - ({"x": 1}, ValueError), - ({"x": [1, 2, 3]}, ValueError), - ({"x": np.array([[1, 1], [1, 1]])}, ValueError), - ], - ) - def test_sample_start_bad_shape(self, start, error): - with pytest.raises(error): - pm.sampling.mcmc._check_start_shape(self.model, start) - - @pytest.mark.parametrize("start", [{"x": np.array([1, 1])}, {"x": [10, 10]}, {"x": [-10, -10]}]) - def test_sample_start_good_shape(self, start): - pm.sampling.mcmc._check_start_shape(self.model, start) - def test_sample_callback(self): callback = mock.Mock() test_cores = [1, 2] diff --git a/tests/test_util.py b/tests/test_util.py index 98cc168f0e..f405ea7cde 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -30,7 +30,7 @@ get_value_vars_from_user_vars, hash_key, hashable, - locally_cachedmethod, + memoize, ) @@ -138,7 +138,7 @@ def some_func(x): assert some_func(b1) != some_func(b2) class TestClass: - @locally_cachedmethod + @memoize def some_method(self, x): return x