diff --git a/ignite/base/mixins.py b/ignite/base/mixins.py index 1e187243d99b..1d73ec43e8fb 100644 --- a/ignite/base/mixins.py +++ b/ignite/base/mixins.py @@ -1,11 +1,19 @@ from collections import OrderedDict from collections.abc import Mapping +from typing import List, Tuple class Serializable: - _state_dict_all_req_keys = () # type: tuple - _state_dict_one_of_opt_keys = () # type: tuple + _state_dict_all_req_keys = () # type: Tuple[str, ...] + _state_dict_one_of_opt_keys = ((),) # type: Tuple[Tuple[str, ...], ...] + + def __init__(self) -> None: + self._state_dict_user_keys = [] # type: List[str] + + @property + def state_dict_user_keys(self) -> List: + return self._state_dict_user_keys def state_dict(self) -> OrderedDict: pass @@ -19,6 +27,13 @@ def load_state_dict(self, state_dict: Mapping) -> None: raise ValueError( f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" ) - opts = [k in state_dict for k in self._state_dict_one_of_opt_keys] - if len(opts) > 0 and ((not any(opts)) or (all(opts))): - raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys") + for one_of_opt_keys in self._state_dict_one_of_opt_keys: + opts = [k in state_dict for k in one_of_opt_keys] + if len(opts) > 0 and (not any(opts)) or (all(opts)): + raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys") + + for k in self._state_dict_user_keys: + if k not in state_dict: + raise ValueError( + f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" + ) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index da666b9ddc4b..8bd1772576e4 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -1,12 +1,11 @@ import functools import logging -import math import time import warnings import weakref from collections import OrderedDict, defaultdict from collections.abc import Mapping -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sized, Tuple, Union from torch.utils.data import DataLoader @@ -118,10 +117,11 @@ def compute_mean_std(engine, batch): """ - _state_dict_all_req_keys = ("epoch_length", "max_epochs") - _state_dict_one_of_opt_keys = ("iteration", "epoch") + _state_dict_all_req_keys = ("epoch_length",) + _state_dict_one_of_opt_keys = (("iteration", "epoch",), ("max_epochs", "max_iters",)) def __init__(self, process_function: Callable): + super(Engine, self).__init__() self._event_handlers = defaultdict(list) # type: Dict[Any, List] self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) self._process_function = process_function @@ -129,7 +129,6 @@ def __init__(self, process_function: Callable): self.should_terminate = False self.should_terminate_single_epoch = False self.state = State() - self._state_dict_user_keys = [] # type: List[str] self._allowed_events = [] # type: List[EventEnum] self._dataloader_iter = None # type: Optional[Iterator[Any]] @@ -469,13 +468,9 @@ def _handle_exception(self, e: BaseException) -> None: else: raise e - @property - def state_dict_user_keys(self) -> List: - return self._state_dict_user_keys - def state_dict(self) -> OrderedDict: - """Returns a dictionary containing engine's state: "epoch_length", "max_epochs" and "iteration" and - other state values defined by `engine.state_dict_user_keys` + """Returns a dictionary containing engine's state: "epoch_length", "iteration", "max_iters" or "max_epoch" + and other state values defined by ``engine.state_dict_user_keys``. .. code-block:: python @@ -500,15 +495,20 @@ def save_engine(_): a dictionary containing engine's state """ - keys = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) # type: Tuple[str, ...] + keys = self._state_dict_all_req_keys # type: Tuple[str, ...] + keys += ("iteration",) + if self.state.max_epochs is not None: + keys += ("max_epochs",) + else: + keys += ("max_iters",) keys += tuple(self._state_dict_user_keys) return OrderedDict([(k, getattr(self.state, k)) for k in keys]) def load_state_dict(self, state_dict: Mapping) -> None: """Setups engine from `state_dict`. - State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` and `epoch_length`. - If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary. + State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` or `max_iters` and `epoch_length`. + If ``engine.state_dict_user_keys`` contains keys, they should be also present in the state dictionary. Iteration and epoch values are 0-based: the first iteration or epoch is zero. This method does not remove any custom attributes added by user. @@ -530,13 +530,9 @@ def load_state_dict(self, state_dict: Mapping) -> None: """ super(Engine, self).load_state_dict(state_dict) - for k in self._state_dict_user_keys: - if k not in state_dict: - raise ValueError( - f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" - ) - self.state.max_epochs = state_dict["max_epochs"] - self.state.epoch_length = state_dict["epoch_length"] + for k in self._state_dict_all_req_keys: + setattr(self.state, k, state_dict[k]) + for k in self._state_dict_user_keys: setattr(self.state, k, state_dict[k]) @@ -545,7 +541,7 @@ def load_state_dict(self, state_dict: Mapping) -> None: self.state.epoch = 0 if self.state.epoch_length is not None: self.state.epoch = self.state.iteration // self.state.epoch_length - elif "epoch" in state_dict: + else: self.state.epoch = state_dict["epoch"] if self.state.epoch_length is None: raise ValueError( @@ -554,6 +550,9 @@ def load_state_dict(self, state_dict: Mapping) -> None: ) self.state.iteration = self.state.epoch_length * self.state.epoch + self._check_and_set_max_epochs(state_dict.get("max_epochs", None)) + self._check_and_set_max_iters(state_dict.get("max_iters", None)) + @staticmethod def _is_done(state: State) -> bool: is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters @@ -612,12 +611,12 @@ def run( Engine has a state and the following logic is applied in this function: - - At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, if provided. + - At the first call, new state is defined by `max_epochs` or `max_iters` and `epoch_length`, if provided. A timer for total and per-epoch time is initialized when Events.STARTED is handled. - - If state is already defined such that there are iterations to run until `max_epochs` and no input arguments - provided, state is kept and used in the function. - - If state is defined and engine is "done" (no iterations to run until `max_epochs`), a new state is defined. - - If state is defined, engine is NOT "done", then input arguments if provided override defined state. + - If state is defined such that there are iterations to run until `max_epochs` or `max_iters` + and no input arguments provided, state is kept and used in the function. + - If engine is "done" (no iterations to run until `max_epochs`), a new state is defined. + - If engine is NOT "done", then input arguments if provided override defined state. Args: data (Iterable): Collection of batches allowing repeated iteration (e.g., list or `DataLoader`). @@ -637,7 +636,8 @@ def run( Note: User can dynamically preprocess input batch at :attr:`~ignite.engine.events.Events.ITERATION_STARTED` and - store output batch in `engine.state.batch`. Latter is passed as usually to `process_function` as argument: + store output batch in ``engine.state.batch``. Latter is passed as usually to ``process_function`` + as argument: .. code-block:: python @@ -662,54 +662,39 @@ def switch_batch(engine): if not isinstance(data, Iterable): raise TypeError("Argument data should be iterable") - if self.state.max_epochs is not None: - # Check and apply overridden parameters - if max_epochs is not None: - if max_epochs < self.state.epoch: - raise ValueError( - "Argument max_epochs should be larger than the start epoch " - f"defined in the state: {max_epochs} vs {self.state.epoch}. " - "Please, set engine.state.max_epochs = None " - "before calling engine.run() in order to restart the training from the beginning." - ) - self.state.max_epochs = max_epochs - if epoch_length is not None: - if epoch_length != self.state.epoch_length: - raise ValueError( - "Argument epoch_length should be same as in the state, " - f"but given {epoch_length} vs {self.state.epoch_length}" - ) + if max_epochs is not None and max_iters is not None: + raise ValueError( + "Arguments max_iters and max_epochs are mutually exclusive. " + "Please provide only max_epochs or max_iters." + ) - if self.state.max_epochs is None or self._is_done(self.state): - # Create new state - if epoch_length is None: - epoch_length = self._get_data_length(data) - if epoch_length is not None and epoch_length < 1: - raise ValueError("Input data has zero size. Please provide non-empty data") - - if max_iters is None: - if max_epochs is None: - max_epochs = 1 - else: - if max_epochs is not None: - raise ValueError( - "Arguments max_iters and max_epochs are mutually exclusive." - "Please provide only max_epochs or max_iters." - ) - if epoch_length is not None: - max_epochs = math.ceil(max_iters / epoch_length) + self._check_and_set_max_epochs(max_epochs) + self._check_and_set_max_iters(max_iters) + self._check_and_set_epoch_length(data, epoch_length) + if self.state.max_epochs is None and self.state.max_iters is None: + self.state.max_epochs = 1 + + if self.state.max_epochs is not None and self.state.max_iters is not None: + raise ValueError( + "State attributes max_iters and max_epochs are mutually exclusive. " + "Please set max_epochs or max_iters to None" + ) + + msg = "Engine run starting with {}." + if self._is_done(self.state): + # Reset iteration/epoch counters self.state.iteration = 0 self.state.epoch = 0 - self.state.max_epochs = max_epochs - self.state.max_iters = max_iters - self.state.epoch_length = epoch_length - self.logger.info(f"Engine run starting with max_epochs={max_epochs}.") - else: - self.logger.info( - f"Engine run resuming from iteration {self.state.iteration}, " - f"epoch {self.state.epoch} until {self.state.max_epochs} epochs" - ) + elif self.state.iteration > 0: + msg = f"Engine run resuming from iteration {self.state.iteration}, epoch {self.state.epoch} " + "until {}." + + if self.state.max_epochs is not None: + msg = msg.format(f"max_epochs={self.state.max_epochs}") + elif self.state.max_iters is not None: + msg = msg.format(f"max_iters={self.state.max_iters}") + + self.logger.info(msg) self.state.dataloader = data return self._internal_run() @@ -719,7 +704,7 @@ def _init_timers(state: State) -> None: state.times[Events.EPOCH_COMPLETED.name] = 0.0 state.times[Events.COMPLETED.name] = 0.0 - def _get_data_length(self, data: Iterable) -> Optional[int]: + def _get_data_length(self, data: Union[Iterable, Sized]) -> Optional[int]: try: if hasattr(data, "__len__"): return len(data) # type: ignore[arg-type] @@ -728,6 +713,53 @@ def _get_data_length(self, data: Iterable) -> Optional[int]: pass return None + def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None: + if max_epochs is not None: + if max_epochs < 1: + raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value") + if self.state.max_epochs is not None and max_epochs <= self.state.epoch: + raise ValueError( + "Argument max_epochs should be larger than the current epoch " + f"defined in the state: {max_epochs} vs {self.state.epoch}. " + "Please, set engine.state.max_epochs = None " + "before calling engine.run() in order to restart the training from the beginning." + ) + self.state.max_epochs = max_epochs + + def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None: + if max_iters is not None: + if max_iters < 1: + raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value") + if (self.state.max_iters is not None) and max_iters <= self.state.iteration: + raise ValueError( + "Argument max_iters should be larger than the current iteration " + f"defined in the state: {max_iters} vs {self.state.iteration}. " + "Please, set engine.state.max_iters = None " + "before calling engine.run() in order to restart the training from the beginning." + ) + self.state.max_iters = max_iters + + def _check_and_set_epoch_length(self, data: Iterable, epoch_length: Optional[int] = None) -> None: + # Can't we accept a redefinition ? + if self.state.epoch_length is not None: + if epoch_length is not None: + if epoch_length != self.state.epoch_length: + raise ValueError( + "Argument epoch_length should be same as in the state, " + f"but given {epoch_length} vs {self.state.epoch_length}" + ) + else: + if epoch_length is None: + epoch_length = self._get_data_length(data) + + if epoch_length is not None and epoch_length < 1: + raise ValueError( + "Argument epoch_length is invalid. Please, either set a correct epoch_length value or " + "check if input data has non-zero size." + ) + + self.state.epoch_length = epoch_length + def _setup_engine(self) -> None: if self.state.dataloader is None: raise RuntimeError( @@ -805,7 +837,6 @@ def _run_once_on_dataset(self) -> float: raise RuntimeError( "Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error." ) - while True: try: # Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted @@ -820,8 +851,6 @@ def _run_once_on_dataset(self) -> float: if self.state.epoch_length is None: # Define epoch length and stop the epoch self.state.epoch_length = iter_counter - if self.state.max_iters is not None: - self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length) break # Should exit while loop if we can not iterate @@ -838,6 +867,7 @@ def _run_once_on_dataset(self) -> float: "iterations to run is not reached. " f"Current iteration: {self.state.iteration} vs Total iterations to run : {total_iters}" ) + self.should_terminate = True break self._fire_event(Events.DATALOADER_STOP_ITERATION) @@ -873,3 +903,14 @@ def _run_once_on_dataset(self) -> float: self._handle_exception(e) return time.time() - start_time + + def debug(self, enabled: bool = True) -> None: + """Enables/disables engine's logging debug mode + """ + from ignite.utils import setup_logger + + if enabled: + setattr(self, "_stored_logger", self.logger) + self.logger = setup_logger(level=logging.DEBUG) + elif hasattr(self, "_stored_logger"): + self.logger = getattr(self, "_stored_logger") diff --git a/tests/ignite/base/test_mixins.py b/tests/ignite/base/test_mixins.py index 78c79c37f7c4..37d3906980c0 100644 --- a/tests/ignite/base/test_mixins.py +++ b/tests/ignite/base/test_mixins.py @@ -1,7 +1,34 @@ +import pytest + from ignite.base import Serializable +class ExampleSerializable(Serializable): + _state_dict_all_req_keys = ("a", "b") + _state_dict_one_of_opt_keys = (("c", "d"), ("e", "f")) + + def test_load_state_dict(): - s = Serializable() - s.load_state_dict({}) + s = ExampleSerializable() + with pytest.raises(TypeError, match=r"Argument state_dict should be a dictionary"): + s.load_state_dict("abc") + + with pytest.raises(ValueError, match=r"is absent in provided state_dict"): + s.load_state_dict({}) + + with pytest.raises(ValueError, match=r"is absent in provided state_dict"): + s.load_state_dict({"a": 1}) + + with pytest.raises(ValueError, match=r"state_dict should contain only one of"): + s.load_state_dict({"a": 1, "b": 2}) + + with pytest.raises(ValueError, match=r"state_dict should contain only one of"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) + + with pytest.raises(ValueError, match=r"state_dict should contain only one of"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 5, "f": 5}) + + s.state_dict_user_keys.append("alpha") + with pytest.raises(ValueError, match=r"Required user state attribute"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 4}) diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 968486848024..0a881a5f0f76 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -1,3 +1,4 @@ +import math import os import time from unittest.mock import MagicMock, Mock, call @@ -334,6 +335,15 @@ def test__is_done(): state = State(iteration=1000, max_epochs=10, epoch_length=100) assert Engine._is_done(state) + state = State(iteration=11, epoch=2, max_epochs=None, epoch_length=11, max_iters=22) + assert not Engine._is_done(state) + + state = State(iteration=100, epoch=1, max_epochs=None, epoch_length=100, max_iters=250) + assert not Engine._is_done(state) + + state = State(iteration=250, epoch=1, max_epochs=None, epoch_length=100, max_iters=250) + assert Engine._is_done(state) + def test__setup_engine(): engine = Engine(lambda e, b: 1) @@ -343,15 +353,28 @@ def test__setup_engine(): engine.state.dataloader = data engine._setup_engine() assert len(engine._init_iter) == 1 and engine._init_iter[0] == 10 - # assert engine._dataloader_len == len(data) def test_run_asserts(): engine = Engine(lambda e, b: 1) - with pytest.raises(ValueError, match=r"Input data has zero size. Please provide non-empty data"): + with pytest.raises( + ValueError, + match=r"Argument epoch_length is invalid. Please, either set a correct epoch_length " + r"value or check if input data has non-zero size.", + ): engine.run([]) + with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than"): + engine.state.max_epochs = 5 + engine.state.epoch = 5 + engine.run([0, 1], max_epochs=3) + + with pytest.raises(ValueError, match=r"Argument max_iters should be larger than"): + engine.state.max_iters = 100 + engine.state.iteration = 100 + engine.run([0, 1], max_iters=50) + def test_state_get_event_attrib_value(): state = State() @@ -412,7 +435,20 @@ def check_completed_time(): _test(list(range(200)), max_epochs=5, epoch_length=100) -def _test_check_triggered_events(data, max_epochs, epoch_length, exp_iter_stops=None): +def _test_check_triggered_events( + data, + max_epochs=None, + epoch_length=None, + max_iters=None, + n_epoch_started=None, + n_epoch_completed=None, + n_iter_started=None, + n_iter_completed=None, + n_batch_started=None, + n_batch_completed=None, + n_dl_stops=None, + n_terminate=None, +): engine = Engine(lambda e, b: 1) events = [ Events.STARTED, @@ -424,6 +460,8 @@ def _test_check_triggered_events(data, max_epochs, epoch_length, exp_iter_stops= Events.GET_BATCH_STARTED, Events.GET_BATCH_COMPLETED, Events.DATALOADER_STOP_ITERATION, + Events.TERMINATE, + Events.TERMINATE_SINGLE_EPOCH, ] handlers = {e: MagicMock() for e in events} @@ -431,18 +469,40 @@ def _test_check_triggered_events(data, max_epochs, epoch_length, exp_iter_stops= for e, handler in handlers.items(): engine.add_event_handler(e, handler) - engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length) + engine.run(data, max_epochs=max_epochs, max_iters=max_iters, epoch_length=epoch_length) + + if epoch_length is None: + epoch_length = engine.state.epoch_length + + assert epoch_length is not None + + if n_iter_started is None: + n_iter_started = max_epochs * epoch_length if max_epochs is not None else max_iters + + if n_iter_completed is None: + n_iter_completed = max_epochs * epoch_length if max_epochs is not None else max_iters + + if n_batch_started is None: + n_batch_started = max_epochs * epoch_length if max_epochs is not None else max_iters + + if n_batch_completed is None: + n_batch_completed = max_epochs * epoch_length if max_epochs is not None else max_iters + + if n_terminate is None: + n_terminate = int(n_epoch_started != n_epoch_completed) if max_iters is not None else 0 expected_num_calls = { Events.STARTED: 1, Events.COMPLETED: 1, - Events.EPOCH_STARTED: max_epochs, - Events.EPOCH_COMPLETED: max_epochs, - Events.ITERATION_STARTED: max_epochs * epoch_length, - Events.ITERATION_COMPLETED: max_epochs * epoch_length, - Events.GET_BATCH_STARTED: max_epochs * epoch_length, - Events.GET_BATCH_COMPLETED: max_epochs * epoch_length, - Events.DATALOADER_STOP_ITERATION: (max_epochs - 1) if exp_iter_stops is None else exp_iter_stops, + Events.EPOCH_STARTED: n_epoch_started if n_epoch_started is not None else max_epochs, + Events.EPOCH_COMPLETED: n_epoch_completed if n_epoch_completed is not None else max_epochs, + Events.ITERATION_STARTED: n_iter_started, + Events.ITERATION_COMPLETED: n_iter_completed, + Events.GET_BATCH_STARTED: n_batch_started, + Events.GET_BATCH_COMPLETED: n_batch_completed, + Events.DATALOADER_STOP_ITERATION: n_dl_stops if n_dl_stops is not None else (max_epochs - 1), + Events.TERMINATE: n_terminate, + Events.TERMINATE_SINGLE_EPOCH: 0, } for n, handler in handlers.items(): @@ -451,10 +511,16 @@ def _test_check_triggered_events(data, max_epochs, epoch_length, exp_iter_stops= def _test_run_check_triggered_events(): # tests issue https://github.com/pytorch/ignite/issues/818 - _test_check_triggered_events(list(range(10)), max_epochs=4, epoch_length=10) - _test_check_triggered_events(list(range(100)), max_epochs=5, epoch_length=100) - _test_check_triggered_events(list(range(100)), max_epochs=5, epoch_length=50, exp_iter_stops=50 * 5 // 100) - _test_check_triggered_events(list(range(100)), max_epochs=5, epoch_length=150, exp_iter_stops=150 * 5 // 100) + _test_check_triggered_events(list(range(20)), max_epochs=5, epoch_length=20) + _test_check_triggered_events(list(range(20)), max_epochs=5, epoch_length=10, n_dl_stops=10 * 5 // 20) + _test_check_triggered_events(list(range(20)), max_epochs=5, epoch_length=25, n_dl_stops=25 * 5 // 20) + + kwargs = dict(n_dl_stops=4, n_epoch_started=5, n_epoch_completed=5) + _test_check_triggered_events(list(range(20)), max_iters=100, epoch_length=20, **kwargs) + kwargs = dict(n_dl_stops=2, n_epoch_started=5, n_epoch_completed=5) + _test_check_triggered_events(list(range(20)), max_iters=50, epoch_length=10, **kwargs) + kwargs = dict(n_dl_stops=2, n_epoch_started=3, n_epoch_completed=2) + _test_check_triggered_events(list(range(20)), max_iters=55, epoch_length=25, **kwargs) def test_run_check_triggered_events_list(): @@ -464,32 +530,93 @@ def test_run_check_triggered_events_list(): def _test_run_check_triggered_events_on_iterator(): def infinite_data_iterator(): while True: - for i in range(100): + for i in range(12): yield i - _test_check_triggered_events(infinite_data_iterator(), max_epochs=5, epoch_length=100, exp_iter_stops=0) - _test_check_triggered_events(infinite_data_iterator(), max_epochs=5, epoch_length=50, exp_iter_stops=0) - _test_check_triggered_events(infinite_data_iterator(), max_epochs=5, epoch_length=150, exp_iter_stops=0) + _test_check_triggered_events(infinite_data_iterator(), max_epochs=5, epoch_length=20, n_dl_stops=0) + + kwargs = dict(n_dl_stops=0, n_epoch_started=5, n_epoch_completed=5) + _test_check_triggered_events(infinite_data_iterator(), max_iters=100, epoch_length=20, **kwargs) + kwargs = dict(n_dl_stops=0, n_epoch_started=1, n_epoch_completed=0) + _test_check_triggered_events(infinite_data_iterator(), max_iters=10, epoch_length=20, **kwargs) + kwargs = dict(n_dl_stops=0, n_epoch_started=2, n_epoch_completed=1) + _test_check_triggered_events(infinite_data_iterator(), max_iters=30, epoch_length=20, **kwargs) def limited_data_iterator(): - for i in range(100): + for i in range(20): yield i - _test_check_triggered_events(limited_data_iterator(), max_epochs=1, epoch_length=100, exp_iter_stops=0) - _test_check_triggered_events(limited_data_iterator(), max_epochs=10, epoch_length=10, exp_iter_stops=0) - - # These tests will fail - with pytest.raises(AssertionError): - with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): - _test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=100) - - with pytest.raises(AssertionError): - with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): - _test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=75) - - with pytest.raises(AssertionError): - with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): - _test_check_triggered_events(limited_data_iterator(), max_epochs=1, epoch_length=101) + _test_check_triggered_events(limited_data_iterator(), max_epochs=1, epoch_length=20, n_dl_stops=0) + _test_check_triggered_events(limited_data_iterator(), max_epochs=5, epoch_length=4, n_dl_stops=0) + + kwargs = dict(n_dl_stops=0, n_epoch_started=1, n_epoch_completed=1) + _test_check_triggered_events(limited_data_iterator(), max_iters=20, epoch_length=20, **kwargs) + kwargs = dict(n_dl_stops=0, n_epoch_started=2, n_epoch_completed=1) + _test_check_triggered_events(limited_data_iterator(), max_iters=19, epoch_length=10, **kwargs) + + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + kwargs = dict( + n_dl_stops=1, + n_epoch_started=2, + n_epoch_completed=1, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=21, + n_batch_completed=20, + n_terminate=1, + ) + _test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=20, **kwargs) + + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + kwargs = dict( + n_dl_stops=1, + n_epoch_started=2, + n_epoch_completed=1, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=22, # 22 and not 21. GET_BATCH_STARTED is called once more to epoch_length + n_batch_completed=20, + n_terminate=1, + ) + _test_check_triggered_events(limited_data_iterator(), max_epochs=3, **kwargs) + + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + kwargs = dict( + n_dl_stops=1, + n_epoch_started=2, + n_epoch_completed=1, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=21, + n_batch_completed=20, + n_terminate=1, + ) + _test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=15, **kwargs) + + kwargs = dict( + n_dl_stops=1, + n_epoch_started=1, + n_epoch_completed=0, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=21, + n_batch_completed=20, + n_terminate=1, + ) + _test_check_triggered_events(limited_data_iterator(), max_epochs=1, epoch_length=21, **kwargs) + + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + kwargs = dict( + n_dl_stops=1, + n_epoch_started=2, + n_epoch_completed=1, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=21, + n_batch_completed=20, + n_terminate=1, + ) + _test_check_triggered_events(limited_data_iterator(), max_iters=21, epoch_length=12, **kwargs) def test_run_check_triggered_events_on_iterator(): @@ -619,15 +746,25 @@ def test_engine_with_dataloader_no_auto_batching(): data, batch_size=None, sampler=BatchSampler(RandomSampler(data), batch_size=8, drop_last=True) ) - counter = [0] + def _test(**kwargs): + counter = [0] - def foo(e, b): - counter[0] += 1 + def foo(e, b): + counter[0] += 1 - engine = Engine(foo) - engine.run(data_loader, epoch_length=10, max_epochs=5) + engine = Engine(foo) + epoch_length = 10 + engine.run(data_loader, epoch_length=epoch_length, **kwargs) - assert counter[0] == 50 + max_epochs = kwargs.get("max_epochs", None) + max_iters = kwargs.get("max_iters", None) + if max_epochs: + assert counter[0] == epoch_length * max_epochs + else: + assert counter[0] == max_iters + + _test(max_epochs=5) + _test(max_iters=25) def test_run_once_finite_iterator_no_epoch_length(): @@ -639,19 +776,46 @@ def finite_unk_size_data_iter(): for i in range(unknown_size): yield i - bc = BatchChecker(data=list(range(unknown_size))) + def _test(**kwargs): + bc = BatchChecker(data=list(range(unknown_size))) - engine = Engine(lambda e, b: bc.check(b)) + def foo(e, b): + print(e.state.iteration, ":", b) + bc.check(b) - completed_handler = MagicMock() - engine.add_event_handler(Events.COMPLETED, completed_handler) + # engine = Engine(lambda e, b: bc.check(b)) + engine = Engine(foo) - data_iter = finite_unk_size_data_iter() - engine.run(data_iter) + epoch_completed_handler = MagicMock() + engine.add_event_handler(Events.EPOCH_COMPLETED, epoch_completed_handler) + + completed_handler = MagicMock() + engine.add_event_handler(Events.COMPLETED, completed_handler) - assert engine.state.epoch == 1 - assert engine.state.iteration == unknown_size - assert completed_handler.call_count == 1 + data_iter = finite_unk_size_data_iter() + engine.run(data_iter, **kwargs) + + assert bc.counter == engine.state.iteration + if len(kwargs) == 0: + assert engine.state.epoch == 1 + assert engine.state.iteration == unknown_size + assert epoch_completed_handler.call_count == 1 + else: + max_iters = kwargs["max_iters"] + if max_iters <= unknown_size: + assert engine.state.epoch == 1 + assert engine.state.iteration == max_iters + else: + assert engine.state.epoch == 2 + assert engine.state.iteration == unknown_size + + assert completed_handler.call_count == 1 + + _test() + _test(max_iters=unknown_size) + _test(max_iters=unknown_size // 2) + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + _test(max_iters=unknown_size * 2) def test_run_finite_iterator_no_epoch_length(): @@ -685,156 +849,187 @@ def finite_size_data_iter(size): for i in range(size): yield i - bc = BatchChecker(data=list(range(known_size))) + def _test(**kwargs): + bc = BatchChecker(data=list(range(known_size))) - engine = Engine(lambda e, b: bc.check(b)) + engine = Engine(lambda e, b: bc.check(b)) - @engine.on(Events.ITERATION_COMPLETED(every=known_size)) - def restart_iter(): - engine.state.dataloader = finite_size_data_iter(known_size) + @engine.on(Events.ITERATION_COMPLETED(every=known_size)) + def restart_iter(): + engine.state.dataloader = finite_size_data_iter(known_size) - data_iter = finite_size_data_iter(known_size) - engine.run(data_iter, max_epochs=5) + data_iter = finite_size_data_iter(known_size) + engine.run(data_iter, **kwargs) - assert engine.state.epoch == 5 - assert engine.state.iteration == known_size * 5 + assert bc.counter == engine.state.iteration + if "max_epochs" in kwargs: + assert engine.state.epoch == kwargs["max_epochs"] + assert engine.state.iteration == known_size * kwargs["max_epochs"] + else: + max_iters = kwargs["max_iters"] + if max_iters <= known_size: + assert engine.state.epoch == math.ceil(max_iters / known_size) + assert engine.state.iteration == max_iters + + _test(max_epochs=5) + _test(max_iters=known_size) + _test(max_iters=known_size // 2) def test_faq_inf_iterator_with_epoch_length(): - # Code snippet from FAQ + def _test(max_epochs, max_iters): + # Code snippet from FAQ - import torch + import torch - torch.manual_seed(12) + torch.manual_seed(12) - def infinite_iterator(batch_size): - while True: - batch = torch.rand(batch_size, 3, 32, 32) - yield batch + def infinite_iterator(batch_size): + while True: + batch = torch.rand(batch_size, 3, 32, 32) + yield batch - def train_step(trainer, batch): - # ... - s = trainer.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") + def train_step(trainer, batch): + # ... + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") + + trainer = Engine(train_step) + # We need to specify epoch_length to define the epoch + trainer.run(infinite_iterator(4), epoch_length=5, max_epochs=max_epochs, max_iters=max_iters) - trainer = Engine(train_step) - # We need to specify epoch_length to define the epoch - trainer.run(infinite_iterator(4), epoch_length=5, max_epochs=3) + assert trainer.state.epoch == 3 + assert trainer.state.iteration == 3 * 5 - assert trainer.state.epoch == 3 - assert trainer.state.iteration == 3 * 5 + _test(max_epochs=3, max_iters=None) + _test(max_epochs=None, max_iters=3 * 5) def test_faq_inf_iterator_no_epoch_length(): - # Code snippet from FAQ + def _test(max_epochs, max_iters): + # Code snippet from FAQ - import torch + import torch - torch.manual_seed(12) + torch.manual_seed(12) - def infinite_iterator(batch_size): - while True: - batch = torch.rand(batch_size, 3, 32, 32) - yield batch + def infinite_iterator(batch_size): + while True: + batch = torch.rand(batch_size, 3, 32, 32) + yield batch - def train_step(trainer, batch): - # ... - s = trainer.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") + def train_step(trainer, batch): + # ... + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") + + trainer = Engine(train_step) - trainer = Engine(train_step) + @trainer.on(Events.ITERATION_COMPLETED(once=15)) + def stop_training(): + trainer.terminate() - @trainer.on(Events.ITERATION_COMPLETED(once=15)) - def stop_training(): - trainer.terminate() + trainer.run(infinite_iterator(4), max_epochs=max_epochs, max_iters=max_iters) - trainer.run(infinite_iterator(4)) + assert trainer.state.epoch == 1 + assert trainer.state.iteration == 15 - assert trainer.state.epoch == 1 - assert trainer.state.iteration == 15 + _test(max_epochs=None, max_iters=None) + _test(max_epochs=None, max_iters=100) def test_faq_fin_iterator_unknw_size(): - # Code snippet from FAQ + def _test(max_epochs, max_iters): + # Code snippet from FAQ - import torch + import torch - torch.manual_seed(12) + torch.manual_seed(12) - def finite_unk_size_data_iter(): - for i in range(11): - yield i + def finite_unk_size_data_iter(): + for i in range(11): + yield i - def train_step(trainer, batch): - # ... - s = trainer.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") + def train_step(trainer, batch): + # ... + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") - trainer = Engine(train_step) + trainer = Engine(train_step) - @trainer.on(Events.DATALOADER_STOP_ITERATION) - def restart_iter(): - trainer.state.dataloader = finite_unk_size_data_iter() + @trainer.on(Events.DATALOADER_STOP_ITERATION) + def restart_iter(): + trainer.state.dataloader = finite_unk_size_data_iter() - data_iter = finite_unk_size_data_iter() - trainer.run(data_iter, max_epochs=5) + data_iter = finite_unk_size_data_iter() + trainer.run(data_iter, max_epochs=max_epochs, max_iters=max_iters) - assert trainer.state.epoch == 5 - assert trainer.state.iteration == 5 * 11 + assert trainer.state.epoch == 5 if max_iters is None else math.ceil(max_iters // 11) + assert trainer.state.iteration == 5 * 11 if max_iters is None else max_iters + + _test(max_epochs=5, max_iters=None) + _test(max_epochs=None, max_iters=60) # # # # # - import torch + def _test(max_epochs, max_iters): + import torch - torch.manual_seed(12) + torch.manual_seed(12) - def finite_unk_size_data_iter(): - for i in range(11): - yield i + def finite_unk_size_data_iter(): + for i in range(11): + yield i - def val_step(evaluator, batch): - # ... - s = evaluator.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") + def val_step(evaluator, batch): + # ... + s = evaluator.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") - evaluator = Engine(val_step) + evaluator = Engine(val_step) - data_iter = finite_unk_size_data_iter() - evaluator.run(data_iter) + data_iter = finite_unk_size_data_iter() + evaluator.run(data_iter, max_epochs=max_epochs, max_iters=max_iters) - assert evaluator.state.epoch == 1 - assert evaluator.state.iteration == 1 * 11 + assert evaluator.state.epoch == 1 + assert evaluator.state.iteration == 1 * 11 + + _test(max_epochs=None, max_iters=None) def test_faq_fin_iterator(): - # Code snippet from FAQ + def _test(max_epochs, max_iters): + # Code snippet from FAQ - import torch + import torch - torch.manual_seed(12) + torch.manual_seed(12) - size = 11 + size = 11 - def finite_size_data_iter(size): - for i in range(size): - yield i + def finite_size_data_iter(size): + for i in range(size): + yield i - def train_step(trainer, batch): - # ... - s = trainer.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") + def train_step(trainer, batch): + # ... + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") - trainer = Engine(train_step) + trainer = Engine(train_step) - @trainer.on(Events.ITERATION_COMPLETED(every=size)) - def restart_iter(): - trainer.state.dataloader = finite_size_data_iter(size) + @trainer.on(Events.ITERATION_COMPLETED(every=size)) + def restart_iter(): + trainer.state.dataloader = finite_size_data_iter(size) - data_iter = finite_size_data_iter(size) - trainer.run(data_iter, max_epochs=5) + data_iter = finite_size_data_iter(size) + trainer.run(data_iter, max_epochs=max_epochs, max_iters=max_iters) + + assert trainer.state.epoch == 5 + assert trainer.state.iteration == 5 * size - assert trainer.state.epoch == 5 - assert trainer.state.iteration == 5 * size + _test(max_epochs=5, max_iters=None) + _test(max_epochs=None, max_iters=5 * 11) # # # # # @@ -929,9 +1124,79 @@ def fired_event(engine): engine.run([0] * 10, max_iters=max_iters) -def test_is_done_with_max_iters(): - state = State(iteration=100, epoch=1, max_epochs=3, epoch_length=100, max_iters=250) - assert not Engine._is_done(state) +def test_restart_training(): + data = range(10) + engine = Engine(lambda e, b: 1) + state = engine.run(data, max_epochs=5) + with pytest.raises( + ValueError, + match=r"Argument max_epochs should be larger than the current epoch defined in the state: 2 vs 5. " + r"Please, .+ " + r"before calling engine.run\(\) in order to restart the training from the beginning.", + ): + engine.run(data, max_epochs=2) + state.max_epochs = None + engine.run(data, max_epochs=2) - state = State(iteration=250, epoch=1, max_epochs=3, epoch_length=100, max_iters=250) - assert Engine._is_done(state) + +def test_engine_multiple_runs(): + engine = Engine(lambda e, b: 1) + engine.debug() + + init_epoch = 0 + init_iter = 0 + epoch_length = None + + @engine.on(Events.STARTED) + def assert_resume(): + assert engine.state.epoch == init_epoch + assert engine.state.iteration == init_iter + assert engine.state.epoch_length == epoch_length + + data = range(10) + epoch_length = len(data) + engine.run(data, max_epochs=2) + assert engine.state.epoch == 2 + assert engine.state.iteration == 2 * epoch_length + + engine.debug(False) + + # Continue run with max_epochs + data = range(15) + init_epoch = 2 + init_iter = 2 * epoch_length + engine.run(data, max_epochs=5) + + assert engine.state.epoch == 5 + assert engine.state.iteration == 5 * epoch_length + + # Continue run with max_iters + data = range(15) + init_epoch = 5 + init_iter = 5 * epoch_length + with pytest.raises(ValueError, match=r"State attributes max_iters and max_epochs are mutually exclusive"): + engine.run(data, max_iters=6 * epoch_length) + + engine.state.max_epochs = None + engine.run(data, max_iters=6 * epoch_length) + + assert engine.state.epoch == 6 + assert engine.state.iteration == 6 * epoch_length + + +def test_engine_multiple_runs_2(): + + e = Engine(lambda _, b: None) + data = iter(range(100)) + + e.run(data, max_iters=50) + assert e.state.iteration == 50 + assert e.state.epoch == 1 + e.run(data, max_iters=52) + assert e.state.iteration == 52 + # should be 1 and if 2 this is a bug : https://github.com/pytorch/ignite/issues/1386 + assert e.state.epoch == 2 + e.run(data, max_iters=100) + assert e.state.iteration == 100 + # should be 1 and if 3 this is a bug : https://github.com/pytorch/ignite/issues/1386 + assert e.state.epoch == 3 diff --git a/tests/ignite/engine/test_engine_state_dict.py b/tests/ignite/engine/test_engine_state_dict.py index ae01c53f3992..50ac569cd7fd 100644 --- a/tests/ignite/engine/test_engine_state_dict.py +++ b/tests/ignite/engine/test_engine_state_dict.py @@ -13,19 +13,24 @@ def test_state_dict(): sd = engine.state_dict() assert isinstance(sd, Mapping) and len(sd) == 3 assert "iteration" in sd and sd["iteration"] == 0 - assert "max_epochs" in sd and sd["max_epochs"] is None + assert "max_iters" in sd and sd["max_iters"] is None assert "epoch_length" in sd and sd["epoch_length"] is None def _test(state): engine.state = state sd = engine.state_dict() - assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 + assert isinstance(sd, Mapping) + assert len(sd) == len(engine._state_dict_all_req_keys) + len(engine._state_dict_one_of_opt_keys) assert sd["iteration"] == engine.state.iteration assert sd["epoch_length"] == engine.state.epoch_length - assert sd["max_epochs"] == engine.state.max_epochs + if state.max_epochs is not None: + assert sd["max_epochs"] == engine.state.max_epochs + else: + assert sd["max_iters"] == engine.state.max_iters _test(State(iteration=500, epoch_length=1000, max_epochs=100)) _test(State(epoch=5, epoch_length=1000, max_epochs=100)) + _test(State(epoch=5, epoch_length=1000, max_iters=500)) def test_state_dict_with_user_keys(): @@ -36,37 +41,49 @@ def test_state_dict_with_user_keys(): def _test(state): engine.state = state sd = engine.state_dict() - assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 + len( - engine.state_dict_user_keys - ) + assert isinstance(sd, Mapping) + sd_size = len(engine._state_dict_all_req_keys) + len(engine._state_dict_one_of_opt_keys) + sd_size += len(engine._state_dict_user_keys) + assert len(sd) == sd_size assert sd["iteration"] == engine.state.iteration assert sd["epoch_length"] == engine.state.epoch_length - assert sd["max_epochs"] == engine.state.max_epochs + if state.max_epochs is not None: + assert sd["max_epochs"] == engine.state.max_epochs + else: + assert sd["max_iters"] == engine.state.max_iters assert sd["alpha"] == engine.state.alpha assert sd["beta"] == engine.state.beta _test(State(iteration=500, epoch_length=1000, max_epochs=100, alpha=0.01, beta="Good")) + _test(State(iteration=500, epoch_length=1000, max_iters=2000, alpha=0.01, beta="Good")) def test_state_dict_integration(): - engine = Engine(lambda e, b: 1) - data = range(100) - engine.run(data, max_epochs=10) - sd = engine.state_dict() - assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 - assert sd["iteration"] == engine.state.iteration == 10 * 100 - assert sd["epoch_length"] == engine.state.epoch_length == 100 - assert sd["max_epochs"] == engine.state.max_epochs == 10 + def _test(max_epochs, max_iters): + engine = Engine(lambda e, b: 1) + data = range(100) + engine.run(data, max_epochs=max_epochs, max_iters=max_iters) + sd = engine.state_dict() + assert isinstance(sd, Mapping) + assert len(sd) == len(engine._state_dict_all_req_keys) + len(engine._state_dict_one_of_opt_keys) + if max_epochs is None and max_iters is None: + max_epochs = 1 + n_iters = max_iters if max_iters is not None else max_epochs * 100 + assert sd["iteration"] == engine.state.iteration == n_iters + assert sd["epoch_length"] == engine.state.epoch_length == 100 + if engine.state.max_epochs is not None: + assert sd["max_epochs"] == engine.state.max_epochs == max_epochs + else: + assert sd["max_iters"] == engine.state.max_iters == max_iters -def test_load_state_dict_asserts(): - engine = Engine(lambda e, b: 1) + _test(max_epochs=10, max_iters=None) + _test(max_epochs=None, max_iters=None) + _test(max_epochs=None, max_iters=10 * 100) - with pytest.raises(TypeError, match=r"Argument state_dict should be a dictionary"): - engine.load_state_dict("123") - with pytest.raises(ValueError, match=r"is absent in provided state_dict"): - engine.load_state_dict({}) +def test_load_state_dict_asserts(): + engine = Engine(lambda e, b: 1) with pytest.raises(ValueError, match=r"state_dict should contain only one of"): engine.load_state_dict({"max_epochs": 100, "epoch_length": 120}) @@ -94,11 +111,29 @@ def _test(sd): elif "epoch" in sd: assert sd["epoch"] == engine.state.epoch assert sd["epoch_length"] == engine.state.epoch_length - assert sd["max_epochs"] == engine.state.max_epochs + if "max_epochs" in sd: + assert sd["max_epochs"] == engine.state.max_epochs + else: + assert sd["max_iters"] == engine.state.max_iters _test({"max_epochs": 100, "epoch_length": 120, "iteration": 123}) _test({"max_epochs": 100, "epoch_length": 120, "epoch": 5}) + with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than"): + _test({"max_epochs": 10, "epoch_length": 120, "epoch": 50}) + + with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than"): + _test({"max_epochs": 10, "epoch_length": 120, "iteration": 5000}) + + _test({"max_iters": 500, "epoch_length": 120, "iteration": 123}) + _test({"max_iters": 500, "epoch_length": 120, "epoch": 3}) + + with pytest.raises(ValueError, match=r"Argument max_iters should be larger than"): + _test({"max_iters": 500, "epoch_length": 120, "epoch": 5}) + + with pytest.raises(ValueError, match=r"Argument max_iters should be larger than"): + _test({"max_iters": 500, "epoch_length": 120, "iteration": 501}) + def test_load_state_dict_with_user_keys(): engine = Engine(lambda e, b: 1) @@ -145,7 +180,7 @@ def test_load_state_dict_with_params_overriding_integration(): assert state.iteration == state_dict["epoch_length"] * new_max_epochs assert state.epoch == new_max_epochs - with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than the start epoch"): + with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than the current epoch"): engine.load_state_dict(state_dict) engine.run(data, max_epochs=3) @@ -263,18 +298,3 @@ def check_custom_attr(): _test() _test(with_load_state_dict=True) - - -def test_restart_training(): - data = range(10) - engine = Engine(lambda e, b: 1) - state = engine.run(data, max_epochs=5) - with pytest.raises( - ValueError, - match=r"Argument max_epochs should be larger than the start epoch defined in the state: 2 vs 5. " - r"Please, .+ " - r"before calling engine.run\(\) in order to restart the training from the beginning.", - ): - state = engine.run(data, max_epochs=2) - state.max_epochs = None - engine.run(data, max_epochs=2)