Skip to content

Commit

Permalink
Fix plot_pareto (#1023)
Browse files Browse the repository at this point in the history
* Fix plot_pareto

* Revert changes in get_metrics

* Correct docstring

* Correct docstring
  • Loading branch information
YamLyubov authored Feb 10, 2023
1 parent df1e56c commit 8c4efb3
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
2 changes: 0 additions & 2 deletions examples/simple/time_series_forecasting/api_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from fedot.api.main import Fedot
from fedot.core.data.data import InputData
from fedot.core.data.data_split import train_test_data_setup
from fedot.core.pipelines.node import PrimaryNode, SecondaryNode
from fedot.core.pipelines.pipeline import Pipeline
from fedot.core.repository.dataset_types import DataTypesEnum
from fedot.core.repository.tasks import TsForecastingParams, Task, TaskTypesEnum
from fedot.core.utils import fedot_project_root
Expand Down
3 changes: 3 additions & 0 deletions fedot/api/api_utils/api_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, problem: str):
self.preprocessing_cache: Optional[PreprocessingCache] = None
self.preset_name = None
self.timer = None
self.metric_names = None
# status flag indicating that composer step was applied
self.was_optimised = False
# status flag indicating that tuner step was applied
Expand Down Expand Up @@ -290,6 +291,8 @@ def compose_pipeline(self, task: Task,
.with_graph_generation_param(graph_generation_params=graph_generation_params) \
.build()

self.metric_names = gp_composer.optimizer.objective.metric_names

n_jobs = determine_n_jobs(composer_requirements.n_jobs)

if self.timer.have_time_for_composing(composer_params['pop_size'], n_jobs):
Expand Down
2 changes: 1 addition & 1 deletion fedot/api/api_utils/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _parse_input_params(self, input_params: Dict[str, Any]):
input_params['task_params'] = TsForecastingParams(forecast_length=DEFAULT_FORECAST_LENGTH)

if self.api_params['problem'] == 'clustering':
raise ValueError('This type of task is not not supported in API now')
raise ValueError('This type of task is not supported in API now')

self.task = self.get_task_params(self.api_params['problem'], input_params['task_params'])
self.metric_name = self.get_default_metric(self.api_params['problem'])
Expand Down
11 changes: 5 additions & 6 deletions fedot/api/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from copy import deepcopy
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, List, Optional, Sequence, Tuple, Union, Callable

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -392,7 +392,7 @@ def load(self, path):
self.data_processor.preprocessor = self.current_pipeline.preprocessor

def plot_pareto(self):
metric_names = self.params.metric_to_compose
metric_names = self.api_composer.metric_names
# archive_history stores archives of the best models.
# Each archive is sorted from the best to the worst model,
# so the best_candidates is sorted too.
Expand Down Expand Up @@ -434,12 +434,12 @@ def get_metrics(self,
"""Gets quality metrics for the fitted graph
Args:
target: the array with target values of test data
metric_names: the names of required metrics
target: the array with target values of test data. If None, target specified for fit is used
metric_names: the names of required metrics.
in_sample: used for time-series forecasting.
If True prediction will be obtained as ``.predict(..., in_sample=True)``.
validation_blocks: number of validation blocks for in-sample forecast.
If ``validation_blocks = None`` uses number of validation blocks set during model initialization
If None uses number of validation blocks set during model initialisation
(default is 2).
Returns:
Expand All @@ -458,7 +458,6 @@ def get_metrics(self,
else:
self.test_data.target = target[:len(self.prediction.predict)]

# TODO change to sklearn metrics
metric_names = ensure_wrapped_in_sequence(metric_names)

calculated_metrics = dict()
Expand Down

0 comments on commit 8c4efb3

Please # to comment.