diff --git a/docs/source/reference/ai/model-forecasting.rst b/docs/source/reference/ai/model-forecasting.rst index 8285ad76b6..0e6adf4f5c 100644 --- a/docs/source/reference/ai/model-forecasting.rst +++ b/docs/source/reference/ai/model-forecasting.rst @@ -28,16 +28,18 @@ Next, we create a function of `TYPE Forecasting`. We must enter the column name CREATE FUNCTION IF NOT EXISTS Forecast FROM (SELECT y FROM AirData) TYPE Forecasting + HORIZON 12 PREDICT 'y'; This trains a forecasting model. The model can be called by providing the horizon for forecasting. .. code-block:: sql - SELECT Forecast(12); + SELECT Forecast(); -Here, the horizon is `12`, which represents the forecast 12 steps into the future. +.. note:: + `Forecasting` function also provides suggestions by default. If you wish to turn it off, send "False" as an optional argument while calling the function. Eg. `SELECT Forecast("False");` Forecast Parameters ------------------- @@ -90,4 +92,5 @@ Below is an example query with `neuralforecast` with `trend` column as exogenous PREDICT 'y' LIBRARY 'neuralforecast' AUTO 'f' - FREQUENCY 'M'; \ No newline at end of file + FREQUENCY 'M'; + diff --git a/evadb/functions/forecast.py b/evadb/functions/forecast.py index 1571f6c4fc..305eba5c44 100644 --- a/evadb/functions/forecast.py +++ b/evadb/functions/forecast.py @@ -48,13 +48,29 @@ def setup( self.id_column_rename = id_column_rename self.horizon = int(horizon) self.library = library + self.suggestion_dict = { + 1: "Predictions are flat. Consider using LIBRARY 'neuralforecast' for more accrate predictions.", + } def forward(self, data) -> pd.DataFrame: if self.library == "statsforecast": - forecast_df = self.model.predict(h=self.horizon) + forecast_df = self.model.predict(h=self.horizon).reset_index() else: - forecast_df = self.model.predict() - forecast_df.reset_index(inplace=True) + forecast_df = self.model.predict().reset_index() + + # Suggestions + if len(data) == 0 or list(data[0])[0].lower()[0] == "t": + suggestion_list = [] + # 1: Flat predictions + if self.library == "statsforecast": + for type_here in forecast_df["unique_id"].unique(): + if forecast_df.loc[forecast_df['unique_id'] == type_here][self.model_name].nunique() == 1: + suggestion_list.append(1) + + for suggestion in set(suggestion_list): + print("\nSUGGESTION: " + self.suggestion_dict[suggestion]) + + forecast_df = forecast_df.rename( columns={ "unique_id": self.id_column_rename,