diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 091b045109..236ec93ed0 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -597,6 +597,22 @@ def eval(self, inputs_to_values=None): if inputs_to_values is None: inputs_to_values = {} + def convert_string_keys_to_variables(input_to_values): + new_input_to_values = {} + for key, value in inputs_to_values.items(): + if isinstance(key, str): + matching_vars = get_var_by_name([self], key) + if not matching_vars: + raise Exception(f"{key} not found in graph") + elif len(matching_vars) > 1: + raise Exception(f"Found multiple variables with name {key}") + new_input_to_values[matching_vars[0]] = value + else: + new_input_to_values[key] = value + return new_input_to_values + + inputs_to_values = convert_string_keys_to_variables(inputs_to_values) + if not hasattr(self, "_fn_cache"): self._fn_cache = dict() diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index a4779c8299..935301be05 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -302,6 +302,24 @@ def test_eval(self): pickle.loads(pickle.dumps(self.w)), "_fn_cache" ), "temporary functions must not be serialized" + def test_eval_with_strings(self): + assert self.w.eval({"x": 1.0, self.y: 2.0}) == 6.0 + assert self.w.eval({self.z: 3}) == 6.0 + + def test_eval_with_strings_multiple_matches(self): + e = scalars("e") + t = e + 1 + t.name = "e" + with pytest.raises(Exception, match="Found multiple variables with name e"): + t.eval({"e": 1}) + + def test_eval_with_strings_no_match(self): + e = scalars("e") + t = e + 1 + t.name = "p" + with pytest.raises(Exception, match="o not found in graph"): + t.eval({"o": 1}) + class TestAutoName: def test_auto_name(self):