Skip to content

Commit

Permalink
Merge pull request #63 from dssg/ep/aggregate_arithmetic
Browse files Browse the repository at this point in the history
Ep/aggregate arithmetic
  • Loading branch information
mbauman authored Mar 14, 2017
2 parents 23361d8 + 80b8b63 commit 4ae3990
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 8 deletions.
91 changes: 88 additions & 3 deletions collate/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,92 @@ def split_distinct(quantity):
return "", (q,)


class Aggregate(object):
class AggregateExpression(object):
def __init__(self, aggregate1, aggregate2, operator,
cast=None, operator_str=None, expression_template=None):
"""
Args:
aggregate1: first aggregate
aggregate2: second aggregate
operator: string of SQL operator, e.g. "+"
cast: optional string to put after aggregate1, e.g. "*1.0", "::decimal"
operator_str: optional name of operator to use, defaults to operator
expression_template: optional formatting template with the following keywords:
name1, operator, name2
"""
self.aggregate1 = aggregate1
self.aggregate2 = aggregate2
self.operator = operator
self.cast = cast if cast else ""
self.operator_str = operator if operator_str else operator
self.expression_template = expression_template \
if expression_template else "{name1}{operator}{name2}"

def alias(self, expression_template):
"""
Set the expression template used for naming columns of an AggregateExpression
Returns: self, for chaining
"""
self.expression_template = expression_template
return self

def get_columns(self, when=None, prefix=None, format_kwargs=None):
if prefix is None:
prefix = ""
if format_kwargs is None:
format_kwargs = {}

columns1 = self.aggregate1.get_columns(when)
columns2 = self.aggregate2.get_columns(when)

for c1, c2 in product(columns1, columns2):
c = ex.literal_column("({}{} {} {})".format(
c1, self.cast, self.operator, c2))
yield c.label(prefix + self.expression_template.format(
name1=c1.name, operator=self.operator_str, name2=c2.name,
**format_kwargs))

def __add__(self, other):
return AggregateExpression(self, other, "+")

def __sub__(self, other):
return AggregateExpression(self, other, "-")

def __mul__(self, other):
return AggregateExpression(self, other, "*")

def __div__(self, other):
return AggregateExpression(self, other, "/", "*1.0")

def __truediv__(self, other):
return AggregateExpression(self, other, "/", "*1.0")

def __lt__(self, other):
return AggregateExpression(self, other, "<")

def __le__(self, other):
return AggregateExpression(self, other, "<=")

def __eq__(self, other):
return AggregateExpression(self, other, "=")

def __ne__(self, other):
return AggregateExpression(self, other, "!=")

def __gt__(self, other):
return AggregateExpression(self, other, ">")

def __ge__(self, other):
return AggregateExpression(self, other, ">=")

def __or__(self, other):
return AggregateExpression(self, other, "or", operator_str="|")

def __and__(self, other):
return AggregateExpression(self, other, "and", operator_str="&")


class Aggregate(AggregateExpression):
"""
An object representing one or more SQL aggregate columns in a groupby
"""
Expand Down Expand Up @@ -228,8 +313,8 @@ def _get_aggregates_sql(self, group):
prefix = "{prefix}_{group}_".format(
prefix=self.prefix, group=group)

return chain(*(a.get_columns(prefix=prefix)
for a in self.aggregates))
return chain(*[a.get_columns(prefix=prefix)
for a in self.aggregates])

def get_selects(self):
"""
Expand Down
8 changes: 4 additions & 4 deletions collate/spacetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def _get_aggregates_sql(self, interval, date, group):
prefix=self.prefix, interval=interval,
group=group)

return chain(*(a.get_columns(when, prefix, format_kwargs={"collate_date": date,
return chain(*[a.get_columns(when, prefix, format_kwargs={"collate_date": date,
"collate_interval": interval})
for a in self.aggregates))
for a in self.aggregates])

def get_selects(self):
"""
Expand All @@ -79,8 +79,8 @@ def get_selects(self):
columns = [groupby,
ex.literal_column("'%s'::date"
% date).label(self.output_date_column)]
columns += list(chain(*(self._get_aggregates_sql(
i, date, group) for i in intervals)))
columns += list(chain(*[self._get_aggregates_sql(
i, date, group) for i in intervals]))

gb_clause = make_sql_clause(groupby, ex.literal_column)
query = ex.select(columns=columns, from_obj=self.from_obj)\
Expand Down
9 changes: 9 additions & 0 deletions tests/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ def test_aggregate_tuple_quantity_when():
"corr(CASE WHEN date < '2012-01-01' THEN x END, "
"CASE WHEN date < '2012-01-01' THEN y END)")

def test_aggregate_arithmetic():
n = collate.Aggregate("x", "sum")
d = collate.Aggregate("1", "count")
m = collate.Aggregate("y", "avg")

e = list((n/d + m).get_columns(prefix="prefix_"))[0]
assert str(e) == "((sum(x)*1.0 / count(1)) + avg(y))"
assert e.name == "prefix_x_sum/1_count+y_avg"

def test_aggregate_format_kwargs():
agg = collate.Aggregate("'{collate_date}' - date", "min")
assert str(list(agg.get_columns(format_kwargs={"collate_date":"2012-01-01"}))[0]) == (
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_engine():

def test_st_explicit_execute():
agg = Aggregate("results='Fail'",["count"])
st = SpacetimeAggregation([agg],
st = SpacetimeAggregation([agg, agg+agg],
from_obj = ex.table('food_inspections'),
groups = {'license':ex.column('license_no'),
'zip':ex.column('zip')},
Expand Down

0 comments on commit 4ae3990

Please # to comment.