From e03c4d04f648729e88d89ab5084a79dd95b17d80 Mon Sep 17 00:00:00 2001 From: americast Date: Sat, 11 Nov 2023 15:27:07 -0500 Subject: [PATCH] optuna for non auto models --- evadb/executor/create_function_executor.py | 5 +++-- test/integration_tests/long/test_model_forecasting.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/evadb/executor/create_function_executor.py b/evadb/executor/create_function_executor.py index 948819592..dec2ff328 100644 --- a/evadb/executor/create_function_executor.py +++ b/evadb/executor/create_function_executor.py @@ -499,7 +499,8 @@ def get_optuna_config(trial): return model_args_config model_args["config"] = get_optuna_config - model_args["backend"] = "optuna" + + model_args["backend"] = "optuna" model_args["h"] = horizon model_args["loss"] = MQLoss(level=[conf]) @@ -542,7 +543,7 @@ def get_optuna_config(trial): raise FunctionIODefinitionError(err_msg) model = StatsForecast( - [model_here(season_length=season_length)], freq=new_freq + [model_here(season_length=season_length)], freq=new_freq, n_jobs=-1 ) data["ds"] = pd.to_datetime(data["ds"]) diff --git a/test/integration_tests/long/test_model_forecasting.py b/test/integration_tests/long/test_model_forecasting.py index d309700d3..76e556235 100644 --- a/test/integration_tests/long/test_model_forecasting.py +++ b/test/integration_tests/long/test_model_forecasting.py @@ -110,6 +110,8 @@ def test_forecast(self): TYPE Forecasting HORIZON 12 PREDICT 'y' + LIBRARY 'neuralforecast' + AUTO 'false' FREQUENCY 'M'; """ execute_query_fetch_all(self.evadb, create_predict_udf)