Skip to content

Commit

Permalink
bsq updates
Browse files Browse the repository at this point in the history
  • Loading branch information
rajeee committed Jan 13, 2025
1 parent 8cfc9bd commit 424e40b
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 13 deletions.
6 changes: 2 additions & 4 deletions buildstock_query/aggregate_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def aggregate_annual(self, *,
upgrade_id = self._bsq._validate_upgrade(params.upgrade_id)
enduse_cols = self._bsq._get_enduse_cols(params.enduses, table='upgrade')
total_weight = self._bsq._get_weight(weights)
agg_func, agg_weight = self._bsq._get_agg_func_and_weight(weights, params.agg_func)
enduse_selection = [agg_func(enduse * agg_weight).label(self._bsq._simple_label(enduse.name, params.agg_func))
enduse_selection = [self._bsq._agg_column(enduse, total_weight, params.agg_func)
for enduse in enduse_cols]
if params.get_quartiles:
enduse_selection += [sa.func.approx_percentile(enduse, [0, 0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98, 1]).label(
Expand Down Expand Up @@ -140,8 +139,7 @@ def aggregate_timeseries(self, params: TSQuery):
[self._bsq._get_table(jl[0]) for jl in params.join_list] # ingress all tables in join list
enduses_cols = self._bsq._get_enduse_cols(params.enduses, table='timeseries')
total_weight = self._bsq._get_weight(params.weights)
agg_func, agg_weight = self._bsq._get_agg_func_and_weight(params.weights, params.agg_func)
enduse_selection = [agg_func(enduse * agg_weight).label(self._bsq._simple_label(enduse.name, params.agg_func))
enduse_selection = [self._bsq._agg_column(enduse, total_weight, params.agg_func)
for enduse in enduses_cols]
group_by = list(params.group_by)
if self._bsq.timestamp_column_name not in group_by and params.collapse_ts:
Expand Down
2 changes: 1 addition & 1 deletion buildstock_query/db_schema/resstock_default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ timestamp = "time"
completed_status = "completed_status"
unmet_hours_cooling_hr = "report_simulation_output.unmet_hours_cooling_hr"
unmet_hours_heating_hr = "report_simulation_output.unmet_hours_heating_hr"
upgrade = "apply_upgrade.upgrade"
upgrade = "upgrade"
fuel_totals = [
'report_simulation_output.energy_use_total_m_btu',
'report_simulation_output.fuel_use_coal_total_m_btu',
Expand Down
5 changes: 4 additions & 1 deletion buildstock_query/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ def get_upgrade_names(self, get_query_only: Literal[True]) -> str:
def get_upgrade_names(self, get_query_only: bool = False) -> Union[str, dict]:
if self.up_table is None:
raise ValueError("This run has no upgrades")
upgrade_table = self._compile(self.up_table)
if isinstance(self.up_table, sa.Table):
upgrade_table = self.up_table.name
else:
upgrade_table = self._compile(self.up_table)
upgrade_col = self.db_schema.column_names.upgrade
query = f"""
Select cast(upgrade as integer) as upgrade, arbitrary("{upgrade_col}") as upgrade_name
Expand Down
17 changes: 10 additions & 7 deletions buildstock_query/query_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,24 +1043,27 @@ def _add_order_by(self, query, order_by_selection):
query = query.order_by(*a)
return query

def _get_weight(self, weights):
def _get_weight(self, weight_cols):
total_weight = self.sample_wt
for weight_col in weights:
for weight_col in weight_cols:
if isinstance(weight_col, tuple):
tbl = self._get_table(weight_col[1])
total_weight *= tbl.c[weight_col[0]]
else:
total_weight *= self._get_column(weight_col)
return total_weight

def _get_agg_func_and_weight(self, weights, agg_func=None):
if agg_func is None or agg_func == 'sum':
return safunc.sum, self._get_weight(weights)
def _agg_column(self, column: DBColType, weights, agg_func=None):
label = self._simple_label(column.name, agg_func)
if callable(agg_func):
return agg_func, 1
return agg_func(column).label(label)
if agg_func is None or agg_func in ['sum']:
return safunc.sum(column * weights).label(label)
if agg_func in ['avg']:
return (safunc.sum(column * weights) / safunc.sum(weights)).label(label)
assert isinstance(agg_func, str), f"agg_func {agg_func} is not a string or callable"
agg_func = getattr(safunc, agg_func)
return agg_func, 1
return agg_func(column).label(label)

def delete_everything(self):
"""Deletes the athena tables and data in s3 for the run.
Expand Down

0 comments on commit 424e40b

Please # to comment.