Skip to content

Commit

Permalink
support parallel MinTrace optimization (#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Nov 16, 2023
1 parent 1ee3133 commit 2067699
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 38 deletions.
4 changes: 2 additions & 2 deletions hierarchicalforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def _build_fn_name(fn) -> str:
func_params = fn.__dict__

# Take default parameter out of names
args_to_remove = ['insample']
args_to_remove = ['insample', 'num_threads']
if not func_params.get('nonnegative', False):
args_to_remove += ['nonnegative']
args_to_remove.append('nonnegative')

if fn_name == 'MinTrace' and \
func_params['method']=='mint_shrink':
Expand Down
43 changes: 29 additions & 14 deletions hierarchicalforecast/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# %% ../nbs/methods.ipynb 3
import warnings
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -136,7 +137,7 @@ class BottomUp(HReconciler):
**References:**<br>
- [Orcutt, G.H., Watts, H.W., & Edwards, J.B.(1968). \"Data aggregation and information loss\". The American
Economic Review, 58 , 773{787)](http://www.jstor.org/stable/1815532).
Economic Review, 58 , 773(787)](http://www.jstor.org/stable/1815532).
"""
insample = False

Expand Down Expand Up @@ -565,13 +566,16 @@ class MinTrace(HReconciler):
minimizes the squared errors for the coherent forecasts under an unbiasedness assumption; the solution has a
closed form.<br>
$$\mathbf{P}_{\\text{MinT}}=\\left(\mathbf{S}^{\intercal}\mathbf{W}_{h}\mathbf{S}\\right)^{-1}
\mathbf{S}^{\intercal}\mathbf{W}^{-1}_{h}$$
$$
\mathbf{P}_{\\text{MinT}}=\\left(\mathbf{S}^{\intercal}\mathbf{W}_{h}\mathbf{S}\\right)^{-1}
\mathbf{S}^{\intercal}\mathbf{W}^{-1}_{h}
$$
**Parameters:**<br>
`method`: str, one of `ols`, `wls_struct`, `wls_var`, `mint_shrink`, `mint_cov`.<br>
`nonnegative`: bool, reconciled forecasts should be nonnegative?<br>
`mint_shr_ridge`: float=2e-8, ridge numeric protection to MinTrace-shr covariance estimator.<br>
`num_threads`: int=1, number of threads to use for solving the optimization problems.
**References:**<br>
- [Wickramasuriya, S. L., Athanasopoulos, G., & Hyndman, R. J. (2019). \"Optimal forecast reconciliation for
Expand All @@ -584,12 +588,14 @@ class MinTrace(HReconciler):
def __init__(self,
method: str,
nonnegative: bool = False,
mint_shr_ridge: Optional[float] = 2e-8):
mint_shr_ridge: Optional[float] = 2e-8,
num_threads: int = 1):
self.method = method
self.nonnegative = nonnegative
self.insample = method in ['wls_var', 'mint_cov', 'mint_shrink']
if method == 'mint_shrink':
self.mint_shr_ridge = mint_shr_ridge
self.num_threads = num_threads

def _get_PW_matrices(self,
S: np.ndarray,
Expand Down Expand Up @@ -708,9 +714,11 @@ def fit(self,
if self.nonnegative:
_, n_bottom = S.shape
W_inv = np.linalg.pinv(self.W)
warnings.warn('Replacing negative forecasts with zero.')
y_hat = np.copy(y_hat)
y_hat[y_hat < 0] = 0.
negatives = y_hat < 0
if negatives.any():
warnings.warn('Replacing negative forecasts with zero.')
y_hat = np.copy(y_hat)
y_hat[negatives] = 0.
# Quadratic progamming formulation
# here we are solving the quadratic programming problem
# formulated in the origial paper
Expand All @@ -724,8 +732,16 @@ def fit(self,
b = np.zeros(n_bottom)
# the quadratic programming problem
# returns the forecasts of the bottom series
bottom_fcts = np.apply_along_axis(lambda y_hat: solve_qp(G=G, a=a @ y_hat, C=C, b=b)[0],
axis=0, arr=y_hat)
if self.num_threads == 1:
bottom_fcts = np.apply_along_axis(lambda y_hat: solve_qp(G=G, a=a @ y_hat, C=C, b=b)[0],
axis=0, arr=y_hat)
else:
futures = []
with ThreadPoolExecutor(self.num_threads) as executor:
for j in range(y_hat.shape[1]):
future = executor.submit(solve_qp, G=G, a=a @ y_hat[:, j], C=C, b=b)
futures.append(future)
bottom_fcts = np.hstack([f.result()[0][:, None] for f in futures])
if not np.all(bottom_fcts > -1e-8):
raise Exception('nonnegative optimization failed')
# remove negative values close to zero
Expand Down Expand Up @@ -933,13 +949,12 @@ class OptimalCombination(MinTrace):
"""
def __init__(self,
method: str,
nonnegative: bool = False):
nonnegative: bool = False,
num_threads: int = 1):
comb_methods = ['ols', 'wls_struct']
if method not in comb_methods:
raise ValueError(f"Optimal Combination class does not support method: \"{method}\"")

self.method = method
self.nonnegative = nonnegative
super().__init__(method=method, nonnegative=nonnegative, num_threads=num_threads)
self.insample = False

# %% ../nbs/methods.ipynb 58
Expand Down Expand Up @@ -996,7 +1011,7 @@ class ERM(HReconciler):
**References:**<br>
- [Ben Taieb, S., & Koo, B. (2019). Regularized regression for hierarchical forecasting without
unbiasedness conditions. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge
Discovery & Data Mining KDD '19 (p. 1337{1347). New York, NY, USA: Association for Computing Machinery.](https://doi.org/10.1145/3292500.3330976).<br>
Discovery & Data Mining KDD '19 (p. 1337-1347). New York, NY, USA: Association for Computing Machinery.](https://doi.org/10.1145/3292500.3330976).<br>
"""
def __init__(self,
method: str,
Expand Down
5 changes: 4 additions & 1 deletion hierarchicalforecast/probabilistic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,10 @@ def get_samples(self, num_samples: int=None):

# Initialize PERMBU utility
rec_samples = base_samples.copy()
encoder = OneHotEncoder(sparse=False, dtype=np.float32)
try:
encoder = OneHotEncoder(sparse_output=False, dtype=np.float32)
except TypeError:
encoder = OneHotEncoder(sparse=False, dtype=np.float32)
hier_links = np.vstack(self._nonzero_indexes_by_row(self.S.T))

# BottomUp hierarchy traversing
Expand Down
16 changes: 6 additions & 10 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@
" func_params = fn.__dict__\n",
"\n",
" # Take default parameter out of names\n",
" args_to_remove = ['insample']\n",
" args_to_remove = ['insample', 'num_threads']\n",
" if not func_params.get('nonnegative', False):\n",
" args_to_remove += ['nonnegative']\n",
" args_to_remove.append('nonnegative')\n",
"\n",
" if fn_name == 'MinTrace' and \\\n",
" func_params['method']=='mint_shrink':\n",
Expand Down Expand Up @@ -921,8 +921,7 @@
" ]\n",
")\n",
"\n",
"fcst.fit(df=train_df)\n",
"fcst_df = fcst.forecast(h=12, fitted=True)\n",
"fcst_df = fcst.forecast(df=train_df, h=12, fitted=True)\n",
"fitted_df = fcst.forecast_fitted_values()\n",
"\n",
"fcst_df = hrec.reconcile(\n",
Expand Down Expand Up @@ -1156,11 +1155,8 @@
"\n",
"# Compute base auto-ETS predictions\n",
"# Careful identifying correct data freq, this data quarterly 'Q'\n",
"fcst = StatsForecast(df=Y_train_df,\n",
" #models=[ETS(season_length=12), Naive()],\n",
" models=[Naive()],\n",
" freq='Q', n_jobs=-1)\n",
"Y_hat_df = fcst.forecast(h=4, fitted=True)\n",
"fcst = StatsForecast(models=[Naive()], freq='Q', n_jobs=-1)\n",
"Y_hat_df = fcst.forecast(df=Y_train_df, h=4, fitted=True)\n",
"Y_fitted_df = fcst.forecast_fitted_values()\n",
"\n",
"# Reconcile the base predictions\n",
Expand All @@ -1172,7 +1168,7 @@
"Y_rec_df = hrec.reconcile(Y_hat_df=Y_hat_df, \n",
" Y_df=Y_fitted_df,\n",
" S=S_df, tags=tags)\n",
"Y_rec_df.groupby('unique_id').head(2)"
"Y_rec_df.groupby('unique_id', observed=True).head(2)"
]
}
],
Expand Down
36 changes: 26 additions & 10 deletions nbs/methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"#| export\n",
"import warnings\n",
"from collections import OrderedDict\n",
"from concurrent.futures import ThreadPoolExecutor\n",
"from copy import deepcopy\n",
"from typing import Callable, Dict, List, Optional, Union\n",
"\n",
Expand Down Expand Up @@ -1028,6 +1029,7 @@
" `method`: str, one of `ols`, `wls_struct`, `wls_var`, `mint_shrink`, `mint_cov`.<br>\n",
" `nonnegative`: bool, reconciled forecasts should be nonnegative?<br>\n",
" `mint_shr_ridge`: float=2e-8, ridge numeric protection to MinTrace-shr covariance estimator.<br>\n",
" `num_threads`: int=1, number of threads to use for solving the optimization problems.\n",
"\n",
" **References:**<br>\n",
" - [Wickramasuriya, S. L., Athanasopoulos, G., & Hyndman, R. J. (2019). \\\"Optimal forecast reconciliation for\n",
Expand All @@ -1040,12 +1042,14 @@
" def __init__(self, \n",
" method: str,\n",
" nonnegative: bool = False,\n",
" mint_shr_ridge: Optional[float] = 2e-8):\n",
" mint_shr_ridge: Optional[float] = 2e-8,\n",
" num_threads: int = 1):\n",
" self.method = method\n",
" self.nonnegative = nonnegative\n",
" self.insample = method in ['wls_var', 'mint_cov', 'mint_shrink']\n",
" if method == 'mint_shrink':\n",
" self.mint_shr_ridge = mint_shr_ridge\n",
" self.num_threads = num_threads\n",
"\n",
" def _get_PW_matrices(self, \n",
" S: np.ndarray,\n",
Expand Down Expand Up @@ -1164,9 +1168,11 @@
" if self.nonnegative:\n",
" _, n_bottom = S.shape\n",
" W_inv = np.linalg.pinv(self.W)\n",
" warnings.warn('Replacing negative forecasts with zero.')\n",
" y_hat = np.copy(y_hat)\n",
" y_hat[y_hat < 0] = 0.\n",
" negatives = y_hat < 0\n",
" if negatives.any():\n",
" warnings.warn('Replacing negative forecasts with zero.')\n",
" y_hat = np.copy(y_hat)\n",
" y_hat[negatives] = 0.\n",
" # Quadratic progamming formulation\n",
" # here we are solving the quadratic programming problem\n",
" # formulated in the origial paper\n",
Expand All @@ -1180,8 +1186,16 @@
" b = np.zeros(n_bottom)\n",
" # the quadratic programming problem\n",
" # returns the forecasts of the bottom series\n",
" bottom_fcts = np.apply_along_axis(lambda y_hat: solve_qp(G=G, a=a @ y_hat, C=C, b=b)[0], \n",
" axis=0, arr=y_hat)\n",
" if self.num_threads == 1:\n",
" bottom_fcts = np.apply_along_axis(lambda y_hat: solve_qp(G=G, a=a @ y_hat, C=C, b=b)[0], \n",
" axis=0, arr=y_hat)\n",
" else:\n",
" futures = []\n",
" with ThreadPoolExecutor(self.num_threads) as executor:\n",
" for j in range(y_hat.shape[1]):\n",
" future = executor.submit(solve_qp, G=G, a=a @ y_hat[:, j], C=C, b=b)\n",
" futures.append(future)\n",
" bottom_fcts = np.hstack([f.result()[0][:, None] for f in futures])\n",
" if not np.all(bottom_fcts > -1e-8):\n",
" raise Exception('nonnegative optimization failed')\n",
" # remove negative values close to zero\n",
Expand Down Expand Up @@ -1451,6 +1465,9 @@
" )['mean'],\n",
" S @ y_hat_bottom\n",
" )\n",
"mintrace_1thr = MinTrace(method='ols', nonnegative=False, num_threads=1).fit(S=S, y_hat=S @ y_hat_bottom)\n",
"mintrace_2thr = MinTrace(method='ols', nonnegative=False, num_threads=2).fit(S=S, y_hat=S @ y_hat_bottom)\n",
"np.testing.assert_allclose(mintrace_1thr.y_hat, mintrace_2thr.y_hat)\n",
"with ExceptionExpected(regex='min_trace (mint_cov)*'):\n",
" for nonnegative in [False, True]:\n",
" cls_min_trace = MinTrace(method='mint_cov', nonnegative=nonnegative)\n",
Expand Down Expand Up @@ -1548,13 +1565,12 @@
" \"\"\"\n",
" def __init__(self,\n",
" method: str,\n",
" nonnegative: bool = False):\n",
" nonnegative: bool = False,\n",
" num_threads: int = 1):\n",
" comb_methods = ['ols', 'wls_struct']\n",
" if method not in comb_methods:\n",
" raise ValueError(f\"Optimal Combination class does not support method: \\\"{method}\\\"\")\n",
"\n",
" self.method = method\n",
" self.nonnegative = nonnegative\n",
" super().__init__(method=method, nonnegative=nonnegative, num_threads=num_threads)\n",
" self.insample = False"
]
},
Expand Down
5 changes: 4 additions & 1 deletion nbs/probabilistic_methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,10 @@
"\n",
" # Initialize PERMBU utility\n",
" rec_samples = base_samples.copy()\n",
" encoder = OneHotEncoder(sparse=False, dtype=np.float32)\n",
" try:\n",
" encoder = OneHotEncoder(sparse_output=False, dtype=np.float32)\n",
" except TypeError:\n",
" encoder = OneHotEncoder(sparse=False, dtype=np.float32)\n",
" hier_links = np.vstack(self._nonzero_indexes_by_row(self.S.T))\n",
"\n",
" # BottomUp hierarchy traversing\n",
Expand Down

0 comments on commit 2067699

Please # to comment.