Skip to content

Model with dims #7820

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

import io
import typing
import urllib.request
import warnings

from collections.abc import Sequence
from copy import copy
from typing import cast
from typing import Union, cast

import numpy as np
import pandas as pd
Expand All @@ -33,12 +34,13 @@
from pytensor.tensor.random.basic import IntegersRV
from pytensor.tensor.variable import TensorConstant, TensorVariable

import pymc as pm

from pymc.logprob.utils import rvs_in_graph
from pymc.pytensorf import convert_data
from pymc.exceptions import ShapeError
from pymc.pytensorf import convert_data, rvs_in_graph
from pymc.vartypes import isgenerator

if typing.TYPE_CHECKING:
from pymc.model.core import Model

__all__ = [
"ConstantData",
"Data",
Expand Down Expand Up @@ -200,7 +202,7 @@ def determine_coords(

if isinstance(value, np.ndarray) and dims is not None:
if len(dims) != value.ndim:
raise pm.exceptions.ShapeError(
raise ShapeError(
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
actual=value.shape,
expected=value.ndim,
Expand Down Expand Up @@ -286,6 +288,7 @@ def Data(
coords: dict[str, Sequence | np.ndarray] | None = None,
infer_dims_and_coords=False,
mutable: bool | None = None,
model: Union["Model", None] = None,
**kwargs,
) -> SharedVariable | TensorConstant:
"""Create a data container that registers a data variable with the model.
Expand Down Expand Up @@ -350,15 +353,18 @@ def Data(
... model.set_data("data", data_vals)
... idatas.append(pm.sample())
"""
from pymc.model.core import modelcontext

if coords is None:
coords = {}

if isinstance(value, list):
value = np.array(value)

# Add data container to the named variables of the model.
model = pm.Model.get_context(error_if_none=False)
if model is None:
try:
model = modelcontext(model)
except TypeError:
raise TypeError(
"No model on context stack, which is needed to instantiate a data container. "
"Add variable inside a 'with model:' block."
Expand Down Expand Up @@ -390,7 +396,7 @@ def Data(
if isinstance(dims, str):
dims = (dims,)
if not (dims is None or len(dims) == x.ndim):
raise pm.exceptions.ShapeError(
raise ShapeError(
"Length of `dims` must match the dimensions of the dataset.",
actual=len(dims),
expected=x.ndim,
Expand Down
64 changes: 64 additions & 0 deletions pymc/dims/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def __init__():
"""Make PyMC aware of the xtensor functionality.

This should be done eagerly once developemnt matures.
"""
import datetime
import warnings

from pytensor.compile import optdb

from pymc.initial_point import initial_point_rewrites_db
from pymc.logprob.abstract import MeasurableOp
from pymc.logprob.rewriting import logprob_rewrites_db

# Filter PyTensor xtensor warning, we emmit our own warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import pytensor.xtensor

from pytensor.xtensor.vectorization import XRV

# Make PyMC aware of xtensor functionality
MeasurableOp.register(XRV)
lower_xtensor_query = optdb.query("+lower_xtensor")
logprob_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1)
initial_point_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1)

# TODO: Better model of probability of bugs
day_of_conception = datetime.date(2025, 6, 17)
day_of_last_bug = datetime.date(2025, 6, 17)
today = datetime.date.today()
days_with_bugs = (day_of_last_bug - day_of_conception).days
Copy link
Member

@twiecki twiecki Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wtf 😆

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has two purposes: distract reviewers so they don't focus on the critical changes, and prove that OSS libraries can't be fun.

days_without_bugs = (today - day_of_last_bug).days
p = 1 - (days_without_bugs / (days_without_bugs + days_with_bugs + 10))
if p > 0.05:
warnings.warn(
f"The `pymc.dims` module is experimental and may contain critical bugs (p={p:.3f}).\n"
"Please report any issues you encounter at https://github.com/pymc-devs/pymc/issues",
UserWarning,
stacklevel=2,
)


__init__()
del __init__

from pymc.dims import math
from pymc.dims.distributions import *
from pymc.dims.model import Data, with_dims
180 changes: 180 additions & 0 deletions pymc/dims/distribution_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable, Sequence
from itertools import chain

from pytensor.tensor.elemwise import DimShuffle
from pytensor.xtensor import as_xtensor
from pytensor.xtensor.type import XTensorVariable

from pymc import modelcontext
from pymc.dims.model import with_dims
from pymc.distributions import transforms
from pymc.distributions.distribution import _support_point, support_point
from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims
from pymc.util import UNSET


@_support_point.register(DimShuffle)
def dimshuffle_support_point(ds_op, _, rv):
# We implement support point for DimShuffle because
# DimDistribution can register a transposed version of a variable.

return ds_op(support_point(rv))


class DimDistribution:
"""Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""

xrv_op: Callable
default_transform: Callable | None

@staticmethod
def _as_xtensor(x):
try:
return as_xtensor(x)
except TypeError:
try:
return with_dims(x)
except ValueError:
raise ValueError(
f"Variable {x} must have dims associated with it.\n"
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of the parameters.\n"
"Convert parameters to an xarray.DataArray, pymc.dims.Data or pytensor.xtensor.as_xtensor with explicit dims."
)

def __new__(
cls,
name: str,
*dist_params,
dims: DimsWithEllipsis | None = None,
initval=None,
observed=None,
total_size=None,
transform=UNSET,
default_transform=UNSET,
model=None,
**kwargs,
) -> XTensorVariable:
try:
model = modelcontext(model)
except TypeError:
raise TypeError(
"No model on context stack, which is needed to instantiate distributions. "
"Add variable inside a 'with model:' block, or use the '.dist' syntax for a standalone distribution."
)

if not isinstance(name, str):
raise TypeError(f"Name needs to be a string but got: {name}")

if dims is None:
dims_dict = {}
else:
dims = convert_dims(dims)
try:
dims_dict = {dim: model.dim_lengths[dim] for dim in dims if dim is not Ellipsis}
except KeyError:
raise ValueError(
f"Not all dims {dims} are part of the model coords. "
f"Add them at initialization time or use `model.add_coord` before defining the distribution."
)

if observed is not None:
observed = cls._as_xtensor(observed)

# Propagate observed dims to dims_dict
for observed_dim in observed.type.dims:
if observed_dim not in dims_dict:
dims_dict[observed_dim] = model.dim_lengths[observed_dim]

rv = cls.dist(*dist_params, dims_dict=dims_dict, **kwargs)

# User provided dims must specify all dims or use ellipsis
if dims is not None:
if (... not in dims) and (set(dims) != set(rv.type.dims)):
raise ValueError(
f"Provided dims {dims} do not match the distribution's output dims {rv.type.dims}. "
"Use ellipsis to specify all other dimensions."
)
# Use provided dims to transpose the output to the desired order
rv = rv.transpose(*dims)

rv_dims = rv.type.dims
if observed is None:
if default_transform is UNSET:
default_transform = cls.default_transform
else:
# Align observed dims with those of the RV
observed = observed.transpose(*rv_dims).values

rv = model.register_rv(
rv.values,
name=name,
observed=observed,
total_size=total_size,
dims=rv_dims,
transform=transform,
default_transform=default_transform,
initval=initval,
)

xrv = as_xtensor(rv, dims=rv_dims)
return xrv

@classmethod
def dist(
cls,
dist_params,
*,
dims_dict: dict[str, int] | None = None,
core_dims: str | Sequence[str] | None = None,
**kwargs,
) -> XTensorVariable:
for invalid_kwarg in ("size", "shape", "dims"):
if invalid_kwarg in kwargs:
raise TypeError(f"DimDistribution does not accept {invalid_kwarg} argument.")

# XRV requires only extra_dims, not dims
dist_params = [cls._as_xtensor(param) for param in dist_params]

if dims_dict is None:
extra_dims = None
else:
parameter_implied_dims = set(
chain.from_iterable(param.type.dims for param in dist_params)
)
extra_dims = {
dim: length
for dim, length in dims_dict.items()
if dim not in parameter_implied_dims
}
return cls.xrv_op(*dist_params, extra_dims=extra_dims, core_dims=core_dims, **kwargs)


class ContinuousDimDistribution(DimDistribution):
"""Base class for real-valued distributions."""

default_transform = None


class PositiveContinuousDimDistribution(DimDistribution):
"""Base class for positive continuous distributions."""

default_transform = transforms.log


class UnitDimDistribution(DimDistribution):
"""Base class for unit-valued distributions."""

default_transform = transforms.logodds
Loading