From d382361c855bef4f7421cab2e94813b688cf8e4e Mon Sep 17 00:00:00 2001 From: americast Date: Tue, 24 Oct 2023 10:28:47 -0400 Subject: [PATCH] fix tests --- evadb/binder/statement_binder.py | 32 +++++++-------- evadb/executor/create_function_executor.py | 2 +- .../long/test_model_forecasting.py | 25 ++++++++++-- .../binder/test_statement_binder.py | 40 ++++++++++++++++++- 4 files changed, 77 insertions(+), 22 deletions(-) diff --git a/evadb/binder/statement_binder.py b/evadb/binder/statement_binder.py index 0af675e110..4cb8aeb6a9 100644 --- a/evadb/binder/statement_binder.py +++ b/evadb/binder/statement_binder.py @@ -130,6 +130,22 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement): assert ( len(required_columns) == 0 ), f"Missing required {required_columns} columns for forecasting function." + outputs.extend( + [ + ColumnDefinition( + arg_map.get("predict", "y") + "-lo", + ColumnType.INTEGER, + None, + None, + ), + ColumnDefinition( + arg_map.get("predict", "y") + "-hi", + ColumnType.INTEGER, + None, + None, + ), + ] + ) else: raise BinderError( f"Unsupported type of function: {node.function_type}." @@ -137,22 +153,6 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement): assert ( len(node.inputs) == 0 and len(node.outputs) == 0 ), f"{node.function_type} functions' input and output are auto assigned" - outputs.extend( - [ - ColumnDefinition( - arg_map.get("predict", "y") + "-lo", - ColumnType.INTEGER, - None, - None, - ), - ColumnDefinition( - arg_map.get("predict", "y") + "-hi", - ColumnType.INTEGER, - None, - None, - ), - ] - ) node.inputs, node.outputs = inputs, outputs @bind.register(SelectStatement) diff --git a/evadb/executor/create_function_executor.py b/evadb/executor/create_function_executor.py index c31a1fb8bf..8538b11056 100644 --- a/evadb/executor/create_function_executor.py +++ b/evadb/executor/create_function_executor.py @@ -493,7 +493,7 @@ def get_optuna_config(trial): model_save_dir_name = ( library + "_" + arg_map["model"] + "_" + new_freq if "statsforecast" in library - else library + "_" + conf + "_" + arg_map["model"] + "_" + new_freq + else library + "_" + str(conf) + "_" + arg_map["model"] + "_" + new_freq ) if len(data.columns) >= 4 and library == "neuralforecast": model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns)) diff --git a/test/integration_tests/long/test_model_forecasting.py b/test/integration_tests/long/test_model_forecasting.py index 47ffe65a83..76e5562357 100644 --- a/test/integration_tests/long/test_model_forecasting.py +++ b/test/integration_tests/long/test_model_forecasting.py @@ -94,7 +94,14 @@ def test_forecast(self): result = execute_query_fetch_all(self.evadb, predict_query) self.assertEqual(len(result), 12) self.assertEqual( - result.columns, ["airforecast.unique_id", "airforecast.ds", "airforecast.y"] + result.columns, + [ + "airforecast.unique_id", + "airforecast.ds", + "airforecast.y", + "airforecast.y-lo", + "airforecast.y-hi", + ], ) create_predict_udf = """ @@ -116,7 +123,13 @@ def test_forecast(self): self.assertEqual(len(result), 24) self.assertEqual( result.columns, - ["airpanelforecast.unique_id", "airpanelforecast.ds", "airpanelforecast.y"], + [ + "airpanelforecast.unique_id", + "airpanelforecast.ds", + "airpanelforecast.y", + "airpanelforecast.y-lo", + "airpanelforecast.y-hi", + ], ) @forecast_skip_marker @@ -143,7 +156,13 @@ def test_forecast_with_column_rename(self): self.assertEqual(len(result), 24) self.assertEqual( result.columns, - ["homeforecast.type", "homeforecast.saledate", "homeforecast.ma"], + [ + "homeforecast.type", + "homeforecast.saledate", + "homeforecast.ma", + "homeforecast.ma-lo", + "homeforecast.ma-hi", + ], ) diff --git a/test/unit_tests/binder/test_statement_binder.py b/test/unit_tests/binder/test_statement_binder.py index d6642ea9a2..d57a08064e 100644 --- a/test/unit_tests/binder/test_statement_binder.py +++ b/test/unit_tests/binder/test_statement_binder.py @@ -475,6 +475,18 @@ def test_bind_create_function_should_bind_forecast_with_default_columns(self): array_type=MagicMock(), array_dimensions=MagicMock(), ) + y_lo_col_obj = ColumnCatalogEntry( + name="y-lo", + type=MagicMock(), + array_type=MagicMock(), + array_dimensions=MagicMock(), + ) + y_hi_col_obj = ColumnCatalogEntry( + name="y-hi", + type=MagicMock(), + array_type=MagicMock(), + array_dimensions=MagicMock(), + ) create_function_statement.query.target_list = [ TupleValueExpression( name=id_col_obj.name, table_alias="a", col_object=id_col_obj @@ -506,7 +518,13 @@ def test_bind_create_function_should_bind_forecast_with_default_columns(self): col_obj.array_type, col_obj.array_dimensions, ) - for col_obj in (id_col_obj, ds_col_obj, y_col_obj) + for col_obj in ( + id_col_obj, + ds_col_obj, + y_col_obj, + y_lo_col_obj, + y_hi_col_obj, + ) ] ) self.assertEqual(create_function_statement.inputs, expected_inputs) @@ -534,6 +552,18 @@ def test_bind_create_function_should_bind_forecast_with_renaming_columns(self): array_type=MagicMock(), array_dimensions=MagicMock(), ) + y_lo_col_obj = ColumnCatalogEntry( + name="y-lo", + type=MagicMock(), + array_type=MagicMock(), + array_dimensions=MagicMock(), + ) + y_hi_col_obj = ColumnCatalogEntry( + name="y-hi", + type=MagicMock(), + array_type=MagicMock(), + array_dimensions=MagicMock(), + ) create_function_statement.query.target_list = [ TupleValueExpression( name=id_col_obj.name, table_alias="a", col_object=id_col_obj @@ -569,7 +599,13 @@ def test_bind_create_function_should_bind_forecast_with_renaming_columns(self): col_obj.array_type, col_obj.array_dimensions, ) - for col_obj in (id_col_obj, ds_col_obj, y_col_obj) + for col_obj in ( + id_col_obj, + ds_col_obj, + y_col_obj, + y_lo_col_obj, + y_hi_col_obj, + ) ] ) self.assertEqual(create_function_statement.inputs, expected_inputs)