diff --git a/collate/collate.py b/collate/collate.py index a8431ff..b60623d 100644 --- a/collate/collate.py +++ b/collate/collate.py @@ -29,7 +29,92 @@ def split_distinct(quantity): return "", (q,) -class Aggregate(object): +class AggregateExpression(object): + def __init__(self, aggregate1, aggregate2, operator, + cast=None, operator_str=None, expression_template=None): + """ + Args: + aggregate1: first aggregate + aggregate2: second aggregate + operator: string of SQL operator, e.g. "+" + cast: optional string to put after aggregate1, e.g. "*1.0", "::decimal" + operator_str: optional name of operator to use, defaults to operator + expression_template: optional formatting template with the following keywords: + name1, operator, name2 + """ + self.aggregate1 = aggregate1 + self.aggregate2 = aggregate2 + self.operator = operator + self.cast = cast if cast else "" + self.operator_str = operator if operator_str else operator + self.expression_template = expression_template \ + if expression_template else "{name1}{operator}{name2}" + + def alias(self, expression_template): + """ + Set the expression template used for naming columns of an AggregateExpression + Returns: self, for chaining + """ + self.expression_template = expression_template + return self + + def get_columns(self, when=None, prefix=None, format_kwargs=None): + if prefix is None: + prefix = "" + if format_kwargs is None: + format_kwargs = {} + + columns1 = self.aggregate1.get_columns(when) + columns2 = self.aggregate2.get_columns(when) + + for c1, c2 in product(columns1, columns2): + c = ex.literal_column("({}{} {} {})".format( + c1, self.cast, self.operator, c2)) + yield c.label(prefix + self.expression_template.format( + name1=c1.name, operator=self.operator_str, name2=c2.name, + **format_kwargs)) + + def __add__(self, other): + return AggregateExpression(self, other, "+") + + def __sub__(self, other): + return AggregateExpression(self, other, "-") + + def __mul__(self, other): + return AggregateExpression(self, other, "*") + + def __div__(self, other): + return AggregateExpression(self, other, "/", "*1.0") + + def __truediv__(self, other): + return AggregateExpression(self, other, "/", "*1.0") + + def __lt__(self, other): + return AggregateExpression(self, other, "<") + + def __le__(self, other): + return AggregateExpression(self, other, "<=") + + def __eq__(self, other): + return AggregateExpression(self, other, "=") + + def __ne__(self, other): + return AggregateExpression(self, other, "!=") + + def __gt__(self, other): + return AggregateExpression(self, other, ">") + + def __ge__(self, other): + return AggregateExpression(self, other, ">=") + + def __or__(self, other): + return AggregateExpression(self, other, "or", operator_str="|") + + def __and__(self, other): + return AggregateExpression(self, other, "and", operator_str="&") + + +class Aggregate(AggregateExpression): """ An object representing one or more SQL aggregate columns in a groupby """ @@ -228,8 +313,8 @@ def _get_aggregates_sql(self, group): prefix = "{prefix}_{group}_".format( prefix=self.prefix, group=group) - return chain(*(a.get_columns(prefix=prefix) - for a in self.aggregates)) + return chain(*[a.get_columns(prefix=prefix) + for a in self.aggregates]) def get_selects(self): """ diff --git a/collate/spacetime.py b/collate/spacetime.py index be9b117..ab83790 100644 --- a/collate/spacetime.py +++ b/collate/spacetime.py @@ -58,9 +58,9 @@ def _get_aggregates_sql(self, interval, date, group): prefix=self.prefix, interval=interval, group=group) - return chain(*(a.get_columns(when, prefix, format_kwargs={"collate_date": date, + return chain(*[a.get_columns(when, prefix, format_kwargs={"collate_date": date, "collate_interval": interval}) - for a in self.aggregates)) + for a in self.aggregates]) def get_selects(self): """ @@ -79,8 +79,8 @@ def get_selects(self): columns = [groupby, ex.literal_column("'%s'::date" % date).label(self.output_date_column)] - columns += list(chain(*(self._get_aggregates_sql( - i, date, group) for i in intervals))) + columns += list(chain(*[self._get_aggregates_sql( + i, date, group) for i in intervals])) gb_clause = make_sql_clause(groupby, ex.literal_column) query = ex.select(columns=columns, from_obj=self.from_obj)\ diff --git a/tests/test_collate.py b/tests/test_collate.py index 796da5c..b2e04e9 100755 --- a/tests/test_collate.py +++ b/tests/test_collate.py @@ -39,6 +39,15 @@ def test_aggregate_tuple_quantity_when(): "corr(CASE WHEN date < '2012-01-01' THEN x END, " "CASE WHEN date < '2012-01-01' THEN y END)") +def test_aggregate_arithmetic(): + n = collate.Aggregate("x", "sum") + d = collate.Aggregate("1", "count") + m = collate.Aggregate("y", "avg") + + e = list((n/d + m).get_columns(prefix="prefix_"))[0] + assert str(e) == "((sum(x)*1.0 / count(1)) + avg(y))" + assert e.name == "prefix_x_sum/1_count+y_avg" + def test_aggregate_format_kwargs(): agg = collate.Aggregate("'{collate_date}' - date", "min") assert str(list(agg.get_columns(format_kwargs={"collate_date":"2012-01-01"}))[0]) == ( diff --git a/tests/test_integration.py b/tests/test_integration.py index b4e9e65..d68786e 100755 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -25,7 +25,7 @@ def test_engine(): def test_st_explicit_execute(): agg = Aggregate("results='Fail'",["count"]) - st = SpacetimeAggregation([agg], + st = SpacetimeAggregation([agg, agg+agg], from_obj = ex.table('food_inspections'), groups = {'license':ex.column('license_no'), 'zip':ex.column('zip')},