Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Update HPO interface #4035

Merged
merged 3 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4011>)
- Prevent using too low confidence thresholds in detection
(<https://github.com/openvinotoolkit/training_extensions/pull/4018>)
- Update HPO interface
(<https://github.com/openvinotoolkit/training_extensions/pull/4035>)

### Bug fixes

Expand Down
11 changes: 9 additions & 2 deletions src/otx/core/config/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dataclasses import dataclass
from pathlib import Path # noqa: TCH003
from typing import Any, Literal
from typing import Any, Callable, Literal

import torch

Expand All @@ -23,7 +23,12 @@

@dataclass
class HpoConfig:
"""DTO for HPO configuration."""
"""DTO for HPO configuration.

progress_update_callback (Callable[[int | float], None] | None):
callback to update progress. If it's given, it's called with progress every second.
callbacks_to_exclude (list[str] | str | None): List of name of callbacks to exclude during HPO.
"""

search_space: dict[str, dict[str, Any]] | str | Path | None = None
save_path: str | None = None
Expand All @@ -40,3 +45,5 @@ class HpoConfig:
asynchronous_sha: bool = num_workers > 1
metric_name: str | None = None
adapt_bs_search_space_max_val: Literal["None", "Safe", "Full"] = "None"
progress_update_callback: Callable[[int | float], None] | None = None
callbacks_to_exclude: list[str] | str | None = None
25 changes: 19 additions & 6 deletions src/otx/engine/hpo/hpo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import json
import logging
import time
from copy import copy
from functools import partial
from pathlib import Path
from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, Literal

import torch
import yaml
from lightning import Callback

from otx.core.config.hpo import HpoConfig
from otx.core.optimizer.callable import OptimizerCallableSupportHPO
Expand All @@ -35,7 +37,6 @@
from .utils import find_trial_file, get_best_hpo_weight, get_callable_args_name, get_hpo_weight_dir, get_metric

if TYPE_CHECKING:
from lightning import Callback
from lightning.pytorch.cli import OptimizerCallable

from otx.engine.engine import Engine
Expand All @@ -48,7 +49,6 @@
engine: Engine,
max_epochs: int,
hpo_config: HpoConfig,
progress_update_callback: Callable[[int | float], None] | None = None,
callbacks: list[Callback] | Callback | None = None,
**train_args,
) -> tuple[dict[str, Any] | None, Path | None]:
Expand All @@ -58,8 +58,6 @@
engine (Engine): engine instnace.
max_epochs (int): max epochs to train.
hpo_config (HpoConfig): Configuration for HPO.
progress_update_callback (Callable[[int | float], None] | None, optional):
callback to update progress. If it's given, it's called with progress every second. Defaults to None.
callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None.

Returns:
Expand Down Expand Up @@ -97,8 +95,23 @@
logger.warning("HPO is skipped.")
return None, None

if progress_update_callback is not None:
Thread(target=_update_hpo_progress, args=[progress_update_callback, hpo_algo], daemon=True).start()
if hpo_config.progress_update_callback is not None:
Thread(target=_update_hpo_progress, args=[hpo_config.progress_update_callback, hpo_algo], daemon=True).start()

if hpo_config.callbacks_to_exclude is not None and callbacks is not None:
if isinstance(hpo_config.callbacks_to_exclude, str):
hpo_config.callbacks_to_exclude = [hpo_config.callbacks_to_exclude]
if isinstance(callbacks, Callback):
callbacks = [callbacks]

Check warning on line 105 in src/otx/engine/hpo/hpo_api.py

View check run for this annotation

Codecov / codecov/patch

src/otx/engine/hpo/hpo_api.py#L105

Added line #L105 was not covered by tests

callbacks = copy(callbacks)
callback_names = [callback.__class__.__name__ for callback in callbacks]
callback_idx_to_exclude = [
callback_names.index(cb_name) for cb_name in hpo_config.callbacks_to_exclude if cb_name in callback_names
]
sorted(callback_idx_to_exclude, reverse=True)
for idx in callback_idx_to_exclude:
callbacks.pop(idx)

run_hpo_loop(
hpo_algo,
Expand Down
25 changes: 22 additions & 3 deletions tests/unit/engine/hpo/test_hpo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,27 @@ def mock_find_trial_file(mocker) -> MagicMock:

@pytest.fixture()
def hpo_config() -> HpoConfig:
return HpoConfig(metric_name="val/accuracy")
return HpoConfig(metric_name="val/accuracy", callbacks_to_exclude="UselessCallback")


@pytest.fixture()
def mock_progress_update_callback() -> MagicMock:
return MagicMock()


class UsefullCallback:
pass


class UselessCallback:
pass


@pytest.fixture()
def mock_callback() -> list:
return [UsefullCallback(), UselessCallback()]


def test_execute_hpo(
mock_engine: MagicMock,
hpo_config: HpoConfig,
Expand All @@ -138,12 +151,14 @@ def test_execute_hpo(
mock_get_best_hpo_weight: MagicMock,
mock_find_trial_file: MagicMock,
mock_progress_update_callback: MagicMock,
mock_callback: list,
):
hpo_config.progress_update_callback = mock_progress_update_callback
best_config, best_hpo_weight = execute_hpo(
engine=mock_engine,
max_epochs=10,
hpo_config=hpo_config,
progress_update_callback=mock_progress_update_callback,
callbacks=mock_callback,
)

# check hpo workdir exists
Expand All @@ -152,12 +167,16 @@ def test_execute_hpo(
# check a case where progress_update_callback exists
mock_thread.assert_called_once()
assert mock_thread.call_args.kwargs["target"] == _update_hpo_progress
assert mock_thread.call_args.kwargs["args"][0] == mock_progress_update_callback
assert mock_thread.call_args.kwargs["daemon"] is True
mock_thread.return_value.start.assert_called_once()
# check whether run_hpo_loop is called well
mock_run_hpo_loop.assert_called_once()
assert mock_run_hpo_loop.call_args.args[0] == mock_hpo_algo
# check UselessCallback is excluded
for callback in mock_run_hpo_loop.call_args.args[1].keywords["callbacks"]:
assert not isinstance(callback, UselessCallback)
# check origincal callback lists isn't changed.
assert len(mock_callback) == 2
# print_result is called after HPO is done
mock_hpo_algo.print_result.assert_called_once()
# best_config and best_hpo_weight are returned well
Expand Down
Loading