Skip to content

Commit 389f818

Browse files
michaelosthegetwiecki
authored andcommitted
Add unit tests for initval evaluation
1 parent 00ffb3b commit 389f818

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

pymc3/tests/test_initvals.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import numpy as np
1415
import pytest
1516

1617
import pymc3 as pm
@@ -39,6 +40,47 @@ def test_new_warnings(self):
3940
pass
4041

4142

43+
class TestInitvalEvaluation:
44+
def test_random_draws(self):
45+
pmodel = pm.Model()
46+
rv = pm.Uniform.dist(lower=1, upper=2)
47+
iv = pmodel._eval_initval(
48+
rv_var=rv,
49+
initval=None,
50+
test_value=None,
51+
transform=None,
52+
)
53+
assert isinstance(iv, np.ndarray)
54+
assert 1 <= iv <= 2
55+
pass
56+
57+
def test_applies_transform(self):
58+
pmodel = pm.Model()
59+
rv = pm.Uniform.dist()
60+
tf = pm.Uniform.default_transform()
61+
iv = pmodel._eval_initval(
62+
rv_var=rv,
63+
initval=0.5,
64+
test_value=None,
65+
transform=tf,
66+
)
67+
assert isinstance(iv, np.ndarray)
68+
assert iv == 0
69+
pass
70+
71+
def test_falls_back_to_test_value(self):
72+
pmodel = pm.Model()
73+
rv = pm.Flat.dist()
74+
iv = pmodel._eval_initval(
75+
rv_var=rv,
76+
initval=None,
77+
test_value=0.6,
78+
transform=None,
79+
)
80+
assert iv == 0.6
81+
pass
82+
83+
4284
class TestSpecialDistributions:
4385
def test_automatically_assigned_test_values(self):
4486
# ...because they don't have random number generators.

0 commit comments

Comments
 (0)