Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Decimal places control implemented #1161

Merged
merged 2 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/simple/classification/api_classification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from fedot.api.main import Fedot
from fedot.core.utils import fedot_project_root
from fedot.core.utils import set_random_seed
from fedot.core.utils import fedot_project_root, set_random_seed


def run_classification_example(timeout: float = None, visualization=False, with_tuning=True):
Expand All @@ -19,7 +18,7 @@ def run_classification_example(timeout: float = None, visualization=False, with_
auto_model.fit(features=train_data_path, target='target')
prediction = auto_model.predict_proba(features=test_data_path)

print(auto_model.get_metrics())
print(auto_model.get_metrics(decimal_places_num=4)) # we can control the rounding of metrics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Мб мне так кажется, но название громоздкое и неочевидное. Не знаю какое название используют на практике, но предложу, например rounding_order, чтобы была связь с округлением

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Пусть будет, переименовал.

if visualization:
auto_model.plot_prediction()
return prediction
Expand Down
11 changes: 7 additions & 4 deletions fedot/api/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
from copy import deepcopy
from typing import Any, List, Optional, Sequence, Tuple, Union, Callable
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
from golem.core.dag.graph_utils import graph_structure
from golem.core.log import default_log, Log
from golem.core.log import Log, default_log
from golem.core.optimisers.opt_history_objects.opt_history import OptHistory
from golem.core.tuning.simultaneous import SimultaneousTuner
from golem.visualisation.opt_viz_extra import visualise_pareto
Expand Down Expand Up @@ -475,7 +475,8 @@ def get_metrics(self,
target: Union[np.ndarray, pd.Series] = None,
metric_names: Union[str, List[str]] = None,
in_sample: Optional[bool] = None,
validation_blocks: Optional[int] = None) -> dict:
validation_blocks: Optional[int] = None,
decimal_places_num: int = 3) -> dict:
"""Gets quality metrics for the fitted graph

Args:
Expand All @@ -484,6 +485,7 @@ def get_metrics(self,
in_sample: used for time series forecasting.
If True prediction will be obtained as ``.predict(..., in_sample=True)``.
validation_blocks: number of validation blocks for time series in-sample forecast.
decimal_places_num: number of decimal places for metrics

Returns:
The values of quality metrics.
Expand Down Expand Up @@ -517,7 +519,8 @@ def get_metrics(self,
do_unfit=False)

metrics = obj_eval.evaluate(self.current_pipeline).values
metrics = {metric_name: round(abs(metric), 3) for (metric_name, metric) in zip(metric_names, metrics)}
metrics = {metric_name: round(abs(metric), decimal_places_num) for (metric_name, metric) in
zip(metric_names, metrics)}

return metrics

Expand Down
4 changes: 2 additions & 2 deletions test/integration/api/test_main_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from fedot.core.utils import fedot_project_root
from test.integration.models.test_split_train_test import get_synthetic_input_data
from test.unit.common_tests import is_predict_ignores_target
from test.unit.tasks.test_classification import get_synthetic_classification_data, get_iris_data
from test.unit.tasks.test_classification import get_iris_data, get_synthetic_classification_data
from test.unit.tasks.test_forecasting import get_ts_data
from test.unit.tasks.test_multi_ts_forecast import get_multi_ts_data
from test.unit.tasks.test_regression import get_synthetic_regression_data
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_api_predict_correct(task_type, predefined_model, metric_name):
model = Fedot(problem=task_type, **TESTS_MAIN_API_DEFAULT_PARAMS)
fedot_model = model.fit(features=train_data, predefined_model=predefined_model)
prediction = model.predict(features=test_data)
metric = model.get_metrics(metric_names=metric_name)
metric = model.get_metrics(metric_names=metric_name, decimal_places_num=5)
assert isinstance(fedot_model, Pipeline)
assert len(prediction) == len(test_data.target)
assert all(value > 0 for value in metric.values())
Expand Down
5 changes: 2 additions & 3 deletions test/unit/validation/test_table_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from datetime import timedelta

import pytest

from golem.core.tuning.simultaneous import SimultaneousTuner

from fedot.api.main import Fedot
from fedot.core.data.data import InputData
from fedot.core.data.data_split import train_test_data_setup
from fedot.core.optimisers.objective import PipelineObjectiveEvaluate
from fedot.core.optimisers.objective.data_source_splitter import DataSourceSplitter
from fedot.core.optimisers.objective.metrics_objective import MetricsObjective
from fedot.core.pipelines.node import PipelineNode
from fedot.core.pipelines.pipeline import Pipeline
Expand All @@ -18,7 +18,6 @@
from fedot.core.repository.quality_metrics_repository import ClassificationMetricsEnum
from fedot.core.repository.tasks import Task, TaskTypesEnum
from fedot.core.utils import fedot_project_root
from fedot.core.optimisers.objective.data_source_splitter import DataSourceSplitter
from test.integration.models.test_model import classification_dataset
from test.unit.tasks.test_classification import get_iris_data, pipeline_simple

Expand Down Expand Up @@ -91,7 +90,7 @@ def test_cv_api_correct():
model = Fedot(problem='classification', logging_level=logging.DEBUG, **composer_params)
fedot_model = model.fit(features=dataset_to_compose)
prediction = model.predict(features=dataset_to_validate)
metric = model.get_metrics(metric_names='f1')
metric = model.get_metrics(metric_names='f1', decimal_places_num=1)

assert isinstance(fedot_model, Pipeline)
assert len(prediction) == len(dataset_to_validate.target)
Expand Down