Skip to content

Commit

Permalink
added a sub query function to collate. The sub-query is passed as SQL…
Browse files Browse the repository at this point in the history
…Alchemy select object. The date where filters are added to the sub query object within collate
  • Loading branch information
k1aus committed Dec 23, 2016
1 parent c2a63d0 commit 626a25a
Showing 1 changed file with 149 additions and 21 deletions.
170 changes: 149 additions & 21 deletions collate/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Aggregate(object):
"""
An object representing one or more SQL aggregate columns in a groupby
"""

def __init__(self, quantity, function, order=None):
"""
Args:
Expand Down Expand Up @@ -73,7 +74,7 @@ def get_columns(self, when=None, prefix=None, format_kwargs=None):
arg_template = "CASE WHEN {when} THEN {quantity} END"

for function, (quantity_name, quantity), order in product(
self.functions, self.quantities.items(), self.orders):
self.functions, self.quantities.items(), self.orders):
args = str.join(", ", (arg_template.format(when=when, quantity=q)
for q in make_tuple(quantity)))
order_clause = order_template.format(when=when, order=order)
Expand All @@ -93,7 +94,7 @@ def __init__(self, aggregates, groups, from_obj, prefix=None, suffix=None, schem
"""
Args:
aggregates: collection of Aggregate objects.
from_obj: defines the from clause, e.g. the name of the table. can use
from_obj: defines the from clause, e.g. the name of the table. can use
groups: a list of expressions to group by in the aggregation or a dictionary
pairs group: expr pairs where group is the alias (used in column names)
prefix: prefix for aggregation tables and column names, defaults to from_obj
Expand Down Expand Up @@ -123,7 +124,7 @@ def _get_aggregates_sql(self, group):
Returns: collection of aggregate column SQL strings
"""
prefix = "{prefix}_{group}_".format(
prefix=self.prefix, group=group)
prefix=self.prefix, group=group)

return chain(*(a.get_columns(prefix=prefix)
for a in self.aggregates))
Expand All @@ -143,8 +144,8 @@ def get_selects(self):
columns += self._get_aggregates_sql(group)

gb_clause = make_sql_clause(groupby, ex.literal_column)
query = ex.select(columns=columns, from_obj=self.from_obj)\
.group_by(gb_clause)
query = ex.select(columns=columns, from_obj=self.from_obj) \
.group_by(gb_clause)

queries[group] = [query]

Expand Down Expand Up @@ -214,15 +215,15 @@ def get_indexes(self):
index is a raw create index query for the corresponding table
"""
return {group: "CREATE INDEX ON %s (%s);" %
(self.get_table_name(group), groupby)
(self.get_table_name(group), groupby)
for group, groupby in self.groups.items()}

def get_join_table(self):
"""
Generate a query for a join table
"""
return ex.Select(columns=self.groups.values(), from_obj=self.from_obj)\
.group_by(*self.groups.values())
return ex.Select(columns=self.groups.values(), from_obj=self.from_obj) \
.group_by(*self.groups.values())

def get_create(self, join_table=None):
"""
Expand All @@ -236,7 +237,7 @@ def get_create(self, join_table=None):
query = "SELECT * FROM %s\n" % join_table
for group, groupby in self.groups.items():
query += "LEFT JOIN %s USING (%s)" % (
self.get_table_name(group), groupby)
self.get_table_name(group), groupby)

return "CREATE TABLE %s AS (%s);" % (self.get_table_name(), query)

Expand Down Expand Up @@ -334,13 +335,13 @@ def _get_aggregates_sql(self, interval, date, group):
"""
if interval != 'all':
when = "{date_column} >= '{date}'::date - interval '{interval}'".format(
interval=interval, date=date, date_column=self.date_column)
interval=interval, date=date, date_column=self.date_column)
else:
when = None

prefix = "{prefix}_{group}_{interval}_".format(
prefix=self.prefix, interval=interval,
group=group)
prefix=self.prefix, interval=interval,
group=group)

return chain(*(a.get_columns(when, prefix, format_kwargs={"collate_date": date})
for a in self.aggregates))
Expand All @@ -363,20 +364,20 @@ def get_selects(self):
ex.literal_column("'%s'::date"
% date).label(self.output_date_column)]
columns += list(chain(*(self._get_aggregates_sql(
i, date, group) for i in intervals)))
i, date, group) for i in intervals)))

# upper bound on date_column by date
where = ex.text("{date_column} < '{date}'".format(
date_column=self.date_column, date=date))
date_column=self.date_column, date=date))

gb_clause = make_sql_clause(groupby, ex.literal_column)
query = ex.select(columns=columns, from_obj=self.from_obj)\
.where(where)\
.group_by(gb_clause)
query = ex.select(columns=columns, from_obj=self.from_obj) \
.where(where) \
.group_by(gb_clause)

if 'all' not in intervals:
greatest = "greatest(%s)" % str.join(
",", ["interval '%s'" % i for i in intervals])
",", ["interval '%s'" % i for i in intervals])
query = query.where(ex.text(
"{date_column} >= '{date}'::date - {greatest}".format(
date_column=self.date_column, date=date,
Expand All @@ -395,7 +396,7 @@ def get_indexes(self):
index is a raw create index query for the corresponding table
"""
return {group: "CREATE INDEX ON %s (%s, %s);" %
(self.get_table_name(group), groupby, self.output_date_column)
(self.get_table_name(group), groupby, self.output_date_column)
for group, groupby in self.groups.items()}

def get_create(self, join_table=None):
Expand All @@ -409,9 +410,136 @@ def get_create(self, join_table=None):

query = ("SELECT * FROM %s\n"
"CROSS JOIN (select unnest('{%s}'::date[]) as %s) t2\n") % (
join_table, str.join(',', self.dates), self.output_date_column)
join_table, str.join(',', self.dates), self.output_date_column)
for group, groupby in self.groups.items():
query += "LEFT JOIN %s USING (%s, %s)" % (
self.get_table_name(group), groupby, self.output_date_column)

return "CREATE TABLE %s AS (%s);" % (self.get_table_name(), query)


class SpacetimeSubQueryAggregation(SpacetimeAggregation):
def __init__(self, aggregates, groups, intervals, from_obj, dates,
prefix=None, suffix=None, schema=None, date_column=None, output_date_column=None,
sub_query=None, join_table=None):
"""
Args:
aggregates: collection of Aggregate objects
from_obj: defines the name of the sub query
groups: a list of expressions to group by in the aggregation or a dictionary
pairs group: expr pairs where group is the alias (used in column names)
intervals: the intervals to aggregate over. either a list of
datetime intervals, e.g. ["1 month", "1 year"], or
a dictionary of group : intervals pairs where
group is a group in groups and intervals is a collection
of datetime intervals, e.g. {"address_id": ["1 month", "1 year]}
dates: list of PostgreSQL date strings,
e.g. ["2012-01-01", "2013-01-01"]
prefix: prefix for column names, defaults to from_obj
suffix: suffix for aggregation table, defaults to "aggregation"
date_column: name of date column in from_obj, defaults to "date"
output_date_column: name of date column in aggregated output, defaults to "date"
join_table: specify a join table, i.e. a table containing unique sets of all possible
valid groups to left join the aggregations onto.
Defaults to None, in which case this table is created by querying the from_obj.
The group arguments is passed directly to the
SQLAlchemy Select object so could be anything supported there.
For details see:
http://docs.sqlalchemy.org/en/latest/core/selectable.html
"""
Aggregation.__init__(self,
aggregates=aggregates,
from_obj=from_obj,
groups=groups,
prefix=prefix,
suffix=suffix,
schema=schema)

if isinstance(intervals, dict):
self.intervals = intervals
else:
self.intervals = {g: intervals for g in self.groups}
self.dates = dates
self.date_column = date_column if date_column else "date"
self.output_date_column = output_date_column if output_date_column else "date"
self.sub_query = sub_query
self.join_table = join_table

def get_selects(self):
"""
Constructs select queries for this aggregation using a sub query
Returns: a dictionary of group : queries pairs where
group are the same keys as groups
queries is a list of Select queries, one for each date in dates
"""
queries = {}

for group, groupby in self.groups.items():
intervals = self.intervals[group]
queries[group] = []
for date in self.dates:
# sub query

# upper bound on date_column by date
where = ex.text("{date_column} < '{date}'".format(
date_column=self.date_column, date=date))

# the where clause is applied at the the sub_query as this query can make use of indices
sub_query = self.sub_query.where(where)

if 'all' not in intervals:
greatest = "greatest(%s)" % str.join(
",", ["interval '%s'" % i for i in intervals])
sub_query = sub_query.where(ex.text(
"{date_column} >= '{date}'::date - {greatest}".format(
date_column=self.date_column, date=date,
greatest=greatest)))

# name the sub query
sub_query = sub_query.alias(str(self.from_obj))

# main query
columns = [groupby,
ex.literal_column("'%s'::date"
% date).label(self.output_date_column)]
columns += list(chain(*(self._get_aggregates_sql(
i, date, group) for i in intervals)))

gb_clause = make_sql_clause(groupby, ex.literal_column)

# note: there is no where clause as the filtering is applied at the sub query level
query = ex.select(columns=columns, from_obj=sub_query) \
.group_by(gb_clause)

queries[group].append(query)

return queries

def get_join_table(self):
"""
Generate a query for a join table
"""
if self.join_table is not None:
return '(%s) t1' % ex.Select(columns=self.groups.values(), from_obj=self.join_table) \
.group_by(*self.groups.values())
else:
return '(%s) t1' % ex.Select(columns=self.groups.values(), from_obj=self.from_obj) \
.group_by(*self.groups.values())

def get_create(self):
"""
Generate a single aggregation table creation query by joining
together the results of get_creates()
Returns: a CREATE TABLE AS query
"""
query = ("SELECT * FROM %s\n"
"CROSS JOIN (select unnest('{%s}'::date[]) as %s) t2\n") % (
self.get_join_table(), str.join(',', self.dates), self.output_date_column)
for group, groupby in self.groups.items():
query += "LEFT JOIN %s USING (%s, %s)" % (
self.get_table_name(group), groupby, self.output_date_column)
self.get_table_name(group), groupby, self.output_date_column)


return "CREATE TABLE %s AS (%s);" % (self.get_table_name(), query)

0 comments on commit 626a25a

Please # to comment.