Skip to content

Commit

Permalink
Update HPO interface (#4035)
Browse files Browse the repository at this point in the history
* update hpo interface

* update unit test

* update CHANGELOG.md
  • Loading branch information
eunwoosh authored Oct 15, 2024
1 parent 91d2df2 commit a3c9fc6
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4011>)
- Prevent using too low confidence thresholds in detection
(<https://github.com/openvinotoolkit/training_extensions/pull/4018>)
- Update HPO interface
(<https://github.com/openvinotoolkit/training_extensions/pull/4035>)

### Bug fixes

Expand Down
11 changes: 9 additions & 2 deletions src/otx/core/config/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dataclasses import dataclass
from pathlib import Path # noqa: TCH003
from typing import Any, Literal
from typing import Any, Callable, Literal

import torch

Expand All @@ -23,7 +23,12 @@

@dataclass
class HpoConfig:
"""DTO for HPO configuration."""
"""DTO for HPO configuration.
progress_update_callback (Callable[[int | float], None] | None):
callback to update progress. If it's given, it's called with progress every second.
callbacks_to_exclude (list[str] | str | None): List of name of callbacks to exclude during HPO.
"""

search_space: dict[str, dict[str, Any]] | str | Path | None = None
save_path: str | None = None
Expand All @@ -40,3 +45,5 @@ class HpoConfig:
asynchronous_sha: bool = num_workers > 1
metric_name: str | None = None
adapt_bs_search_space_max_val: Literal["None", "Safe", "Full"] = "None"
progress_update_callback: Callable[[int | float], None] | None = None
callbacks_to_exclude: list[str] | str | None = None
25 changes: 19 additions & 6 deletions src/otx/engine/hpo/hpo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import json
import logging
import time
from copy import copy
from functools import partial
from pathlib import Path
from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, Literal

import torch
import yaml
from lightning import Callback

from otx.core.config.hpo import HpoConfig
from otx.core.optimizer.callable import OptimizerCallableSupportHPO
Expand All @@ -35,7 +37,6 @@
from .utils import find_trial_file, get_best_hpo_weight, get_callable_args_name, get_hpo_weight_dir, get_metric

if TYPE_CHECKING:
from lightning import Callback
from lightning.pytorch.cli import OptimizerCallable

from otx.engine.engine import Engine
Expand All @@ -48,7 +49,6 @@ def execute_hpo(
engine: Engine,
max_epochs: int,
hpo_config: HpoConfig,
progress_update_callback: Callable[[int | float], None] | None = None,
callbacks: list[Callback] | Callback | None = None,
**train_args,
) -> tuple[dict[str, Any] | None, Path | None]:
Expand All @@ -58,8 +58,6 @@ def execute_hpo(
engine (Engine): engine instnace.
max_epochs (int): max epochs to train.
hpo_config (HpoConfig): Configuration for HPO.
progress_update_callback (Callable[[int | float], None] | None, optional):
callback to update progress. If it's given, it's called with progress every second. Defaults to None.
callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None.
Returns:
Expand Down Expand Up @@ -97,8 +95,23 @@ def execute_hpo(
logger.warning("HPO is skipped.")
return None, None

if progress_update_callback is not None:
Thread(target=_update_hpo_progress, args=[progress_update_callback, hpo_algo], daemon=True).start()
if hpo_config.progress_update_callback is not None:
Thread(target=_update_hpo_progress, args=[hpo_config.progress_update_callback, hpo_algo], daemon=True).start()

if hpo_config.callbacks_to_exclude is not None and callbacks is not None:
if isinstance(hpo_config.callbacks_to_exclude, str):
hpo_config.callbacks_to_exclude = [hpo_config.callbacks_to_exclude]
if isinstance(callbacks, Callback):
callbacks = [callbacks]

callbacks = copy(callbacks)
callback_names = [callback.__class__.__name__ for callback in callbacks]
callback_idx_to_exclude = [
callback_names.index(cb_name) for cb_name in hpo_config.callbacks_to_exclude if cb_name in callback_names
]
sorted(callback_idx_to_exclude, reverse=True)
for idx in callback_idx_to_exclude:
callbacks.pop(idx)

run_hpo_loop(
hpo_algo,
Expand Down
25 changes: 22 additions & 3 deletions tests/unit/engine/hpo/test_hpo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,27 @@ def mock_find_trial_file(mocker) -> MagicMock:

@pytest.fixture()
def hpo_config() -> HpoConfig:
return HpoConfig(metric_name="val/accuracy")
return HpoConfig(metric_name="val/accuracy", callbacks_to_exclude="UselessCallback")


@pytest.fixture()
def mock_progress_update_callback() -> MagicMock:
return MagicMock()


class UsefullCallback:
pass


class UselessCallback:
pass


@pytest.fixture()
def mock_callback() -> list:
return [UsefullCallback(), UselessCallback()]


def test_execute_hpo(
mock_engine: MagicMock,
hpo_config: HpoConfig,
Expand All @@ -138,12 +151,14 @@ def test_execute_hpo(
mock_get_best_hpo_weight: MagicMock,
mock_find_trial_file: MagicMock,
mock_progress_update_callback: MagicMock,
mock_callback: list,
):
hpo_config.progress_update_callback = mock_progress_update_callback
best_config, best_hpo_weight = execute_hpo(
engine=mock_engine,
max_epochs=10,
hpo_config=hpo_config,
progress_update_callback=mock_progress_update_callback,
callbacks=mock_callback,
)

# check hpo workdir exists
Expand All @@ -152,12 +167,16 @@ def test_execute_hpo(
# check a case where progress_update_callback exists
mock_thread.assert_called_once()
assert mock_thread.call_args.kwargs["target"] == _update_hpo_progress
assert mock_thread.call_args.kwargs["args"][0] == mock_progress_update_callback
assert mock_thread.call_args.kwargs["daemon"] is True
mock_thread.return_value.start.assert_called_once()
# check whether run_hpo_loop is called well
mock_run_hpo_loop.assert_called_once()
assert mock_run_hpo_loop.call_args.args[0] == mock_hpo_algo
# check UselessCallback is excluded
for callback in mock_run_hpo_loop.call_args.args[1].keywords["callbacks"]:
assert not isinstance(callback, UselessCallback)
# check origincal callback lists isn't changed.
assert len(mock_callback) == 2
# print_result is called after HPO is done
mock_hpo_algo.print_result.assert_called_once()
# best_config and best_hpo_weight are returned well
Expand Down

0 comments on commit a3c9fc6

Please # to comment.