39
39
Apply ,
40
40
Constant ,
41
41
Variable ,
42
+ ancestors ,
42
43
clone_get_equiv ,
43
44
graph_inputs ,
44
45
walk ,
@@ -963,18 +964,19 @@ def compile_pymc(
963
964
this function is called within a model context and the model `check_bounds` flag
964
965
is set to False.
965
966
"""
966
- # Set the default update of RandomVariable's RNG so that it is automatically
967
+ # Create an update mapping of RandomVariable's RNG so that it is automatically
967
968
# updated after every function call
968
969
# TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph)
970
+ rng_updates = {}
969
971
output_to_list = outputs if isinstance (outputs , (list , tuple )) else [outputs ]
970
972
for rv in (
971
973
node
972
- for node in walk_model (output_to_list , walk_past_rvs = True )
974
+ for node in ancestors (output_to_list )
973
975
if node .owner and isinstance (node .owner .op , RandomVariable )
974
976
):
975
977
rng = rv .owner .inputs [0 ]
976
978
if not hasattr (rng , "default_update" ):
977
- rng . default_update = rv .owner .outputs [0 ]
979
+ rng_updates [ rng ] = rv .owner .outputs [0 ]
978
980
979
981
# If called inside a model context, see if check_bounds flag is set to False
980
982
try :
@@ -991,5 +993,11 @@ def compile_pymc(
991
993
mode = get_mode (mode )
992
994
opt_qry = mode .provided_optimizer .including ("random_make_inplace" , check_parameter_opt )
993
995
mode = Mode (linker = mode .linker , optimizer = opt_qry )
994
- aesara_function = aesara .function (inputs , outputs , mode = mode , ** kwargs )
996
+ aesara_function = aesara .function (
997
+ inputs ,
998
+ outputs ,
999
+ updates = {** rng_updates , ** kwargs .pop ("updates" , {})},
1000
+ mode = mode ,
1001
+ ** kwargs ,
1002
+ )
995
1003
return aesara_function
0 commit comments