Skip to content

Commit

Permalink
Decimal places control implemented (#1161)
Browse files Browse the repository at this point in the history
* Decimal places control implemented
  • Loading branch information
nicl-nno authored Sep 4, 2023
1 parent 5da1447 commit 1a96f38
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
5 changes: 2 additions & 3 deletions examples/simple/classification/api_classification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from fedot.api.main import Fedot
from fedot.core.utils import fedot_project_root
from fedot.core.utils import set_random_seed
from fedot.core.utils import fedot_project_root, set_random_seed


def run_classification_example(timeout: float = None, visualization=False, with_tuning=True):
Expand All @@ -19,7 +18,7 @@ def run_classification_example(timeout: float = None, visualization=False, with_
auto_model.fit(features=train_data_path, target='target')
prediction = auto_model.predict_proba(features=test_data_path)

print(auto_model.get_metrics())
print(auto_model.get_metrics(rounding_order=4)) # we can control the rounding of metrics
if visualization:
auto_model.plot_prediction()
return prediction
Expand Down
11 changes: 7 additions & 4 deletions fedot/api/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
from copy import deepcopy
from typing import Any, List, Optional, Sequence, Tuple, Union, Callable
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
from golem.core.dag.graph_utils import graph_structure
from golem.core.log import default_log, Log
from golem.core.log import Log, default_log
from golem.core.optimisers.opt_history_objects.opt_history import OptHistory
from golem.core.tuning.simultaneous import SimultaneousTuner
from golem.visualisation.opt_viz_extra import visualise_pareto
Expand Down Expand Up @@ -475,7 +475,8 @@ def get_metrics(self,
target: Union[np.ndarray, pd.Series] = None,
metric_names: Union[str, List[str]] = None,
in_sample: Optional[bool] = None,
validation_blocks: Optional[int] = None) -> dict:
validation_blocks: Optional[int] = None,
rounding_order: int = 3) -> dict:
"""Gets quality metrics for the fitted graph
Args:
Expand All @@ -484,6 +485,7 @@ def get_metrics(self,
in_sample: used for time series forecasting.
If True prediction will be obtained as ``.predict(..., in_sample=True)``.
validation_blocks: number of validation blocks for time series in-sample forecast.
rounding_order: number of decimal places for metrics
Returns:
The values of quality metrics.
Expand Down Expand Up @@ -517,7 +519,8 @@ def get_metrics(self,
do_unfit=False)

metrics = obj_eval.evaluate(self.current_pipeline).values
metrics = {metric_name: round(abs(metric), 3) for (metric_name, metric) in zip(metric_names, metrics)}
metrics = {metric_name: round(abs(metric), rounding_order) for (metric_name, metric) in
zip(metric_names, metrics)}

return metrics

Expand Down
4 changes: 2 additions & 2 deletions test/integration/api/test_main_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from fedot.core.utils import fedot_project_root
from test.integration.models.test_split_train_test import get_synthetic_input_data
from test.unit.common_tests import is_predict_ignores_target
from test.unit.tasks.test_classification import get_synthetic_classification_data, get_iris_data
from test.unit.tasks.test_classification import get_iris_data, get_synthetic_classification_data
from test.unit.tasks.test_forecasting import get_ts_data
from test.unit.tasks.test_multi_ts_forecast import get_multi_ts_data
from test.unit.tasks.test_regression import get_synthetic_regression_data
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_api_predict_correct(task_type, predefined_model, metric_name):
model = Fedot(problem=task_type, **TESTS_MAIN_API_DEFAULT_PARAMS)
fedot_model = model.fit(features=train_data, predefined_model=predefined_model)
prediction = model.predict(features=test_data)
metric = model.get_metrics(metric_names=metric_name)
metric = model.get_metrics(metric_names=metric_name, rounding_order=5)
assert isinstance(fedot_model, Pipeline)
assert len(prediction) == len(test_data.target)
assert all(value > 0 for value in metric.values())
Expand Down
5 changes: 2 additions & 3 deletions test/unit/validation/test_table_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from datetime import timedelta

import pytest

from golem.core.tuning.simultaneous import SimultaneousTuner

from fedot.api.main import Fedot
from fedot.core.data.data import InputData
from fedot.core.data.data_split import train_test_data_setup
from fedot.core.optimisers.objective import PipelineObjectiveEvaluate
from fedot.core.optimisers.objective.data_source_splitter import DataSourceSplitter
from fedot.core.optimisers.objective.metrics_objective import MetricsObjective
from fedot.core.pipelines.node import PipelineNode
from fedot.core.pipelines.pipeline import Pipeline
Expand All @@ -18,7 +18,6 @@
from fedot.core.repository.quality_metrics_repository import ClassificationMetricsEnum
from fedot.core.repository.tasks import Task, TaskTypesEnum
from fedot.core.utils import fedot_project_root
from fedot.core.optimisers.objective.data_source_splitter import DataSourceSplitter
from test.integration.models.test_model import classification_dataset
from test.unit.tasks.test_classification import get_iris_data, pipeline_simple

Expand Down Expand Up @@ -91,7 +90,7 @@ def test_cv_api_correct():
model = Fedot(problem='classification', logging_level=logging.DEBUG, **composer_params)
fedot_model = model.fit(features=dataset_to_compose)
prediction = model.predict(features=dataset_to_validate)
metric = model.get_metrics(metric_names='f1')
metric = model.get_metrics(metric_names='f1', rounding_order=1)

assert isinstance(fedot_model, Pipeline)
assert len(prediction) == len(dataset_to_validate.target)
Expand Down

0 comments on commit 1a96f38

Please # to comment.