Skip to content

Commit

Permalink
Fixes date and frequency issues in forecasting (#1094)
Browse files Browse the repository at this point in the history
Fixes #1081 pt 2.
  • Loading branch information
americast authored Sep 12, 2023
1 parent ea74021 commit 788d65e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/source/reference/ai/model-forecasting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ EvaDB's default forecast framework is `statsforecast <https://nixtla.github.io/s
* - PREDICT (required)
- The name of the column we wish to forecast.
* - TIME
- The name of the column that contains the datestamp, wihch should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp. If not provided, an auto increasing ID column will be used.
- The name of the column that contains the datestamp, wihch should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp. Please visit the `pandas documentation <https://pandas.pydata.org/docs/reference/api/pandas.to_datetime.html>`_ for details. If not provided, an auto increasing ID column will be used.
* - ID
- The name of column that represents an identifier for the series. If not provided, the whole table is considered as one series of data.
* - MODEL
Expand Down
13 changes: 8 additions & 5 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,11 @@ def handle_forecasting_function(self):
impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
else:
impl_path = self.node.impl_path.absolute().as_posix()
arg_map = {arg.key: arg.value for arg in self.node.metadata}

if "model" not in arg_map.keys():
arg_map["model"] = "AutoARIMA"
if "frequency" not in arg_map.keys():
arg_map["frequency"] = "M"

model_name = arg_map["model"]
frequency = arg_map["frequency"]

"""
The following rename is needed for statsforecast, which requires the column name to be the following:
Expand All @@ -179,6 +175,10 @@ def handle_forecasting_function(self):
if "ds" not in list(data.columns):
data["ds"] = [x + 1 for x in range(len(data))]

if "frequency" not in arg_map.keys():
arg_map["frequency"] = pd.infer_freq(data["ds"])
frequency = arg_map["frequency"]

try_to_import_forecast()
from statsforecast import StatsForecast
from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
Expand Down Expand Up @@ -220,7 +220,7 @@ def handle_forecasting_function(self):
)

weight_file = Path(model_path)

data["ds"] = pd.to_datetime(data["ds"])
if not weight_file.exists():
model.fit(data)
f = open(model_path, "wb")
Expand All @@ -233,6 +233,9 @@ def handle_forecasting_function(self):
FunctionMetadataCatalogEntry("model_name", model_name),
FunctionMetadataCatalogEntry("model_path", model_path),
FunctionMetadataCatalogEntry("output_column_rename", arg_map["predict"]),
FunctionMetadataCatalogEntry(
"time_column_rename", arg_map["time"] if "time" in arg_map else "ds"
),
]

return (
Expand Down
14 changes: 12 additions & 2 deletions evadb/functions/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,20 @@ def name(self) -> str:
return "ForecastModel"

@setup(cacheable=False, function_type="Forecasting", batchable=True)
def setup(self, model_name: str, model_path: str, output_column_rename: str):
def setup(
self,
model_name: str,
model_path: str,
output_column_rename: str,
time_column_rename: str,
):
f = open(model_path, "rb")
loaded_model = pickle.load(f)
f.close()
self.model = loaded_model
self.model_name = model_name
self.output_column_rename = output_column_rename
self.time_column_rename = time_column_rename

def forward(self, data) -> pd.DataFrame:
horizon = list(data.iloc[:, -1])[0]
Expand All @@ -43,6 +50,9 @@ def forward(self, data) -> pd.DataFrame:
), "Forecast UDF expects integral horizon in parameter."
forecast_df = self.model.predict(h=horizon)
forecast_df = forecast_df.rename(
columns={self.model_name: self.output_column_rename}
columns={
self.model_name: self.output_column_rename,
"ds": self.time_column_rename,
}
)
return forecast_df

0 comments on commit 788d65e

Please # to comment.