Skip to content

Allow string keys in eval utility #242

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 7 commits into from
Mar 22, 2023
16 changes: 16 additions & 0 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,22 @@ def eval(self, inputs_to_values=None):
if inputs_to_values is None:
inputs_to_values = {}

def convert_string_keys_to_variables(input_to_values):
new_input_to_values = {}
for key, value in inputs_to_values.items():
if isinstance(key, str):
matching_vars = get_var_by_name([self], key)
if not matching_vars:
raise Exception(f"{key} not found in graph")
elif len(matching_vars) > 1:
raise Exception(f"Found multiple variables with name {key}")
new_input_to_values[matching_vars[0]] = value
else:
new_input_to_values[key] = value
return new_input_to_values

inputs_to_values = convert_string_keys_to_variables(inputs_to_values)

if not hasattr(self, "_fn_cache"):
self._fn_cache = dict()

Expand Down
18 changes: 18 additions & 0 deletions tests/graph/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,24 @@ def test_eval(self):
pickle.loads(pickle.dumps(self.w)), "_fn_cache"
), "temporary functions must not be serialized"

def test_eval_with_strings(self):
assert self.w.eval({"x": 1.0, self.y: 2.0}) == 6.0
assert self.w.eval({self.z: 3}) == 6.0

def test_eval_with_strings_multiple_matches(self):
e = scalars("e")
t = e + 1
t.name = "e"
with pytest.raises(Exception, match="Found multiple variables with name e"):
t.eval({"e": 1})

def test_eval_with_strings_no_match(self):
e = scalars("e")
t = e + 1
t.name = "p"
with pytest.raises(Exception, match="o not found in graph"):
t.eval({"o": 1})


class TestAutoName:
def test_auto_name(self):
Expand Down