From a8769770e5c4aa1ecccff44e24b2ef11abf827d2 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 5 Mar 2025 14:23:35 +0100 Subject: [PATCH] Do not monkey-patch Ipython pretty representation on model variables --- pymc/distributions/distribution.py | 9 ----- pymc/model/core.py | 46 +++++++---------------- pymc/printing.py | 60 +++++++----------------------- tests/test_printing.py | 27 ++++++++------ 4 files changed, 42 insertions(+), 100 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 5ec5df4671..8cb2922798 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextvars -import functools import re import sys -import types import warnings from abc import ABCMeta @@ -53,7 +51,6 @@ from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob from pymc.logprob.basic import logp from pymc.logprob.rewriting import logprob_rewrites_db -from pymc.printing import str_for_dist from pymc.pytensorf import ( collect_default_updates_inner_fgraph, constant_fold, @@ -506,12 +503,6 @@ def __new__( default_transform=default_transform, initval=initval, ) - - # add in pretty-printing support - rv_out.str_repr = types.MethodType(str_for_dist, rv_out) - rv_out._repr_latex_ = types.MethodType( - functools.partial(str_for_dist, formatting="latex"), rv_out - ) return rv_out @classmethod diff --git a/pymc/model/core.py b/pymc/model/core.py index 5a7b2651cf..684e039b6d 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -16,7 +16,6 @@ import functools import sys import threading -import types import warnings from collections.abc import Iterable, Sequence @@ -496,13 +495,6 @@ def __init__( for name, values in coords_mutable.items(): self.add_coord(name, values, mutable=True) - from pymc.printing import str_for_model - - self.str_repr = types.MethodType(str_for_model, self) - self._repr_latex_ = types.MethodType( - functools.partial(str_for_model, formatting="latex"), self - ) - @classmethod def get_context( cls, error_if_none: bool = True, allow_block_model_access: bool = False @@ -2026,6 +2018,19 @@ def to_graphviz( dpi=dpi, ) + def _repr_pretty_(self, p, cycle): + from pymc.printing import str_for_model + + output = str_for_model(self) + # Find newlines and replace them with p.break_() + # (see IPython.lib.pretty._repr_pprint) + lines = output.splitlines() + with p.group(): + for idx, output_line in enumerate(lines): + if idx: + p.break_() + p.text(output_line) + class BlockModelAccess(Model): """Can be used to prevent user access to Model contexts.""" @@ -2252,19 +2257,6 @@ def Deterministic(name, var, model=None, dims=None): var = var.copy(model.name_for(name)) model.deterministics.append(var) model.add_named_variable(var, dims) - - from pymc.printing import str_for_potential_or_deterministic - - var.str_repr = types.MethodType( - functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var - ) - var._repr_latex_ = types.MethodType( - functools.partial( - str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex" - ), - var, - ) - return var @@ -2377,16 +2369,4 @@ def normal_logp(value, mu, sigma): model.potentials.append(var) model.add_named_variable(var, dims) - from pymc.printing import str_for_potential_or_deterministic - - var.str_repr = types.MethodType( - functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var - ) - var._repr_latex_ = types.MethodType( - functools.partial( - str_for_potential_or_deterministic, dist_name="Potential", formatting="latex" - ), - var, - ) - return var diff --git a/pymc/printing.py b/pymc/printing.py index c4376d306e..93638eaa8c 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -35,7 +35,10 @@ def str_for_dist( - dist: TensorVariable, formatting: str = "plain", include_params: bool = True + dist: TensorVariable, + formatting: str = "plain", + include_params: bool = True, + model: Model | None = None, ) -> str: """Make a human-readable string representation of a Distribution in a model. @@ -47,12 +50,12 @@ def str_for_dist( dist.owner.op, "extended_signature", None ): dist_args = [ - _str_for_input_var(x, formatting=formatting) + _str_for_input_var(x, formatting=formatting, model=model) for x in dist.owner.op.dist_params(dist.owner) ] else: dist_args = [ - _str_for_input_var(x, formatting=formatting) + _str_for_input_var(x, formatting=formatting, model=model) for x in dist.owner.inputs if not isinstance(x.type, RandomType | NoneTypeT) ] @@ -106,7 +109,7 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool including parameter values. """ # Wrap functions to avoid confusing typecheckers - sfd = partial(str_for_dist, formatting=formatting, include_params=include_params) + sfd = partial(str_for_dist, formatting=formatting, include_params=include_params, model=model) sfp = partial( str_for_potential_or_deterministic, formatting=formatting, include_params=include_params ) @@ -169,18 +172,14 @@ def str_for_potential_or_deterministic( return rf"{print_name} ~ {dist_name}" -def _str_for_input_var(var: Variable, formatting: str) -> str: +def _str_for_input_var(var: Variable, formatting: str, model: Model | None = None) -> str: # Avoid circular import from pymc.distributions.distribution import SymbolicRandomVariable def _is_potential_or_deterministic(var: Variable) -> bool: - if not hasattr(var, "str_repr"): - return False - try: - return var.str_repr.__func__.func is str_for_potential_or_deterministic - except AttributeError: - # in case other code overrides str_repr, fallback + if model is None: return False + return var in model.deterministics or var in model.potentials if isinstance(var, Constant | SharedVariable): return _str_for_constant(var, formatting) @@ -190,18 +189,18 @@ def _is_potential_or_deterministic(var: Variable) -> bool: # show the names for RandomVariables, Deterministics, and Potentials, rather # than the full expression assert isinstance(var, TensorVariable) - return _str_for_input_rv(var, formatting) + return _str_for_input_rv(var, formatting, model=model) elif isinstance(var.owner.op, DimShuffle): - return _str_for_input_var(var.owner.inputs[0], formatting) + return _str_for_input_var(var.owner.inputs[0], formatting, model=model) else: return _str_for_expression(var, formatting) -def _str_for_input_rv(var: TensorVariable, formatting: str) -> str: +def _str_for_input_rv(var: TensorVariable, formatting: str, model: Model | None = None) -> str: _str = ( var.name if var.name is not None - else str_for_dist(var, formatting=formatting, include_params=True) + else str_for_dist(var, formatting=formatting, include_params=True, model=model) ) if "latex" in formatting: return _latex_text_format(_latex_escape(_str.strip("$"))) @@ -277,37 +276,6 @@ def _latex_escape(text: str) -> str: return text.replace("$", r"\$") -def _default_repr_pretty(obj: TensorVariable | Model, p, cycle): - """Handy plug-in method to instruct IPython-like REPLs to use our str_repr above.""" - # we know that our str_repr does not recurse, so we can ignore cycle - try: - if not hasattr(obj, "str_repr"): - raise AttributeError - output = obj.str_repr() - # Find newlines and replace them with p.break_() - # (see IPython.lib.pretty._repr_pprint) - lines = output.splitlines() - with p.group(): - for idx, output_line in enumerate(lines): - if idx: - p.break_() - p.text(output_line) - except AttributeError: - # the default fallback option (no str_repr method) - IPython.lib.pretty._repr_pprint(obj, p, cycle) - - -try: - # register our custom pretty printer in ipython shells - import IPython - - IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty) - IPython.lib.pretty.for_type(Model, _default_repr_pretty) -except (ModuleNotFoundError, AttributeError): - # no ipython shell - pass - - def _format_underscore(variable: str) -> str: """Escapes all unescaped underscores in the variable name for LaTeX representation.""" return re.sub(r"(?