From db2f8aedc583f687312ff0ae8376471522cfadcc Mon Sep 17 00:00:00 2001 From: Eric Potash Date: Thu, 17 Nov 2016 18:45:13 +0000 Subject: [PATCH 01/11] initial operator overloading --- collate/collate.py | 34 +++++++++++++++++++++++++++++++++- tests/test_collate.py | 7 +++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/collate/collate.py b/collate/collate.py index e19380a..e93982d 100644 --- a/collate/collate.py +++ b/collate/collate.py @@ -39,7 +39,39 @@ def to_sql_name(name): return name.replace('"', '') -class Aggregate(object): +class AggregateExpression(object): + def __init__(self, aggregates, operator): + self.aggregates = aggregates + self.operator = operator + + def get_columns(self, when=None, prefix=None): + if prefix is None: + prefix = "" + + columns0 = self.aggregates[0].get_columns() + columns1 = self.aggregates[1].get_columns() + + for c0, c1 in product(columns0, columns1): + c = ex.literal_column("({} {} {})".format(c0,self.operator,c1))\ + .label("{}{}{}{}".format(prefix, c0.name, self.operator, c1.name)) + yield c + + # TODO: floordiv and truediv for py3 + def __add__(self, other): + return AggregateExpression([self, other], "+") + + def __sub__(self, other): + return AggregateExpression([self, other], "-") + + def __mul__(self, other): + return AggregateExpression([self, other], "*") + + def __div__(self, other): + return AggregateExpression([self, other], "/") + + + +class Aggregate(AggregateExpression): """ An object representing one or more SQL aggregate columns in a groupby """ diff --git a/tests/test_collate.py b/tests/test_collate.py index 2ce68c8..b24c68f 100755 --- a/tests/test_collate.py +++ b/tests/test_collate.py @@ -38,3 +38,10 @@ def test_aggregate_tuple_quantity_when(): assert str(list(agg.get_columns(when="date < '2012-01-01'"))[0]) == ( "corr(CASE WHEN date < '2012-01-01' THEN x END, " "CASE WHEN date < '2012-01-01' THEN y END)") + +def test_aggregate_div(): + n = collate.Aggregate("x", "sum") + d = collate.Aggregate("1", "count") + + print list((n/d + d - n).get_columns())[0].name + From 3f35a15faa8ce011a9da7f02a4df48e3ffe5bfa7 Mon Sep 17 00:00:00 2001 From: Eric Potash Date: Thu, 17 Nov 2016 18:55:48 +0000 Subject: [PATCH 02/11] syntax, when --- collate/collate.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/collate/collate.py b/collate/collate.py index e93982d..1d3f439 100644 --- a/collate/collate.py +++ b/collate/collate.py @@ -40,21 +40,23 @@ def to_sql_name(name): class AggregateExpression(object): - def __init__(self, aggregates, operator): + def __init__(self, aggregates, operator, cast=None): self.aggregates = aggregates self.operator = operator + self.cast = cast if cast else "" def get_columns(self, when=None, prefix=None): if prefix is None: prefix = "" - columns0 = self.aggregates[0].get_columns() - columns1 = self.aggregates[1].get_columns() + columns0 = self.aggregates[0].get_columns(when) + columns1 = self.aggregates[1].get_columns(when) for c0, c1 in product(columns0, columns1): - c = ex.literal_column("({} {} {})".format(c0,self.operator,c1))\ - .label("{}{}{}{}".format(prefix, c0.name, self.operator, c1.name)) - yield c + c = ex.literal_column("({}{} {} {})".format( + c0, self.cast, self.operator, c1)) + yield c.label("{}{}{}{}".format( + prefix, c0.name, self.operator, c1.name)) # TODO: floordiv and truediv for py3 def __add__(self, other): @@ -67,8 +69,7 @@ def __mul__(self, other): return AggregateExpression([self, other], "*") def __div__(self, other): - return AggregateExpression([self, other], "/") - + return AggregateExpression([self, other], "/", "*1.0") class Aggregate(AggregateExpression): From b6cbaaa0a3a70e7f95ace72053c73d9f2696a57e Mon Sep 17 00:00:00 2001 From: Eric Potash Date: Thu, 17 Nov 2016 19:54:37 +0000 Subject: [PATCH 03/11] real test --- tests/test_collate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_collate.py b/tests/test_collate.py index b24c68f..facedb4 100755 --- a/tests/test_collate.py +++ b/tests/test_collate.py @@ -42,6 +42,7 @@ def test_aggregate_tuple_quantity_when(): def test_aggregate_div(): n = collate.Aggregate("x", "sum") d = collate.Aggregate("1", "count") + m = collate.Aggregate("y", "avg") - print list((n/d + d - n).get_columns())[0].name + assert str(list((n/d + m).get_columns())[0]) == "((sum(x)*1.0 / count(1)) + avg(y))" From b25ab2505038407c4aae545108f8578099eafa52 Mon Sep 17 00:00:00 2001 From: Eric Potash Date: Mon, 12 Dec 2016 19:20:42 +0000 Subject: [PATCH 04/11] truediv --- collate/collate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/collate/collate.py b/collate/collate.py index 1d3f439..42a1eab 100644 --- a/collate/collate.py +++ b/collate/collate.py @@ -58,7 +58,6 @@ def get_columns(self, when=None, prefix=None): yield c.label("{}{}{}{}".format( prefix, c0.name, self.operator, c1.name)) - # TODO: floordiv and truediv for py3 def __add__(self, other): return AggregateExpression([self, other], "+") @@ -71,6 +70,9 @@ def __mul__(self, other): def __div__(self, other): return AggregateExpression([self, other], "/", "*1.0") + def __truediv__(self, other): + return AggregateExpression([self, other], "/", "*1.0") + class Aggregate(AggregateExpression): """ From c8432159d16c5a0b2bb974471c2a65ec36248907 Mon Sep 17 00:00:00 2001 From: Eric Potash Date: Mon, 27 Feb 2017 19:30:21 +0000 Subject: [PATCH 05/11] arithmetic --- collate/collate.py | 40 ++++++++++++++++++++++++++++++++++++++-- tests/test_collate.py | 6 ++++-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/collate/collate.py b/collate/collate.py index d0fb6d3..50784ee 100644 --- a/collate/collate.py +++ b/collate/collate.py @@ -34,6 +34,18 @@ def __init__(self, aggregates, operator, cast=None): self.aggregates = aggregates self.operator = operator self.cast = cast if cast else "" + self.expression_template = "{name1}{operator}{name2}" + + def alias(self, expression_template): + """ + Set the expression template used for naming columns of an AggregateExpression + Args: + expression_template: formatting template with the following keywords: + name0, operator, name1 + Returns: self, for chaining + """ + self.expression_template = expression_template + return self def get_columns(self, when=None, prefix=None): if prefix is None: @@ -45,8 +57,8 @@ def get_columns(self, when=None, prefix=None): for c0, c1 in product(columns0, columns1): c = ex.literal_column("({}{} {} {})".format( c0, self.cast, self.operator, c1)) - yield c.label("{}{}{}{}".format( - prefix, c0.name, self.operator, c1.name)) + yield c.label(prefix + self.expression_template.format( + name1=c0.name, operator=self.operator, name2=c1.name)) def __add__(self, other): return AggregateExpression([self, other], "+") @@ -63,6 +75,30 @@ def __div__(self, other): def __truediv__(self, other): return AggregateExpression([self, other], "/", "*1.0") + def __lt__(self, other): + return AggregateExpression([self, other], "<") + + def __le__(self, other): + return AggregateExpression([self, other], "<=") + + def __eq__(self, other): + return AggregateExpression([self, other], "=") + + def __ne__(self, other): + return AggregateExpression([self, other], "!=") + + def __gt__(self, other): + return AggregateExpression([self, other], ">") + + def __ge__(self, other): + return AggregateExpression([self, other], ">=") + + def __or__(self, other): + return AggregateExpression([self, other], "|") + + def __and__(self, other): + return AggregateExpression([self, other], "&") + class Aggregate(AggregateExpression): """ diff --git a/tests/test_collate.py b/tests/test_collate.py index 5444afd..b2e04e9 100755 --- a/tests/test_collate.py +++ b/tests/test_collate.py @@ -39,12 +39,14 @@ def test_aggregate_tuple_quantity_when(): "corr(CASE WHEN date < '2012-01-01' THEN x END, " "CASE WHEN date < '2012-01-01' THEN y END)") -def test_aggregate_div(): +def test_aggregate_arithmetic(): n = collate.Aggregate("x", "sum") d = collate.Aggregate("1", "count") m = collate.Aggregate("y", "avg") - assert str(list((n/d + m).get_columns())[0]) == "((sum(x)*1.0 / count(1)) + avg(y))" + e = list((n/d + m).get_columns(prefix="prefix_"))[0] + assert str(e) == "((sum(x)*1.0 / count(1)) + avg(y))" + assert e.name == "prefix_x_sum/1_count+y_avg" def test_aggregate_format_kwargs(): agg = collate.Aggregate("'{collate_date}' - date", "min") From e210332c906d6e728cdadf2ebf14c424caa16522 Mon Sep 17 00:00:00 2001 From: Eric Potash Date: Mon, 27 Feb 2017 19:48:24 +0000 Subject: [PATCH 06/11] operator_str --- collate/collate.py | 60 +++++++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/collate/collate.py b/collate/collate.py index 50784ee..3346e9c 100644 --- a/collate/collate.py +++ b/collate/collate.py @@ -30,18 +30,30 @@ def split_distinct(quantity): class AggregateExpression(object): - def __init__(self, aggregates, operator, cast=None): - self.aggregates = aggregates + def __init__(self, aggregate1, aggregate2, operator, + cast=None, operator_str=None, expression_template=None): + """ + Args: + aggregate1: + aggregate2: + operator: string of SQL operator, e.g. "+" + cast: string to put after aggregate1, e.g. "*1.0", "::decimal" + defaults to empty + operator_str: name of operator to use + expression_template: formatting template with the following keywords: + name1, operator, name2 + """ + self.aggregate1 = aggregate1 + self.aggregate2 = aggregate2 self.operator = operator self.cast = cast if cast else "" - self.expression_template = "{name1}{operator}{name2}" + self.operator_str = operator if operator_str else operator + self.expression_template = expression_template \ + if expression_template else "{name1}{operator}{name2}" def alias(self, expression_template): """ Set the expression template used for naming columns of an AggregateExpression - Args: - expression_template: formatting template with the following keywords: - name0, operator, name1 Returns: self, for chaining """ self.expression_template = expression_template @@ -51,53 +63,53 @@ def get_columns(self, when=None, prefix=None): if prefix is None: prefix = "" - columns0 = self.aggregates[0].get_columns(when) - columns1 = self.aggregates[1].get_columns(when) + columns1 = self.aggregate1.get_columns(when) + columns2 = self.aggregate2.get_columns(when) - for c0, c1 in product(columns0, columns1): + for c1, c2 in product(columns1, columns2): c = ex.literal_column("({}{} {} {})".format( - c0, self.cast, self.operator, c1)) + c1, self.cast, self.operator, c2)) yield c.label(prefix + self.expression_template.format( - name1=c0.name, operator=self.operator, name2=c1.name)) + name1=c1.name, operator=self.operator_str, name2=c2.name)) def __add__(self, other): - return AggregateExpression([self, other], "+") + return AggregateExpression(self, other, "+") def __sub__(self, other): - return AggregateExpression([self, other], "-") + return AggregateExpression(self, other, "-") def __mul__(self, other): - return AggregateExpression([self, other], "*") + return AggregateExpression(self, other, "*") def __div__(self, other): - return AggregateExpression([self, other], "/", "*1.0") + return AggregateExpression(self, other, "/", "*1.0") def __truediv__(self, other): - return AggregateExpression([self, other], "/", "*1.0") + return AggregateExpression(self, other, "/", "*1.0") def __lt__(self, other): - return AggregateExpression([self, other], "<") + return AggregateExpression(self, other, "<") def __le__(self, other): - return AggregateExpression([self, other], "<=") + return AggregateExpression(self, other, "<=") def __eq__(self, other): - return AggregateExpression([self, other], "=") + return AggregateExpression(self, other, "=") def __ne__(self, other): - return AggregateExpression([self, other], "!=") + return AggregateExpression(self, other, "!=") def __gt__(self, other): - return AggregateExpression([self, other], ">") + return AggregateExpression(self, other, ">") def __ge__(self, other): - return AggregateExpression([self, other], ">=") + return AggregateExpression(self, other, ">=") def __or__(self, other): - return AggregateExpression([self, other], "|") + return AggregateExpression(self, other, "or", operator_str="|") def __and__(self, other): - return AggregateExpression([self, other], "&") + return AggregateExpression(self, other, "and", operator_str="&") class Aggregate(AggregateExpression): From 5f9e2ac7ebb4ce08b5244f2d2cf3872496896ea3 Mon Sep 17 00:00:00 2001 From: Eric Potash Date: Mon, 27 Feb 2017 19:53:37 +0000 Subject: [PATCH 07/11] docstring --- collate/collate.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/collate/collate.py b/collate/collate.py index 3346e9c..6862600 100644 --- a/collate/collate.py +++ b/collate/collate.py @@ -34,13 +34,12 @@ def __init__(self, aggregate1, aggregate2, operator, cast=None, operator_str=None, expression_template=None): """ Args: - aggregate1: - aggregate2: + aggregate1: first aggregate + aggregate2: second aggregate operator: string of SQL operator, e.g. "+" - cast: string to put after aggregate1, e.g. "*1.0", "::decimal" - defaults to empty - operator_str: name of operator to use - expression_template: formatting template with the following keywords: + cast: optional string to put after aggregate1, e.g. "*1.0", "::decimal" + operator_str: optional name of operator to use, defaults to operator + expression_template: optional formatting template with the following keywords: name1, operator, name2 """ self.aggregate1 = aggregate1 From 192f133e57ed075b2e458a3f0e8e93562d013aed Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Mon, 27 Feb 2017 23:22:24 +0000 Subject: [PATCH 08/11] Be less lazy for Python 2.7 --- collate/collate.py | 4 ++-- collate/spacetime.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/collate/collate.py b/collate/collate.py index 6862600..d8a62e8 100644 --- a/collate/collate.py +++ b/collate/collate.py @@ -310,8 +310,8 @@ def _get_aggregates_sql(self, group): prefix = "{prefix}_{group}_".format( prefix=self.prefix, group=group) - return chain(*(a.get_columns(prefix=prefix) - for a in self.aggregates)) + return chain(*[a.get_columns(prefix=prefix) + for a in self.aggregates]) def get_selects(self): """ diff --git a/collate/spacetime.py b/collate/spacetime.py index 0ae0e51..ae1d48a 100644 --- a/collate/spacetime.py +++ b/collate/spacetime.py @@ -58,8 +58,8 @@ def _get_aggregates_sql(self, interval, date, group): prefix=self.prefix, interval=interval, group=group) - return chain(*(a.get_columns(when, prefix, format_kwargs={"collate_date": date}) - for a in self.aggregates)) + return chain(*[a.get_columns(when, prefix, format_kwargs={"collate_date": date}) + for a in self.aggregates]) def get_selects(self): """ @@ -78,8 +78,8 @@ def get_selects(self): columns = [groupby, ex.literal_column("'%s'::date" % date).label(self.output_date_column)] - columns += list(chain(*(self._get_aggregates_sql( - i, date, group) for i in intervals))) + columns += list(chain(*[self._get_aggregates_sql( + i, date, group) for i in intervals])) gb_clause = make_sql_clause(groupby, ex.literal_column) query = ex.select(columns=columns, from_obj=self.from_obj)\ From 8890de18b081a5e99b63e0dba5523476060f1512 Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Mon, 27 Feb 2017 23:26:38 +0000 Subject: [PATCH 09/11] Support format_kwargs when getting columns of AggregateExpressions --- collate/collate.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/collate/collate.py b/collate/collate.py index d8a62e8..b60623d 100644 --- a/collate/collate.py +++ b/collate/collate.py @@ -58,9 +58,11 @@ def alias(self, expression_template): self.expression_template = expression_template return self - def get_columns(self, when=None, prefix=None): + def get_columns(self, when=None, prefix=None, format_kwargs=None): if prefix is None: prefix = "" + if format_kwargs is None: + format_kwargs = {} columns1 = self.aggregate1.get_columns(when) columns2 = self.aggregate2.get_columns(when) @@ -69,7 +71,8 @@ def get_columns(self, when=None, prefix=None): c = ex.literal_column("({}{} {} {})".format( c1, self.cast, self.operator, c2)) yield c.label(prefix + self.expression_template.format( - name1=c1.name, operator=self.operator_str, name2=c2.name)) + name1=c1.name, operator=self.operator_str, name2=c2.name, + **format_kwargs)) def __add__(self, other): return AggregateExpression(self, other, "+") From 8a9325de662df116dc018e56277358e14788fd56 Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Tue, 28 Feb 2017 19:05:31 +0000 Subject: [PATCH 10/11] Add trivial integration test --- tests/test_integration.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_integration.py b/tests/test_integration.py index b4e9e65..867b90e 100755 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -24,8 +24,8 @@ def test_engine(): assert len(engine.execute("SELECT * FROM food_inspections").fetchall()) == 966 def test_st_explicit_execute(): - agg = Aggregate("results='Fail'",["count"]) - st = SpacetimeAggregation([agg], + agg = Aggregate("results='Fail'",["sum"]) + st = SpacetimeAggregation([agg, agg+agg], from_obj = ex.table('food_inspections'), groups = {'license':ex.column('license_no'), 'zip':ex.column('zip')}, @@ -38,7 +38,7 @@ def test_st_explicit_execute(): st.execute(engine.connect()) def test_st_lazy_execute(): - agg = Aggregate("results='Fail'",["count"]) + agg = Aggregate("results='Fail'",["sum"]) st = SpacetimeAggregation([agg], from_obj = 'food_inspections', groups = ['license_no', 'zip'], @@ -50,7 +50,7 @@ def test_st_lazy_execute(): st.execute(engine.connect()) def test_st_execute_broadcast_intervals(): - agg = Aggregate("results='Fail'",["count"]) + agg = Aggregate("results='Fail'",["sum"]) st = SpacetimeAggregation([agg], from_obj = 'food_inspections', groups = ['license_no', 'zip'], @@ -61,7 +61,7 @@ def test_st_execute_broadcast_intervals(): st.execute(engine.connect()) def test_execute(): - agg = Aggregate("results='Fail'",["count"]) + agg = Aggregate("results='Fail'",["sum"]) st = Aggregation([agg], from_obj = 'food_inspections', groups = ['license_no', 'zip']) @@ -69,7 +69,7 @@ def test_execute(): st.execute(engine.connect()) def test_execute_schema_output_date_column(): - agg = Aggregate("results='Fail'",["count"]) + agg = Aggregate("results='Fail'",["sum"]) st = SpacetimeAggregation([agg], from_obj = 'food_inspections', groups = ['license_no', 'zip'], From 80b8b631dc90f96f04e5651eaf6389743a37e885 Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Wed, 1 Mar 2017 09:47:21 -0600 Subject: [PATCH 11/11] Go back to using the count function it's nonsensical, but it works for now. --- tests/test_integration.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_integration.py b/tests/test_integration.py index 867b90e..d68786e 100755 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -24,7 +24,7 @@ def test_engine(): assert len(engine.execute("SELECT * FROM food_inspections").fetchall()) == 966 def test_st_explicit_execute(): - agg = Aggregate("results='Fail'",["sum"]) + agg = Aggregate("results='Fail'",["count"]) st = SpacetimeAggregation([agg, agg+agg], from_obj = ex.table('food_inspections'), groups = {'license':ex.column('license_no'), @@ -38,7 +38,7 @@ def test_st_explicit_execute(): st.execute(engine.connect()) def test_st_lazy_execute(): - agg = Aggregate("results='Fail'",["sum"]) + agg = Aggregate("results='Fail'",["count"]) st = SpacetimeAggregation([agg], from_obj = 'food_inspections', groups = ['license_no', 'zip'], @@ -50,7 +50,7 @@ def test_st_lazy_execute(): st.execute(engine.connect()) def test_st_execute_broadcast_intervals(): - agg = Aggregate("results='Fail'",["sum"]) + agg = Aggregate("results='Fail'",["count"]) st = SpacetimeAggregation([agg], from_obj = 'food_inspections', groups = ['license_no', 'zip'], @@ -61,7 +61,7 @@ def test_st_execute_broadcast_intervals(): st.execute(engine.connect()) def test_execute(): - agg = Aggregate("results='Fail'",["sum"]) + agg = Aggregate("results='Fail'",["count"]) st = Aggregation([agg], from_obj = 'food_inspections', groups = ['license_no', 'zip']) @@ -69,7 +69,7 @@ def test_execute(): st.execute(engine.connect()) def test_execute_schema_output_date_column(): - agg = Aggregate("results='Fail'",["sum"]) + agg = Aggregate("results='Fail'",["count"]) st = SpacetimeAggregation([agg], from_obj = 'food_inspections', groups = ['license_no', 'zip'],