This repository has been archived by the owner on Oct 7, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(Baseline): added all baseline models from
fold
(#42)
* feature(Baseline): added all baseline models from `fold` * Create test-baselines.yaml
- Loading branch information
1 parent
2730f22
commit f1fac2d
Showing
4 changed files
with
194 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
name: test-baselines | ||
|
||
on: push | ||
|
||
jobs: | ||
|
||
run-tests: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- name: checkout | ||
uses: actions/checkout@v3 | ||
|
||
- name: setup-python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: 3.9 | ||
|
||
- name: install-dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install ".[tests]" | ||
- name: run-tests | ||
run: pytest tests/test_baselines.py -s --durations 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Union | ||
|
||
import pandas as pd | ||
from fold.models.base import Model | ||
from fold.transformations.base import fit_noop | ||
|
||
|
||
class Naive(Model): | ||
""" | ||
A model that predicts the last value seen. | ||
""" | ||
|
||
name = "Naive" | ||
properties = Model.Properties(mode=Model.Properties.Mode.online, memory_size=1) | ||
|
||
def predict(self, X: pd.DataFrame) -> Union[pd.Series, pd.DataFrame]: | ||
# it's an online transformation, so len(X) will be always 1, | ||
return pd.Series( | ||
self._state.memory_y.iloc[-1].squeeze(), index=X.index[-1:None] | ||
) | ||
|
||
def predict_in_sample(self, X: pd.DataFrame) -> Union[pd.Series, pd.DataFrame]: | ||
return self._state.memory_y.shift(1) | ||
|
||
fit = fit_noop | ||
update = fit | ||
|
||
|
||
class NaiveSeasonal(Model): | ||
""" | ||
A model that predicts the last value seen in the same season. | ||
""" | ||
|
||
name = "NaiveSeasonal" | ||
|
||
def __init__(self, seasonal_length: int) -> None: | ||
assert seasonal_length > 1, "seasonal_length must be greater than 1" | ||
self.seasonal_length = seasonal_length | ||
self.properties = Model.Properties( | ||
mode=Model.Properties.Mode.online, | ||
memory_size=seasonal_length, | ||
_internal_supports_minibatch_backtesting=True, | ||
) | ||
|
||
def predict(self, X: pd.DataFrame) -> Union[pd.Series, pd.DataFrame]: | ||
# it's an online transformation, so len(X) will be always 1, | ||
return pd.Series( | ||
self._state.memory_y.iloc[-self.seasonal_length].squeeze(), | ||
index=X.index[-1:None], | ||
) | ||
|
||
def predict_in_sample(self, X: pd.DataFrame) -> Union[pd.Series, pd.DataFrame]: | ||
return self._state.memory_y.shift(self.seasonal_length) | ||
|
||
fit = fit_noop | ||
update = fit | ||
|
||
|
||
class MovingAverage(Model): | ||
""" | ||
A model that predicts the mean of the last values seen. | ||
""" | ||
|
||
name = "MovingAverage" | ||
|
||
def __init__(self, window_size: int) -> None: | ||
self.window_size = window_size | ||
self.properties = Model.Properties( | ||
mode=Model.Properties.Mode.online, | ||
memory_size=window_size, | ||
_internal_supports_minibatch_backtesting=True, | ||
) | ||
|
||
def predict(self, X: pd.DataFrame) -> Union[pd.Series, pd.DataFrame]: | ||
# it's an online transformation, so len(X) will be always 1, | ||
return pd.Series( | ||
self._state.memory_y[-self.window_size :].mean(), index=X.index[-1:None] | ||
) | ||
|
||
def predict_in_sample(self, X: pd.DataFrame) -> Union[pd.Series, pd.DataFrame]: | ||
return self._state.memory_y.shift(1).rolling(self.window_size).mean() | ||
|
||
fit = fit_noop | ||
update = fit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
from fold.loop import backtest, train | ||
from fold.splitters import ExpandingWindowSplitter | ||
from fold.transformations.columns import OnlyPredictions | ||
from fold.transformations.dev import Test | ||
from fold.utils.tests import generate_sine_wave_data | ||
|
||
from fold_models.baseline import MovingAverage, Naive, NaiveSeasonal | ||
|
||
|
||
def test_baseline_naive() -> None: | ||
X, y = generate_sine_wave_data( | ||
cycles=10, length=120, freq="M" | ||
) # create a sine wave with yearly seasonality | ||
|
||
def check_if_not_nan(x): | ||
assert not x.isna().squeeze().any() | ||
|
||
splitter = ExpandingWindowSplitter(initial_train_window=0.2, step=0.1) | ||
transformations = [ | ||
Naive(), | ||
Test(fit_func=check_if_not_nan, transform_func=lambda X: X), | ||
OnlyPredictions(), | ||
] | ||
transformations_over_time = train(transformations, X, y, splitter) | ||
pred = backtest(transformations_over_time, X, y, splitter) | ||
assert ( | ||
pred.squeeze() == y.shift(1)[pred.index] | ||
).all() # last year's value should match this year's value, with the sine wave we generated | ||
assert ( | ||
len(pred) == 120 * 0.8 | ||
) # should return non-NaN predictions for the all out-of-sample sets | ||
|
||
|
||
def test_baseline_naive_seasonal() -> None: | ||
X, y = generate_sine_wave_data( | ||
cycles=10, length=120, freq="M" | ||
) # create a sine wave with yearly seasonality | ||
|
||
def check_if_not_nan(x): | ||
assert not x.isna().squeeze().any() | ||
|
||
splitter = ExpandingWindowSplitter(initial_train_window=0.2, step=0.1) | ||
transformations = [ | ||
NaiveSeasonal(seasonal_length=12), | ||
Test(fit_func=check_if_not_nan, transform_func=lambda X: X), | ||
OnlyPredictions(), | ||
] | ||
transformations_over_time = train(transformations, X, y, splitter) | ||
pred = backtest(transformations_over_time, X, y, splitter) | ||
assert np.isclose( | ||
pred.squeeze(), y[pred.index], atol=0.02 | ||
).all() # last year's value should match this year's value, with the sine wave we generated | ||
assert ( | ||
len(pred) == 120 * 0.8 | ||
) # should return non-NaN predictions for the all out-of-sample sets | ||
|
||
|
||
def test_baseline_mean() -> None: | ||
X, y = generate_sine_wave_data(cycles=10, length=400) | ||
|
||
def check_if_not_nan(x): | ||
assert not x.isna().squeeze().any() | ||
|
||
splitter = ExpandingWindowSplitter(initial_train_window=0.2, step=0.1) | ||
transformations = [ | ||
MovingAverage(window_size=12), | ||
Test(fit_func=check_if_not_nan, transform_func=lambda X: X), | ||
OnlyPredictions(), | ||
] | ||
transformations_over_time = train(transformations, X, y, splitter) | ||
pred = backtest(transformations_over_time, X, y, splitter) | ||
assert np.isclose( | ||
y.shift(1).rolling(12).mean()[pred.index], pred.squeeze(), atol=0.01 | ||
).all() | ||
assert ( | ||
len(pred) == 400 * 0.8 | ||
) # should return non-NaN predictions for the all out-of-sample sets |