-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Model with dims #7820
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Draft
ricardoV94
wants to merge
1
commit into
pymc-devs:main
Choose a base branch
from
ricardoV94:model_with_dims
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Model with dims #7820
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright 2025 - present The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
def __init__(): | ||
"""Make PyMC aware of the xtensor functionality. | ||
|
||
This should be done eagerly once developemnt matures. | ||
""" | ||
import datetime | ||
import warnings | ||
|
||
from pytensor.compile import optdb | ||
|
||
from pymc.initial_point import initial_point_rewrites_db | ||
from pymc.logprob.abstract import MeasurableOp | ||
from pymc.logprob.rewriting import logprob_rewrites_db | ||
|
||
# Filter PyTensor xtensor warning, we emmit our own warning | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore", UserWarning) | ||
import pytensor.xtensor | ||
|
||
from pytensor.xtensor.vectorization import XRV | ||
|
||
# Make PyMC aware of xtensor functionality | ||
MeasurableOp.register(XRV) | ||
lower_xtensor_query = optdb.query("+lower_xtensor") | ||
logprob_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1) | ||
initial_point_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1) | ||
|
||
# TODO: Better model of probability of bugs | ||
day_of_conception = datetime.date(2025, 6, 17) | ||
day_of_last_bug = datetime.date(2025, 6, 17) | ||
today = datetime.date.today() | ||
days_with_bugs = (day_of_last_bug - day_of_conception).days | ||
days_without_bugs = (today - day_of_last_bug).days | ||
p = 1 - (days_without_bugs / (days_without_bugs + days_with_bugs + 10)) | ||
if p > 0.05: | ||
warnings.warn( | ||
f"The `pymc.dims` module is experimental and may contain critical bugs (p={p:.3f}).\n" | ||
"Please report any issues you encounter at https://github.com/pymc-devs/pymc/issues", | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
|
||
|
||
__init__() | ||
del __init__ | ||
|
||
from pymc.dims import math | ||
from pymc.dims.distributions import * | ||
from pymc.dims.model import Data, with_dims |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# Copyright 2025 - present The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections.abc import Callable, Sequence | ||
from itertools import chain | ||
|
||
from pytensor.tensor.elemwise import DimShuffle | ||
from pytensor.xtensor import as_xtensor | ||
from pytensor.xtensor.type import XTensorVariable | ||
|
||
from pymc import modelcontext | ||
from pymc.dims.model import with_dims | ||
from pymc.distributions import transforms | ||
from pymc.distributions.distribution import _support_point, support_point | ||
from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims | ||
from pymc.util import UNSET | ||
|
||
|
||
@_support_point.register(DimShuffle) | ||
def dimshuffle_support_point(ds_op, _, rv): | ||
# We implement support point for DimShuffle because | ||
# DimDistribution can register a transposed version of a variable. | ||
|
||
return ds_op(support_point(rv)) | ||
|
||
|
||
class DimDistribution: | ||
"""Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics.""" | ||
|
||
xrv_op: Callable | ||
default_transform: Callable | None | ||
|
||
@staticmethod | ||
def _as_xtensor(x): | ||
try: | ||
return as_xtensor(x) | ||
except TypeError: | ||
try: | ||
return with_dims(x) | ||
except ValueError: | ||
raise ValueError( | ||
f"Variable {x} must have dims associated with it.\n" | ||
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of the parameters.\n" | ||
"Convert parameters to an xarray.DataArray, pymc.dims.Data or pytensor.xtensor.as_xtensor with explicit dims." | ||
) | ||
|
||
def __new__( | ||
cls, | ||
name: str, | ||
*dist_params, | ||
dims: DimsWithEllipsis | None = None, | ||
initval=None, | ||
observed=None, | ||
total_size=None, | ||
transform=UNSET, | ||
default_transform=UNSET, | ||
model=None, | ||
**kwargs, | ||
) -> XTensorVariable: | ||
try: | ||
model = modelcontext(model) | ||
except TypeError: | ||
raise TypeError( | ||
"No model on context stack, which is needed to instantiate distributions. " | ||
"Add variable inside a 'with model:' block, or use the '.dist' syntax for a standalone distribution." | ||
) | ||
|
||
if not isinstance(name, str): | ||
raise TypeError(f"Name needs to be a string but got: {name}") | ||
|
||
if dims is None: | ||
dims_dict = {} | ||
else: | ||
dims = convert_dims(dims) | ||
try: | ||
dims_dict = {dim: model.dim_lengths[dim] for dim in dims if dim is not Ellipsis} | ||
except KeyError: | ||
raise ValueError( | ||
f"Not all dims {dims} are part of the model coords. " | ||
f"Add them at initialization time or use `model.add_coord` before defining the distribution." | ||
) | ||
|
||
if observed is not None: | ||
observed = cls._as_xtensor(observed) | ||
|
||
# Propagate observed dims to dims_dict | ||
for observed_dim in observed.type.dims: | ||
if observed_dim not in dims_dict: | ||
dims_dict[observed_dim] = model.dim_lengths[observed_dim] | ||
|
||
rv = cls.dist(*dist_params, dims_dict=dims_dict, **kwargs) | ||
|
||
# User provided dims must specify all dims or use ellipsis | ||
if dims is not None: | ||
if (... not in dims) and (set(dims) != set(rv.type.dims)): | ||
raise ValueError( | ||
f"Provided dims {dims} do not match the distribution's output dims {rv.type.dims}. " | ||
"Use ellipsis to specify all other dimensions." | ||
) | ||
# Use provided dims to transpose the output to the desired order | ||
rv = rv.transpose(*dims) | ||
|
||
rv_dims = rv.type.dims | ||
if observed is None: | ||
if default_transform is UNSET: | ||
default_transform = cls.default_transform | ||
else: | ||
# Align observed dims with those of the RV | ||
observed = observed.transpose(*rv_dims).values | ||
|
||
rv = model.register_rv( | ||
rv.values, | ||
name=name, | ||
observed=observed, | ||
total_size=total_size, | ||
dims=rv_dims, | ||
transform=transform, | ||
default_transform=default_transform, | ||
initval=initval, | ||
) | ||
|
||
xrv = as_xtensor(rv, dims=rv_dims) | ||
return xrv | ||
|
||
@classmethod | ||
def dist( | ||
cls, | ||
dist_params, | ||
*, | ||
dims_dict: dict[str, int] | None = None, | ||
core_dims: str | Sequence[str] | None = None, | ||
**kwargs, | ||
) -> XTensorVariable: | ||
for invalid_kwarg in ("size", "shape", "dims"): | ||
if invalid_kwarg in kwargs: | ||
raise TypeError(f"DimDistribution does not accept {invalid_kwarg} argument.") | ||
|
||
# XRV requires only extra_dims, not dims | ||
dist_params = [cls._as_xtensor(param) for param in dist_params] | ||
|
||
if dims_dict is None: | ||
extra_dims = None | ||
else: | ||
parameter_implied_dims = set( | ||
chain.from_iterable(param.type.dims for param in dist_params) | ||
) | ||
extra_dims = { | ||
dim: length | ||
for dim, length in dims_dict.items() | ||
if dim not in parameter_implied_dims | ||
} | ||
return cls.xrv_op(*dist_params, extra_dims=extra_dims, core_dims=core_dims, **kwargs) | ||
|
||
|
||
class ContinuousDimDistribution(DimDistribution): | ||
"""Base class for real-valued distributions.""" | ||
|
||
default_transform = None | ||
|
||
|
||
class PositiveContinuousDimDistribution(DimDistribution): | ||
"""Base class for positive continuous distributions.""" | ||
|
||
default_transform = transforms.log | ||
|
||
|
||
class UnitDimDistribution(DimDistribution): | ||
"""Base class for unit-valued distributions.""" | ||
|
||
default_transform = transforms.logodds |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wtf 😆
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has two purposes: distract reviewers so they don't focus on the critical changes, and prove that OSS libraries can't be fun.