diff --git a/dasy/parser/comparisons.py b/dasy/parser/comparisons.py index 0052a40..b6d877b 100644 --- a/dasy/parser/comparisons.py +++ b/dasy/parser/comparisons.py @@ -1,28 +1,34 @@ +from typing import List, Union from hy import models from dasy import parser from .utils import next_nodeid import vyper.ast.nodes as vy_nodes +from hy.models import Expression, Symbol COMP_FUNCS = ["<", "<=", ">", ">=", "==", "!="] - -def chain_comps(expr): +def chain_comps(chain_expr: Expression) -> Expression: + """ + Creates a new expression chaining comparisons. + """ new_node = models.Expression() - new_expr = [models.Symbol("and")] - for vals in zip(expr[1:], expr[2:]): - new_expr.append(models.Expression((expr[0], vals[0], vals[1]))) + new_expr: List[Union[Symbol, Expression]] = [models.Symbol("and")] + for vals in zip(chain_expr[1:], chain_expr[2:]): + new_expr.append(models.Expression((chain_expr[0], vals[0], vals[1]))) new_node += tuple(new_expr) return new_node +def parse_comparison(comparison_expr: Expression) -> vy_nodes.Compare: + """ + Parses a comparison expression, chaining comparisons if necessary. + """ + assert str(comparison_expr[0]) in COMP_FUNCS, f"Invalid comparison operator {comparison_expr[0]}" -def parse_comparison(comp_tree): - if ( - len(comp_tree[1:]) > 2 - ): # comparing more than 2 things; chain comps for (< 2 3 4 ) - return parser.parse_node(chain_comps(comp_tree)) - left = parser.parse_node(comp_tree[1]) - right = parser.parse_node(comp_tree[2]) - op = parser.parse_node(comp_tree[0]) + # Always apply chain comps for consistency + chained_expr = chain_comps(comparison_expr) + left = parser.parse_node(chained_expr[1]) + right = parser.parse_node(chained_expr[2]) + op = parser.parse_node(chained_expr[0]) return vy_nodes.Compare( left=left, ops=[op],