Skip to content

Commit 8b063f9

Browse files
committed
Do not set RNG updates inplace in compile_pymc
1 parent 80f8195 commit 8b063f9

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

pymc/aesaraf.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
Apply,
4040
Constant,
4141
Variable,
42+
ancestors,
4243
clone_get_equiv,
4344
graph_inputs,
4445
walk,
@@ -963,18 +964,19 @@ def compile_pymc(
963964
this function is called within a model context and the model `check_bounds` flag
964965
is set to False.
965966
"""
966-
# Set the default update of RandomVariable's RNG so that it is automatically
967+
# Create an update mapping of RandomVariable's RNG so that it is automatically
967968
# updated after every function call
968969
# TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph)
970+
rng_updates = {}
969971
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
970972
for rv in (
971973
node
972-
for node in walk_model(output_to_list, walk_past_rvs=True)
974+
for node in ancestors(output_to_list)
973975
if node.owner and isinstance(node.owner.op, RandomVariable)
974976
):
975977
rng = rv.owner.inputs[0]
976978
if not hasattr(rng, "default_update"):
977-
rng.default_update = rv.owner.outputs[0]
979+
rng_updates[rng] = rv.owner.outputs[0]
978980

979981
# If called inside a model context, see if check_bounds flag is set to False
980982
try:
@@ -991,5 +993,11 @@ def compile_pymc(
991993
mode = get_mode(mode)
992994
opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
993995
mode = Mode(linker=mode.linker, optimizer=opt_qry)
994-
aesara_function = aesara.function(inputs, outputs, mode=mode, **kwargs)
996+
aesara_function = aesara.function(
997+
inputs,
998+
outputs,
999+
updates={**rng_updates, **kwargs.pop("updates", {})},
1000+
mode=mode,
1001+
**kwargs,
1002+
)
9951003
return aesara_function

pymc/tests/test_aesaraf.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,9 +576,21 @@ def test_check_bounds_flag():
576576
assert np.all(compile_pymc([], bound)() == -np.inf)
577577

578578

579-
def test_compile_pymc_sets_default_updates():
579+
def test_compile_pymc_sets_rng_updates():
580580
rng = aesara.shared(np.random.default_rng(0))
581581
x = pm.Normal.dist(rng=rng)
582582
assert x.owner.inputs[0] is rng
583583
f = compile_pymc([], x)
584584
assert not np.isclose(f(), f())
585+
586+
# Check that update was not done inplace
587+
assert not hasattr(rng, "default_update")
588+
f = aesara.function([], x)
589+
assert f() == f()
590+
591+
592+
def test_compile_pymc_with_updates():
593+
x = aesara.shared(0)
594+
f = compile_pymc([], x, updates={x: x + 1})
595+
assert f() == 0
596+
assert f() == 1

0 commit comments

Comments
 (0)