Skip to content

Commit 48870ad

Browse files
committed
Rename Model initial_values to rvs_to_initial_values
1 parent fad8864 commit 48870ad

File tree

4 files changed

+14
-11
lines changed

4 files changed

+14
-11
lines changed

pymc/initial_point.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def make_initial_point_fn(
133133

134134
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
135135
initval_strats = {
136-
**model.initial_values,
136+
**model.rvs_to_initial_values,
137137
**sdict_overrides,
138138
}
139139

pymc/model.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -550,15 +550,14 @@ def __init__(
550550
self.name = self._validate_name(name)
551551
self.check_bounds = check_bounds
552552

553-
self._initial_values: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]] = {}
554-
555553
if self.parent is not None:
556554
self.named_vars = treedict(parent=self.parent.named_vars)
557555
self.named_vars_to_dims = treedict(parent=self.parent.named_vars_to_dims)
558556
self.values_to_rvs = treedict(parent=self.parent.values_to_rvs)
559557
self.rvs_to_values = treedict(parent=self.parent.rvs_to_values)
560558
self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms)
561559
self.rvs_to_total_sizes = treedict(parent=self.parent.rvs_to_total_sizes)
560+
self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values)
562561
self.free_RVs = treelist(parent=self.parent.free_RVs)
563562
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
564563
self.auto_deterministics = treelist(parent=self.parent.auto_deterministics)
@@ -573,6 +572,7 @@ def __init__(
573572
self.rvs_to_values = treedict()
574573
self.rvs_to_transforms = treedict()
575574
self.rvs_to_total_sizes = treedict()
575+
self.rvs_to_initial_values = treedict()
576576
self.free_RVs = treelist()
577577
self.observed_RVs = treelist()
578578
self.auto_deterministics = treelist()
@@ -1128,15 +1128,18 @@ def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Vari
11281128
Keys are the random variables (as returned by e.g. ``pm.Uniform()``) and
11291129
values are the numeric/symbolic initial values, strings denoting the strategy to get them, or None.
11301130
"""
1131-
return self._initial_values
1131+
warnings.warn(
1132+
"Model.initial_values is deprecated. Use Model.rvs_to_initial_values instead."
1133+
)
1134+
return self.rvs_to_initial_values
11321135

11331136
def set_initval(self, rv_var, initval):
11341137
"""Sets an initial value (strategy) for a random variable."""
11351138
if initval is not None and not isinstance(initval, (Variable, str)):
11361139
# Convert scalars or array-like inputs to ndarrays
11371140
initval = rv_var.type.filter(initval)
11381141

1139-
self.initial_values[rv_var] = initval
1142+
self.rvs_to_initial_values[rv_var] = initval
11401143

11411144
def set_data(
11421145
self,

pymc/tests/test_initial_point.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_dependent_initvals(self):
8686
assert ip["B2_interval__"] == 0
8787

8888
# Modify initval of L and re-evaluate
89-
pmodel.initial_values[U] = 9.9
89+
pmodel.rvs_to_initial_values[U] = 9.9
9090
ip = pmodel.initial_point(random_seed=0)
9191
assert ip["B1_interval__"] < 0
9292
assert ip["B2_interval__"] == 0
@@ -108,7 +108,7 @@ def test_nested_initvals(self):
108108
ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=False)(0).values())
109109
assert np.allclose(ip_vals, [1, 2, 4, 8, 16, 32], rtol=1e-3)
110110

111-
pmodel.initial_values[four] = 1
111+
pmodel.rvs_to_initial_values[four] = 1
112112

113113
ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=True)(0).values())
114114
assert np.allclose(np.exp(ip_vals), [1, 2, 4, 1, 2, 4], rtol=1e-3)

pymc/tests/test_model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -586,11 +586,11 @@ def test_initial_point():
586586
with model:
587587
y = pm.Normal("y", initval=y_initval)
588588

589-
assert a in model.initial_values
590-
assert x in model.initial_values
591-
assert model.initial_values[b] == b_initval
589+
assert a in model.rvs_to_initial_values
590+
assert x in model.rvs_to_initial_values
591+
assert model.rvs_to_initial_values[b] == b_initval
592592
assert model.initial_point(0)["b_interval__"] == b_initval_trans
593-
assert model.initial_values[y] == y_initval
593+
assert model.rvs_to_initial_values[y] == y_initval
594594

595595

596596
def test_point_logps():

0 commit comments

Comments
 (0)