Skip to content

Commit 9008a32

Browse files
committed
Modified the code structure based on suggestions
1 parent 6e2efaa commit 9008a32

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

Diff for: pytensor/graph/basic.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -605,13 +605,12 @@ def convert_string_keys_to_variables(input_to_values):
605605
if not matching_vars:
606606
raise Exception(f"{key} not found in graph")
607607
elif len(matching_vars) > 1:
608-
raise Exception(
609-
f"Found multiple variables with name {key}"
610-
)
608+
raise Exception(f"Found multiple variables with name {key}")
611609
new_input_to_values[matching_vars[0]] = value
612610
else:
613611
new_input_to_values[key] = value
614612
return new_input_to_values
613+
615614
inputs_to_values = convert_string_keys_to_variables(inputs_to_values)
616615

617616
if not hasattr(self, "_fn_cache"):

Diff for: tests/graph/test_basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,14 @@ def test_eval_with_strings(self):
306306
assert self.w.eval({"x": 1.0, self.y: 2.0}) == 6.0
307307
assert self.w.eval({self.z: 3}) == 6.0
308308

309-
def test_eval_errors_having_mulitple_variables_same_name(self):
309+
def test_eval_with_strings_multiple_matches(self):
310310
e = scalars("e")
311311
t = e + 1
312312
t.name = "e"
313313
with pytest.raises(Exception, match="Found multiple variables with name e"):
314314
t.eval({"e": 1})
315315

316-
def test_eval_errors_with_no_name_exists(self):
316+
def test_eval_with_strings_no_match(self):
317317
e = scalars("e")
318318
t = e + 1
319319
t.name = "p"

0 commit comments

Comments
 (0)