Skip to content

Commit

Permalink
Add feedback for flat predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
americast committed Oct 5, 2023
1 parent b9a3c7d commit 006d19e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
9 changes: 6 additions & 3 deletions docs/source/reference/ai/model-forecasting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,18 @@ Next, we create a function of `TYPE Forecasting`. We must enter the column name
CREATE FUNCTION IF NOT EXISTS Forecast FROM
(SELECT y FROM AirData)
TYPE Forecasting
HORIZON 12
PREDICT 'y';
This trains a forecasting model. The model can be called by providing the horizon for forecasting.

.. code-block:: sql
SELECT Forecast(12);
SELECT Forecast();
Here, the horizon is `12`, which represents the forecast 12 steps into the future.
.. note::

`Forecasting` function also provides suggestions by default. If you wish to turn it off, send "False" as an optional argument while calling the function. Eg. `SELECT Forecast("False");`

Forecast Parameters
-------------------
Expand Down Expand Up @@ -90,4 +92,5 @@ Below is an example query with `neuralforecast` with `trend` column as exogenous
PREDICT 'y'
LIBRARY 'neuralforecast'
AUTO 'f'
FREQUENCY 'M';
FREQUENCY 'M';
22 changes: 19 additions & 3 deletions evadb/functions/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,29 @@ def setup(
self.id_column_rename = id_column_rename
self.horizon = int(horizon)
self.library = library
self.suggestion_dict = {
1: "Predictions are flat. Consider using LIBRARY 'neuralforecast' for more accrate predictions.",
}

def forward(self, data) -> pd.DataFrame:
if self.library == "statsforecast":
forecast_df = self.model.predict(h=self.horizon)
forecast_df = self.model.predict(h=self.horizon).reset_index()
else:
forecast_df = self.model.predict()
forecast_df.reset_index(inplace=True)
forecast_df = self.model.predict().reset_index()

# Suggestions
if len(data) == 0 or list(data[0])[0].lower()[0] == "t":
suggestion_list = []
# 1: Flat predictions
if self.library == "statsforecast":
for type_here in forecast_df["unique_id"].unique():
if forecast_df.loc[forecast_df['unique_id'] == type_here][self.model_name].nunique() == 1:
suggestion_list.append(1)

for suggestion in set(suggestion_list):
print("\nSUGGESTION: " + self.suggestion_dict[suggestion])


forecast_df = forecast_df.rename(
columns={
"unique_id": self.id_column_rename,
Expand Down

0 comments on commit 006d19e

Please # to comment.