|
12 | 12 | import unittest.mock
|
13 | 13 | import warnings
|
14 | 14 | from abc import abstractmethod
|
15 |
| -from typing import Any, Callable, Dict, List, Optional, Union, cast |
| 15 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast |
16 | 16 |
|
17 | 17 | from ConfigSpace.configuration_space import Configuration, ConfigurationSpace
|
18 | 18 |
|
@@ -223,9 +223,7 @@ def build_pipeline(self, dataset_properties: Dict[str, Any]) -> BasePipeline:
|
223 | 223 | """
|
224 | 224 | raise NotImplementedError
|
225 | 225 |
|
226 |
| - def set_pipeline_config( |
227 |
| - self, |
228 |
| - **pipeline_config_kwargs: Any) -> None: |
| 226 | + def set_pipeline_config(self, **pipeline_config_kwargs: Any) -> None: |
229 | 227 | """
|
230 | 228 | Check whether arguments are valid and
|
231 | 229 | then sets them to the current pipeline
|
@@ -259,12 +257,6 @@ def get_pipeline_options(self) -> dict:
|
259 | 257 | """
|
260 | 258 | return self.pipeline_options
|
261 | 259 |
|
262 |
| - # def set_search_space(self, search_space: ConfigurationSpace) -> None: |
263 |
| - # """ |
264 |
| - # Update the search space. |
265 |
| - # """ |
266 |
| - # raise NotImplementedError |
267 |
| - # |
268 | 260 | def get_search_space(self, dataset: BaseDataset = None) -> ConfigurationSpace:
|
269 | 261 | """
|
270 | 262 | Returns the current search space as ConfigurationSpace object.
|
@@ -406,9 +398,9 @@ def _close_dask_client(self) -> None:
|
406 | 398 | None
|
407 | 399 | """
|
408 | 400 | if (
|
409 |
| - hasattr(self, '_is_dask_client_internally_created') |
410 |
| - and self._is_dask_client_internally_created |
411 |
| - and self._dask_client |
| 401 | + hasattr(self, '_is_dask_client_internally_created') |
| 402 | + and self._is_dask_client_internally_created |
| 403 | + and self._dask_client |
412 | 404 | ):
|
413 | 405 | self._dask_client.shutdown()
|
414 | 406 | self._dask_client.close()
|
@@ -661,10 +653,11 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
|
661 | 653 | f"Fitting {cls} took {runtime}s, performance:{cost}/{additional_info}")
|
662 | 654 | configuration = additional_info['pipeline_configuration']
|
663 | 655 | origin = additional_info['configuration_origin']
|
| 656 | + additional_info.pop('pipeline_configuration') |
664 | 657 | run_history.add(config=configuration, cost=cost,
|
665 | 658 | time=runtime, status=status, seed=self.seed,
|
666 | 659 | starttime=starttime, endtime=starttime + runtime,
|
667 |
| - origin=origin) |
| 660 | + origin=origin, additional_info=additional_info) |
668 | 661 | else:
|
669 | 662 | if additional_info.get('exitcode') == -6:
|
670 | 663 | self._logger.error(
|
@@ -710,6 +703,7 @@ def _search(
|
710 | 703 | memory_limit: Optional[int] = 4096,
|
711 | 704 | smac_scenario_args: Optional[Dict[str, Any]] = None,
|
712 | 705 | get_smac_object_callback: Optional[Callable] = None,
|
| 706 | + tae_func: Optional[Callable] = None, |
713 | 707 | all_supported_metrics: bool = True,
|
714 | 708 | precision: int = 32,
|
715 | 709 | disable_file_output: List = [],
|
@@ -777,6 +771,10 @@ def _search(
|
777 | 771 | instances, num_params, runhistory, seed and ta. This is
|
778 | 772 | an advanced feature. Use only if you are familiar with
|
779 | 773 | [SMAC](https://automl.github.io/SMAC3/master/index.html).
|
| 774 | + tae_func (Optional[Callable]): |
| 775 | + TargetAlgorithm to be optimised. If None, `eval_function` |
| 776 | + available in autoPyTorch/evaluation/train_evaluator is used. |
| 777 | + Must be child class of AbstractEvaluator. |
780 | 778 | all_supported_metrics (bool), (default=True): if True, all
|
781 | 779 | metrics supporting current task will be calculated
|
782 | 780 | for each pipeline and results will be available via cv_results
|
@@ -988,7 +986,7 @@ def _search(
|
988 | 986 | )
|
989 | 987 | try:
|
990 | 988 | run_history, self.trajectory, budget_type = \
|
991 |
| - _proc_smac.run_smbo() |
| 989 | + _proc_smac.run_smbo(func=tae_func) |
992 | 990 | self.run_history.update(run_history, DataOrigin.INTERNAL)
|
993 | 991 | trajectory_filename = os.path.join(
|
994 | 992 | self._backend.get_smac_output_directory_for_run(self.seed),
|
@@ -1042,10 +1040,10 @@ def _search(
|
1042 | 1040 | return self
|
1043 | 1041 |
|
1044 | 1042 | def refit(
|
1045 |
| - self, |
1046 |
| - dataset: BaseDataset, |
1047 |
| - budget_config: Dict[str, Union[int, str]] = {}, |
1048 |
| - split_id: int = 0 |
| 1043 | + self, |
| 1044 | + dataset: BaseDataset, |
| 1045 | + budget_config: Dict[str, Union[int, str]] = {}, |
| 1046 | + split_id: int = 0 |
1049 | 1047 | ) -> "BaseTask":
|
1050 | 1048 | """
|
1051 | 1049 | Refit all models found with fit to new data.
|
@@ -1181,10 +1179,10 @@ def fit(self,
|
1181 | 1179 | return pipeline
|
1182 | 1180 |
|
1183 | 1181 | def predict(
|
1184 |
| - self, |
1185 |
| - X_test: np.ndarray, |
1186 |
| - batch_size: Optional[int] = None, |
1187 |
| - n_jobs: int = 1 |
| 1182 | + self, |
| 1183 | + X_test: np.ndarray, |
| 1184 | + batch_size: Optional[int] = None, |
| 1185 | + n_jobs: int = 1 |
1188 | 1186 | ) -> np.ndarray:
|
1189 | 1187 | """Generate the estimator predictions.
|
1190 | 1188 | Generate the predictions based on the given examples from the test set.
|
@@ -1234,9 +1232,9 @@ def predict(
|
1234 | 1232 | return predictions
|
1235 | 1233 |
|
1236 | 1234 | def score(
|
1237 |
| - self, |
1238 |
| - y_pred: np.ndarray, |
1239 |
| - y_test: Union[np.ndarray, pd.DataFrame] |
| 1235 | + self, |
| 1236 | + y_pred: np.ndarray, |
| 1237 | + y_test: Union[np.ndarray, pd.DataFrame] |
1240 | 1238 | ) -> Dict[str, float]:
|
1241 | 1239 | """Calculate the score on the test set.
|
1242 | 1240 | Calculate the evaluation measure on the test set.
|
@@ -1277,17 +1275,37 @@ def __del__(self) -> None:
|
1277 | 1275 | if hasattr(self, '_backend'):
|
1278 | 1276 | self._backend.context.delete_directories(force=False)
|
1279 | 1277 |
|
1280 |
| - @typing.no_type_check |
1281 | 1278 | def get_incumbent_results(
|
1282 |
| - self |
1283 |
| - ): |
1284 |
| - pass |
| 1279 | + self, |
| 1280 | + include_traditional: bool = False |
| 1281 | + ) -> Tuple[Configuration, Dict[str, Union[int, str, float]]]: |
| 1282 | + """ |
| 1283 | + Get Incumbent config and the corresponding results |
| 1284 | + Args: |
| 1285 | + include_traditional: Whether to include results from tradtional pipelines |
1285 | 1286 |
|
1286 |
| - @typing.no_type_check |
1287 |
| - def get_incumbent_config( |
1288 |
| - self |
1289 |
| - ): |
1290 |
| - pass |
| 1287 | + Returns: |
| 1288 | +
|
| 1289 | + """ |
| 1290 | + assert self.run_history is not None, "No Run History found, search has not been called." |
| 1291 | + if self.run_history.empty(): |
| 1292 | + raise ValueError("Run History is empty. Something went wrong, " |
| 1293 | + "smac was not able to fit any model?") |
| 1294 | + |
| 1295 | + run_history_data = self.run_history.data |
| 1296 | + if not include_traditional: |
| 1297 | + # traditional classifiers have trainer_configuration in their additional info |
| 1298 | + run_history_data = dict( |
| 1299 | + filter(lambda elem: elem[1].additional_info is not None and elem[1]. |
| 1300 | + additional_info['configuration_origin'] != 'traditional', |
| 1301 | + run_history_data.items())) |
| 1302 | + run_history_data = dict( |
| 1303 | + filter(lambda elem: 'SUCCESS' in str(elem[1].status), run_history_data.items())) |
| 1304 | + sorted_runvalue_by_cost = sorted(run_history_data.items(), key=lambda item: item[1].cost) |
| 1305 | + incumbent_run_key, incumbent_run_value = sorted_runvalue_by_cost[0] |
| 1306 | + incumbent_config = self.run_history.ids_config[incumbent_run_key.config_id] |
| 1307 | + incumbent_results = incumbent_run_value.additional_info |
| 1308 | + return incumbent_config, incumbent_results |
1291 | 1309 |
|
1292 | 1310 | def get_models_with_weights(self) -> List:
|
1293 | 1311 | if self.models_ is None or len(self.models_) == 0 or \
|
|
0 commit comments