Skip to content

Commit

Permalink
feat: added arithmetic expression support, closes georgia-tech-db#1093
Browse files Browse the repository at this point in the history
  • Loading branch information
aayushacharya committed Feb 7, 2024
1 parent e5a9190 commit 287ca98
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 4 deletions.
50 changes: 47 additions & 3 deletions evadb/expression/arithmetic_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,57 @@ def __init__(
super().__init__(exp_type, rtype=ExpressionReturnType.FLOAT, children=children)

def evaluate(self, *args, **kwargs):
vl = self.get_child(0).evaluate(*args, **kwargs)
vr = self.get_child(1).evaluate(*args, **kwargs)
lbatch = self.get_child(0).evaluate(*args, **kwargs)
rbatch = self.get_child(1).evaluate(*args, **kwargs)

return Batch.combine_batches(vl, vr, self.etype)
assert len(lbatch) == len(
rbatch
), f"Left and Right batch does not have equal elements: left: {len(lbatch)} right: {len(rbatch)}"

assert self.etype in [
ExpressionType.ARITHMETIC_ADD,
ExpressionType.ARITHMETIC_SUBTRACT,
ExpressionType.ARITHMETIC_DIVIDE,
ExpressionType.ARITHMETIC_MULTIPLY,
], f"Expression type not supported {self.etype}"

if self.etype == ExpressionType.ARITHMETIC_ADD:
return Batch.from_add(lbatch, rbatch)
elif self.etype == ExpressionType.ARITHMETIC_SUBTRACT:
return Batch.from_subtract(lbatch, rbatch)
elif self.etype == ExpressionType.ARITHMETIC_MULTIPLY:
return Batch.from_multiply(lbatch, rbatch)
elif self.etype == ExpressionType.ARITHMETIC_DIVIDE:
return Batch.from_divide(lbatch, rbatch)

return Batch.combine_batches(lbatch, rbatch, self.etype)

def get_symbol(self) -> str:
if self.etype == ExpressionType.ARITHMETIC_ADD:
return "+"
elif self.etype == ExpressionType.ARITHMETIC_SUBTRACT:
return "-"
elif self.etype == ExpressionType.ARITHMETIC_MULTIPLY:
return "*"
elif self.etype == ExpressionType.ARITHMETIC_DIVIDE:
return "/"

def __str__(self) -> str:
expr_str = "("
if self.get_child(0):
expr_str += f"{self.get_child(0)}"
if self.etype:
expr_str += f" {self.get_symbol()} "
if self.get_child(1):
expr_str += f"{self.get_child(1)}"
expr_str += ")"
return expr_str

def __eq__(self, other):
is_subtree_equal = super().__eq__(other)
if not isinstance(other, ArithmeticExpression):
return False
return is_subtree_equal and self.etype == other.etype

def __hash__(self) -> int:
return super().__hash__()
16 changes: 16 additions & 0 deletions evadb/models/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ def deserialize(cls, data):
obj = PickleSerializer.deserialize(data)
return cls(frames=obj["frames"])

@classmethod
def from_add(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() + batch2.to_numpy()))

@classmethod
def from_subtract(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() - batch2.to_numpy()))

@classmethod
def from_multiply(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() * batch2.to_numpy()))

@classmethod
def from_divide(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.Dataframe(batch1.to_numpy() / batch2.to_numpy()))

@classmethod
def from_eq(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() == batch2.to_numpy()))
Expand Down
2 changes: 1 addition & 1 deletion evadb/parser/evadb.lark
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ expression_atom.2: constant ->constant_expression_atom
| "(" expression ("," expression)* ")" ->nested_expression_atom
| "(" select_statement ")" ->subquery_expession_atom
| expression_atom bit_operator expression_atom ->bit_expression_atom
| expression_atom math_operator expression_atom
| expression_atom math_operator expression_atom -> arithmetic_expression_atom

unary_operator: EXCLAMATION_SYMBOL | BIT_NOT_OP | PLUS | MINUS | NOT

Expand Down
18 changes: 18 additions & 0 deletions evadb/parser/lark_visitor/_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from evadb.catalog.catalog_type import ColumnType
from evadb.expression.abstract_expression import ExpressionType
from evadb.expression.arithmetic_expression import ArithmeticExpression
from evadb.expression.comparison_expression import ComparisonExpression
from evadb.expression.constant_value_expression import ConstantValueExpression
from evadb.expression.logical_expression import LogicalExpression
Expand Down Expand Up @@ -60,6 +61,23 @@ def constant(self, tree):

return self.visit_children(tree)

def arithmetic_expression_atom(self, tree):
left = self.visit(tree.children[0])
op = self.visit(tree.children[1])
right = self.visit(tree.children[2])
return ArithmeticExpression(op, left, right)

def math_operator(self, tree):
op = str(tree.children[0])
if op == "+":
return ExpressionType.ARITHMETIC_ADD
elif op == "-":
return ExpressionType.ARITHMETIC_SUBTRACT
elif op == "*":
return ExpressionType.ARITHMETIC_MULTIPLY
elif op == "/":
return ExpressionType.ARITHMETIC_DIVIDE

def logical_expression(self, tree):
left = self.visit(tree.children[0])
op = self.visit(tree.children[1])
Expand Down

0 comments on commit 287ca98

Please # to comment.