Skip to content

Commit fad8864

Browse files
committed
Rename add_random_variable to add_named_variable and _RV_dims to named_vars_to_dims
1 parent ba41e95 commit fad8864

File tree

11 files changed

+50
-50
lines changed

11 files changed

+50
-50
lines changed

Diff for: pymc/backends/arviz.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,11 @@ def __init__(
215215
}
216216

217217
self.dims = {} if dims is None else dims
218-
if hasattr(self.model, "RV_dims"):
219-
model_dims = {
220-
var_name: [dim for dim in dims if dim is not None]
221-
for var_name, dims in self.model.RV_dims.items()
222-
}
223-
self.dims = {**model_dims, **self.dims}
218+
model_dims = {
219+
var_name: [dim for dim in dims if dim is not None]
220+
for var_name, dims in self.model.named_vars_to_dims.items()
221+
}
222+
self.dims = {**model_dims, **self.dims}
224223
if sample_dims is None:
225224
sample_dims = ["chain", "draw"]
226225
self.sample_dims = sample_dims

Diff for: pymc/data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,6 @@ def Data(
718718
length=xshape[d],
719719
)
720720

721-
model.add_random_variable(x, dims=dims)
721+
model.add_named_variable(x, dims=dims)
722722

723723
return x

Diff for: pymc/model.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ def __init__(
554554

555555
if self.parent is not None:
556556
self.named_vars = treedict(parent=self.parent.named_vars)
557+
self.named_vars_to_dims = treedict(parent=self.parent.named_vars_to_dims)
557558
self.values_to_rvs = treedict(parent=self.parent.values_to_rvs)
558559
self.rvs_to_values = treedict(parent=self.parent.rvs_to_values)
559560
self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms)
@@ -564,10 +565,10 @@ def __init__(
564565
self.deterministics = treelist(parent=self.parent.deterministics)
565566
self.potentials = treelist(parent=self.parent.potentials)
566567
self._coords = self.parent._coords
567-
self._RV_dims = treedict(parent=self.parent._RV_dims)
568568
self._dim_lengths = self.parent._dim_lengths
569569
else:
570570
self.named_vars = treedict()
571+
self.named_vars_to_dims = treedict()
571572
self.values_to_rvs = treedict()
572573
self.rvs_to_values = treedict()
573574
self.rvs_to_transforms = treedict()
@@ -578,7 +579,6 @@ def __init__(
578579
self.deterministics = treelist()
579580
self.potentials = treelist()
580581
self._coords = {}
581-
self._RV_dims = treedict()
582582
self._dim_lengths = {}
583583
self.add_coords(coords)
584584

@@ -972,7 +972,11 @@ def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]:
972972
973973
Entries in the tuples may be ``None``, if the RV dimension was not given a name.
974974
"""
975-
return self._RV_dims
975+
warnings.warn(
976+
"Model.RV_dims is deprecated. User Model.named_vars_to_dims instead.",
977+
FutureWarning,
978+
)
979+
return self.named_vars_to_dims
976980

977981
@property
978982
def coords(self) -> Dict[str, Union[Tuple, None]]:
@@ -1167,7 +1171,7 @@ def set_data(
11671171
if isinstance(values, list):
11681172
values = np.array(values)
11691173
values = convert_observed_data(values)
1170-
dims = self.RV_dims.get(name, None) or ()
1174+
dims = self.named_vars_to_dims.get(name, None) or ()
11711175
coords = coords or {}
11721176

11731177
if values.ndim != shared_object.ndim:
@@ -1297,7 +1301,7 @@ def register_rv(
12971301
if observed is None:
12981302
self.free_RVs.append(rv_var)
12991303
self.create_value_var(rv_var, transform)
1300-
self.add_random_variable(rv_var, dims)
1304+
self.add_named_variable(rv_var, dims)
13011305
self.set_initval(rv_var, initval)
13021306
else:
13031307
if (
@@ -1424,7 +1428,7 @@ def make_obs_var(
14241428
observed_rv_var.tag.observations = nonmissing_data
14251429

14261430
self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)
1427-
self.add_random_variable(observed_rv_var)
1431+
self.add_named_variable(observed_rv_var)
14281432
self.observed_RVs.append(observed_rv_var)
14291433

14301434
# Create deterministic that combines observed and missing
@@ -1440,7 +1444,7 @@ def make_obs_var(
14401444
data = at.as_tensor_variable(data, name=name)
14411445
rv_var.tag.observations = data
14421446
self.create_value_var(rv_var, transform=None, value_var=data)
1443-
self.add_random_variable(rv_var, dims)
1447+
self.add_named_variable(rv_var, dims)
14441448
self.observed_RVs.append(rv_var)
14451449

14461450
return rv_var
@@ -1486,8 +1490,12 @@ def create_value_var(
14861490

14871491
return value_var
14881492

1489-
def add_random_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] = None):
1490-
"""Add a random variable to the named variables of the model."""
1493+
def add_named_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] = None):
1494+
"""Add a random graph variable to the named variables of the model.
1495+
1496+
This can include several types of variables such basic_RVs, Data, Deterministics,
1497+
and Potentials.
1498+
"""
14911499
if self.named_vars.tree_contains(var.name):
14921500
raise ValueError(f"Variable name {var.name} already exists.")
14931501

@@ -1499,7 +1507,7 @@ def add_random_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]]
14991507
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
15001508
if any(var.name == dim for dim in dims):
15011509
raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")
1502-
self._RV_dims[var.name] = dims
1510+
self.named_vars_to_dims[var.name] = dims
15031511

15041512
self.named_vars[var.name] = var
15051513
if not hasattr(self, self.name_of(var.name)):
@@ -1967,7 +1975,7 @@ def Deterministic(name, var, model=None, dims=None, auto=False):
19671975
model.auto_deterministics.append(var)
19681976
else:
19691977
model.deterministics.append(var)
1970-
model.add_random_variable(var, dims)
1978+
model.add_named_variable(var, dims)
19711979

19721980
from pymc.printing import str_for_potential_or_deterministic
19731981

@@ -1999,7 +2007,7 @@ def Potential(name, var, model=None):
19992007
model = modelcontext(model)
20002008
var.name = model.name_for(name)
20012009
model.potentials.append(var)
2002-
model.add_random_variable(var)
2010+
model.add_named_variable(var)
20032011

20042012
from pymc.printing import str_for_potential_or_deterministic
20052013

Diff for: pymc/model_graph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,10 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str,
200200

201201
for var_name in self.vars_to_plot(var_names):
202202
v = self.model[var_name]
203-
if var_name in self.model.RV_dims:
203+
if var_name in self.model.named_vars_to_dims:
204204
plate_label = " x ".join(
205205
f"{d} ({self._eval(self.model.dim_lengths[d])})"
206-
for d in self.model.RV_dims[var_name]
206+
for d in self.model.named_vars_to_dims[var_name]
207207
)
208208
else:
209209
plate_label = " x ".join(map(str, self._eval(v.shape)))

Diff for: pymc/sampling/forward.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,10 @@ def sample_prior_predictive(
371371
)
372372

373373
if var_names is None:
374-
prior_pred_vars = model.observed_RVs + model.auto_deterministics
375-
prior_vars = (
376-
get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials
377-
)
378-
vars_: Set[str] = {var.name for var in prior_vars + prior_pred_vars}
374+
vars_: Set[str] = {
375+
var.name
376+
for var in model.basic_RVs + model.deterministics + model.auto_deterministics
377+
}
379378
else:
380379
vars_ = set(var_names)
381380

Diff for: pymc/sampling/jax.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -343,13 +343,10 @@ def sample_blackjax_nuts(
343343
if cvals is not None
344344
}
345345

346-
if hasattr(model, "RV_dims"):
347-
dims = {
348-
var_name: [dim for dim in dims if dim is not None]
349-
for var_name, dims in model.RV_dims.items()
350-
}
351-
else:
352-
dims = {}
346+
dims = {
347+
var_name: [dim for dim in dims if dim is not None]
348+
for var_name, dims in model.named_vars_to_dims.items()
349+
}
353350

354351
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
355352

@@ -559,13 +556,10 @@ def sample_numpyro_nuts(
559556
if cvals is not None
560557
}
561558

562-
if hasattr(model, "RV_dims"):
563-
dims = {
564-
var_name: [dim for dim in dims if dim is not None]
565-
for var_name, dims in model.RV_dims.items()
566-
}
567-
else:
568-
dims = {}
559+
dims = {
560+
var_name: [dim for dim in dims if dim is not None]
561+
for var_name, dims in model.named_vars_to_dims.items()
562+
}
569563

570564
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
571565

Diff for: pymc/tests/distributions/test_shape_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def test_simultaneous_shape_and_dims(self):
288288
# The shape and dims tuples correspond to each other.
289289
# Note: No checks are performed that implied shape (x), shape and dims actually match.
290290
y = pm.Normal("y", mu=x, shape=(2, 3), dims=("dshape", "ddata"))
291-
assert pmodel.RV_dims["y"] == ("dshape", "ddata")
291+
assert pmodel.named_vars_to_dims["y"] == ("dshape", "ddata")
292292

293293
assert "dshape" in pmodel.dim_lengths
294294
assert y.eval().shape == (2, 3)
@@ -301,7 +301,7 @@ def test_simultaneous_size_and_dims(self):
301301
# Size does not include support dims, so this test must use a dist with support dims.
302302
kwargs = dict(name="y", size=(2, 3), mu=at.ones((3, 4)), cov=at.eye(4))
303303
y = pm.MvNormal(**kwargs, dims=("dsize", "ddata", "dsupport"))
304-
assert pmodel.RV_dims["y"] == ("dsize", "ddata", "dsupport")
304+
assert pmodel.named_vars_to_dims["y"] == ("dsize", "ddata", "dsupport")
305305

306306
assert "dsize" in pmodel.dim_lengths
307307
assert y.eval().shape == (2, 3, 4)
@@ -313,7 +313,7 @@ def test_simultaneous_dims_and_observed(self):
313313

314314
# Note: No checks are performed that observed and dims actually match.
315315
y = pm.Normal("y", observed=[0, 0, 0], dims="ddata")
316-
assert pmodel.RV_dims["y"] == ("ddata",)
316+
assert pmodel.named_vars_to_dims["y"] == ("ddata",)
317317
assert y.eval().shape == (3,)
318318

319319
def test_define_dims_on_the_fly_raises(self):

Diff for: pymc/tests/sampling/test_jax.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_idata_kwargs(
235235

236236
posterior = idata.get("posterior")
237237
assert posterior is not None
238-
x_dim_expected = idata_kwargs.get("dims", model_test_idata_kwargs.RV_dims)["x"][0]
238+
x_dim_expected = idata_kwargs.get("dims", model_test_idata_kwargs.named_vars_to_dims)["x"][0]
239239
assert x_dim_expected is not None
240240
assert posterior["x"].dims[-1] == x_dim_expected
241241

Diff for: pymc/tests/test_data.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def test_explicit_coords(self):
331331
assert pmodel.dim_lengths["rows"].eval() == 5
332332
assert "columns" in pmodel.coords
333333
assert pmodel.coords["columns"] == ("C1", "C2", "C3", "C4", "C5", "C6", "C7")
334-
assert pmodel.RV_dims == {"observations": ("rows", "columns")}
334+
assert pmodel.named_vars_to_dims == {"observations": ("rows", "columns")}
335335
assert "columns" in pmodel.dim_lengths
336336
assert pmodel.dim_lengths["columns"].eval() == 7
337337

@@ -382,7 +382,7 @@ def test_implicit_coords_series(self):
382382

383383
assert "date" in pmodel.coords
384384
assert len(pmodel.coords["date"]) == 22
385-
assert pmodel.RV_dims == {"sales": ("date",)}
385+
assert pmodel.named_vars_to_dims == {"sales": ("date",)}
386386

387387
def test_implicit_coords_dataframe(self):
388388
pd = pytest.importorskip("pandas")
@@ -402,7 +402,7 @@ def test_implicit_coords_dataframe(self):
402402

403403
assert "rows" in pmodel.coords
404404
assert "columns" in pmodel.coords
405-
assert pmodel.RV_dims == {"observations": ("rows", "columns")}
405+
assert pmodel.named_vars_to_dims == {"observations": ("rows", "columns")}
406406

407407
def test_data_kwargs(self):
408408
strict_value = True

Diff for: pymc/tests/test_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def test_nested_model_coords():
726726
e = pm.Normal("e", a[None] + d[:, None], dims=("dim2", "dim1"))
727727
assert m1.coords is m2.coords
728728
assert m1.dim_lengths is m2.dim_lengths
729-
assert set(m2.RV_dims) < set(m1.RV_dims)
729+
assert set(m2.named_vars_to_dims) < set(m1.named_vars_to_dims)
730730

731731

732732
def test_shapeerror_from_set_data_dimensionality():
@@ -1378,7 +1378,7 @@ def test_dims(self):
13781378
with pm.Model(coords={"observed": range(10)}) as model:
13791379
with pytest.warns(ImputationWarning):
13801380
x = pm.Normal("x", observed=data, dims=("observed",))
1381-
assert model.RV_dims == {"x": ("observed",)}
1381+
assert model.named_vars_to_dims == {"x": ("observed",)}
13821382

13831383
def test_error_non_random_variable(self):
13841384
data = np.array([np.nan] * 3 + [0] * 7)

Diff for: pymc/util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(self, iterable=(), parent=None, **kwargs):
120120
update = withparent(dict.update)
121121

122122
def tree_contains(self, item):
123-
# needed for `add_random_variable` method
123+
# needed for `add_named_variable` method
124124
if isinstance(self.parent, treedict):
125125
return dict.__contains__(self, item) or self.parent.tree_contains(item)
126126
elif isinstance(self.parent, dict):

0 commit comments

Comments
 (0)