diff --git a/test/integration/models/test_model.py b/test/integration/models/test_model.py index 09873db356..a8684e8964 100644 --- a/test/integration/models/test_model.py +++ b/test/integration/models/test_model.py @@ -461,22 +461,44 @@ def test_locf_forecast_correctly(): assert np.array_equal(predict_forecast.predict, np.array([[110, 120, 130, 110]])) -def test_models_does_not_fall_on_constant_data(): +@pytest.mark.parametrize('operation', OperationTypesRepository('all')._repo, ids=lambda x: x.id) +def test_models_does_not_fall_on_constant_data(operation): """ Run models on constant data """ # models that raise exception to_skip = ['custom', 'arima', 'catboost', 'catboostreg', 'cgru', 'lda', 'fast_ica', 'decompose', 'class_decompose'] + if operation.id in to_skip: + return - for operation in OperationTypesRepository('all')._repo: - if operation.id in to_skip: - continue - for task_type in operation.task_type: - for data_type in operation.input_types: - data = get_data_for_testing(task_type, data_type, - length=100, features_count=2, - random=False) - if data is not None: + for task_type in operation.task_type: + for data_type in operation.input_types: + data = get_data_for_testing(task_type, data_type, + length=100, features_count=2, + random=False) + if data is not None: + nodes_from = [] + if task_type is TaskTypesEnum.ts_forecasting: + if 'non_lagged' not in operation.tags: + nodes_from = [PipelineNode('lagged')] + node = PipelineNode(operation.id, nodes_from=nodes_from) + pipeline = Pipeline(node) + pipeline.fit(data) + assert pipeline.predict(data) is not None + + +@pytest.mark.parametrize('operation', OperationTypesRepository('all')._repo, ids=lambda x: x.id) +def test_operations_are_serializable(operation): + to_skip = ['custom', 'decompose', 'class_decompose'] + if operation.id in to_skip: + return + for task_type in operation.task_type: + for data_type in operation.input_types: + data = get_data_for_testing(task_type, data_type, + length=100, features_count=2, + random=True) + if data is not None: + try: nodes_from = [] if task_type is TaskTypesEnum.ts_forecasting: if 'non_lagged' not in operation.tags: @@ -484,33 +506,10 @@ def test_models_does_not_fall_on_constant_data(): node = PipelineNode(operation.id, nodes_from=nodes_from) pipeline = Pipeline(node) pipeline.fit(data) - assert pipeline.predict(data) is not None - - -def test_operations_are_serializable(): - to_skip = ['custom', 'decompose', 'class_decompose'] - - for operation in OperationTypesRepository('all')._repo: - if operation.id in to_skip: - continue - for task_type in operation.task_type: - for data_type in operation.input_types: - data = get_data_for_testing(task_type, data_type, - length=100, features_count=2, - random=True) - if data is not None: - try: - nodes_from = [] - if task_type is TaskTypesEnum.ts_forecasting: - if 'non_lagged' not in operation.tags: - nodes_from = [PipelineNode('lagged')] - node = PipelineNode(operation.id, nodes_from=nodes_from) - pipeline = Pipeline(node) - pipeline.fit(data) - serialized = pickle.dumps(pipeline, pickle.HIGHEST_PROTOCOL) - assert isinstance(serialized, bytes) - except NotImplementedError: - pass + serialized = pickle.dumps(pipeline, pickle.HIGHEST_PROTOCOL) + assert isinstance(serialized, bytes) + except NotImplementedError: + pass def test_operations_are_fast(): @@ -534,7 +533,7 @@ def test_operations_are_fast(): reference_time = tuple(map(min, zip(perfomance_values, reference_time))) for operation in OperationTypesRepository('all')._repo: - if (operation.id not in to_skip and operation.presets and FAST_TRAIN_PRESET_NAME in operation.presets): + if operation.id not in to_skip and operation.presets and FAST_TRAIN_PRESET_NAME in operation.presets: for _ in range(attempt): perfomance_values = get_operation_perfomance(operation, data_lengths) # if attempt is successful then stop @@ -548,7 +547,6 @@ def test_all_operations_are_documented(): # All operations and presets should be listed in `docs/source/introduction/fedot_features/automation_features.rst` to_skip = {'custom', 'data_source_img', 'data_source_text', 'data_source_table', 'data_source_ts', 'exog_ts'} path_to_docs = fedot_project_root() / 'docs/source/introduction/fedot_features/automation_features.rst' - docs_lines = None with open(path_to_docs, 'r') as docs_: docs_lines = docs_.readlines()