From a3c9fc6e98388d80fb7b46913444d50bcacb4d11 Mon Sep 17 00:00:00 2001 From: Eunwoo Shin Date: Tue, 15 Oct 2024 18:13:43 +0900 Subject: [PATCH] Update HPO interface (#4035) * update hpo interface * update unit test * update CHANGELOG.md --- CHANGELOG.md | 2 ++ src/otx/core/config/hpo.py | 11 +++++++++-- src/otx/engine/hpo/hpo_api.py | 25 +++++++++++++++++++------ tests/unit/engine/hpo/test_hpo_api.py | 25 ++++++++++++++++++++++--- 4 files changed, 52 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c70c9ab18e6..7970722a0bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,8 @@ All notable changes to this project will be documented in this file. () - Prevent using too low confidence thresholds in detection () +- Update HPO interface + () ### Bug fixes diff --git a/src/otx/core/config/hpo.py b/src/otx/core/config/hpo.py index 8d4dd085955..29695631ef8 100644 --- a/src/otx/core/config/hpo.py +++ b/src/otx/core/config/hpo.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from pathlib import Path # noqa: TCH003 -from typing import Any, Literal +from typing import Any, Callable, Literal import torch @@ -23,7 +23,12 @@ @dataclass class HpoConfig: - """DTO for HPO configuration.""" + """DTO for HPO configuration. + + progress_update_callback (Callable[[int | float], None] | None): + callback to update progress. If it's given, it's called with progress every second. + callbacks_to_exclude (list[str] | str | None): List of name of callbacks to exclude during HPO. + """ search_space: dict[str, dict[str, Any]] | str | Path | None = None save_path: str | None = None @@ -40,3 +45,5 @@ class HpoConfig: asynchronous_sha: bool = num_workers > 1 metric_name: str | None = None adapt_bs_search_space_max_val: Literal["None", "Safe", "Full"] = "None" + progress_update_callback: Callable[[int | float], None] | None = None + callbacks_to_exclude: list[str] | str | None = None diff --git a/src/otx/engine/hpo/hpo_api.py b/src/otx/engine/hpo/hpo_api.py index 7fd5983618e..25f097817d9 100644 --- a/src/otx/engine/hpo/hpo_api.py +++ b/src/otx/engine/hpo/hpo_api.py @@ -9,6 +9,7 @@ import json import logging import time +from copy import copy from functools import partial from pathlib import Path from threading import Thread @@ -16,6 +17,7 @@ import torch import yaml +from lightning import Callback from otx.core.config.hpo import HpoConfig from otx.core.optimizer.callable import OptimizerCallableSupportHPO @@ -35,7 +37,6 @@ from .utils import find_trial_file, get_best_hpo_weight, get_callable_args_name, get_hpo_weight_dir, get_metric if TYPE_CHECKING: - from lightning import Callback from lightning.pytorch.cli import OptimizerCallable from otx.engine.engine import Engine @@ -48,7 +49,6 @@ def execute_hpo( engine: Engine, max_epochs: int, hpo_config: HpoConfig, - progress_update_callback: Callable[[int | float], None] | None = None, callbacks: list[Callback] | Callback | None = None, **train_args, ) -> tuple[dict[str, Any] | None, Path | None]: @@ -58,8 +58,6 @@ def execute_hpo( engine (Engine): engine instnace. max_epochs (int): max epochs to train. hpo_config (HpoConfig): Configuration for HPO. - progress_update_callback (Callable[[int | float], None] | None, optional): - callback to update progress. If it's given, it's called with progress every second. Defaults to None. callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None. Returns: @@ -97,8 +95,23 @@ def execute_hpo( logger.warning("HPO is skipped.") return None, None - if progress_update_callback is not None: - Thread(target=_update_hpo_progress, args=[progress_update_callback, hpo_algo], daemon=True).start() + if hpo_config.progress_update_callback is not None: + Thread(target=_update_hpo_progress, args=[hpo_config.progress_update_callback, hpo_algo], daemon=True).start() + + if hpo_config.callbacks_to_exclude is not None and callbacks is not None: + if isinstance(hpo_config.callbacks_to_exclude, str): + hpo_config.callbacks_to_exclude = [hpo_config.callbacks_to_exclude] + if isinstance(callbacks, Callback): + callbacks = [callbacks] + + callbacks = copy(callbacks) + callback_names = [callback.__class__.__name__ for callback in callbacks] + callback_idx_to_exclude = [ + callback_names.index(cb_name) for cb_name in hpo_config.callbacks_to_exclude if cb_name in callback_names + ] + sorted(callback_idx_to_exclude, reverse=True) + for idx in callback_idx_to_exclude: + callbacks.pop(idx) run_hpo_loop( hpo_algo, diff --git a/tests/unit/engine/hpo/test_hpo_api.py b/tests/unit/engine/hpo/test_hpo_api.py index bcc71d8bc9a..8b24dffcf00 100644 --- a/tests/unit/engine/hpo/test_hpo_api.py +++ b/tests/unit/engine/hpo/test_hpo_api.py @@ -119,7 +119,7 @@ def mock_find_trial_file(mocker) -> MagicMock: @pytest.fixture() def hpo_config() -> HpoConfig: - return HpoConfig(metric_name="val/accuracy") + return HpoConfig(metric_name="val/accuracy", callbacks_to_exclude="UselessCallback") @pytest.fixture() @@ -127,6 +127,19 @@ def mock_progress_update_callback() -> MagicMock: return MagicMock() +class UsefullCallback: + pass + + +class UselessCallback: + pass + + +@pytest.fixture() +def mock_callback() -> list: + return [UsefullCallback(), UselessCallback()] + + def test_execute_hpo( mock_engine: MagicMock, hpo_config: HpoConfig, @@ -138,12 +151,14 @@ def test_execute_hpo( mock_get_best_hpo_weight: MagicMock, mock_find_trial_file: MagicMock, mock_progress_update_callback: MagicMock, + mock_callback: list, ): + hpo_config.progress_update_callback = mock_progress_update_callback best_config, best_hpo_weight = execute_hpo( engine=mock_engine, max_epochs=10, hpo_config=hpo_config, - progress_update_callback=mock_progress_update_callback, + callbacks=mock_callback, ) # check hpo workdir exists @@ -152,12 +167,16 @@ def test_execute_hpo( # check a case where progress_update_callback exists mock_thread.assert_called_once() assert mock_thread.call_args.kwargs["target"] == _update_hpo_progress - assert mock_thread.call_args.kwargs["args"][0] == mock_progress_update_callback assert mock_thread.call_args.kwargs["daemon"] is True mock_thread.return_value.start.assert_called_once() # check whether run_hpo_loop is called well mock_run_hpo_loop.assert_called_once() assert mock_run_hpo_loop.call_args.args[0] == mock_hpo_algo + # check UselessCallback is excluded + for callback in mock_run_hpo_loop.call_args.args[1].keywords["callbacks"]: + assert not isinstance(callback, UselessCallback) + # check origincal callback lists isn't changed. + assert len(mock_callback) == 2 # print_result is called after HPO is done mock_hpo_algo.print_result.assert_called_once() # best_config and best_hpo_weight are returned well