-
-
Notifications
You must be signed in to change notification settings - Fork 648
Improved max_iters handling #1565
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: master
Are you sure you want to change the base?
Changes from all commits
b89771b
9b95c24
bd22c35
e7bf139
cc02908
e9a8767
25fbc4a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,11 @@ | ||
import functools | ||
import logging | ||
import math | ||
import time | ||
import warnings | ||
import weakref | ||
from collections import OrderedDict, defaultdict | ||
from collections.abc import Mapping | ||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union | ||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sized, Tuple, Union | ||
|
||
from torch.utils.data import DataLoader | ||
|
||
|
@@ -118,18 +117,18 @@ def compute_mean_std(engine, batch): | |
|
||
""" | ||
|
||
_state_dict_all_req_keys = ("epoch_length", "max_epochs") | ||
_state_dict_one_of_opt_keys = ("iteration", "epoch") | ||
_state_dict_all_req_keys = ("epoch_length",) | ||
_state_dict_one_of_opt_keys = (("iteration", "epoch",), ("max_epochs", "max_iters",)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can relax the constraint that the both |
||
|
||
def __init__(self, process_function: Callable): | ||
super(Engine, self).__init__() | ||
self._event_handlers = defaultdict(list) # type: Dict[Any, List] | ||
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) | ||
self._process_function = process_function | ||
self.last_event_name = None # type: Optional[Events] | ||
self.should_terminate = False | ||
self.should_terminate_single_epoch = False | ||
self.state = State() | ||
self._state_dict_user_keys = [] # type: List[str] | ||
self._allowed_events = [] # type: List[EventEnum] | ||
|
||
self._dataloader_iter = None # type: Optional[Iterator[Any]] | ||
|
@@ -469,13 +468,9 @@ def _handle_exception(self, e: BaseException) -> None: | |
else: | ||
raise e | ||
|
||
@property | ||
def state_dict_user_keys(self) -> List: | ||
return self._state_dict_user_keys | ||
|
||
def state_dict(self) -> OrderedDict: | ||
"""Returns a dictionary containing engine's state: "epoch_length", "max_epochs" and "iteration" and | ||
other state values defined by `engine.state_dict_user_keys` | ||
"""Returns a dictionary containing engine's state: "epoch_length", "iteration", "max_iters" or "max_epoch" | ||
and other state values defined by ``engine.state_dict_user_keys``. | ||
|
||
.. code-block:: python | ||
|
||
|
@@ -500,15 +495,20 @@ def save_engine(_): | |
a dictionary containing engine's state | ||
|
||
""" | ||
keys = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) # type: Tuple[str, ...] | ||
keys = self._state_dict_all_req_keys # type: Tuple[str, ...] | ||
keys += ("iteration",) | ||
if self.state.max_epochs is not None: | ||
keys += ("max_epochs",) | ||
else: | ||
keys += ("max_iters",) | ||
keys += tuple(self._state_dict_user_keys) | ||
return OrderedDict([(k, getattr(self.state, k)) for k in keys]) | ||
|
||
def load_state_dict(self, state_dict: Mapping) -> None: | ||
"""Setups engine from `state_dict`. | ||
|
||
State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` and `epoch_length`. | ||
If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary. | ||
State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` or `max_iters` and `epoch_length`. | ||
If ``engine.state_dict_user_keys`` contains keys, they should be also present in the state dictionary. | ||
Iteration and epoch values are 0-based: the first iteration or epoch is zero. | ||
|
||
This method does not remove any custom attributes added by user. | ||
|
@@ -530,13 +530,9 @@ def load_state_dict(self, state_dict: Mapping) -> None: | |
""" | ||
super(Engine, self).load_state_dict(state_dict) | ||
|
||
for k in self._state_dict_user_keys: | ||
if k not in state_dict: | ||
raise ValueError( | ||
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" | ||
) | ||
self.state.max_epochs = state_dict["max_epochs"] | ||
self.state.epoch_length = state_dict["epoch_length"] | ||
for k in self._state_dict_all_req_keys: | ||
setattr(self.state, k, state_dict[k]) | ||
|
||
for k in self._state_dict_user_keys: | ||
setattr(self.state, k, state_dict[k]) | ||
|
||
|
@@ -545,7 +541,7 @@ def load_state_dict(self, state_dict: Mapping) -> None: | |
self.state.epoch = 0 | ||
if self.state.epoch_length is not None: | ||
self.state.epoch = self.state.iteration // self.state.epoch_length | ||
elif "epoch" in state_dict: | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to |
||
self.state.epoch = state_dict["epoch"] | ||
if self.state.epoch_length is None: | ||
raise ValueError( | ||
|
@@ -554,6 +550,9 @@ def load_state_dict(self, state_dict: Mapping) -> None: | |
) | ||
self.state.iteration = self.state.epoch_length * self.state.epoch | ||
|
||
self._check_and_set_max_epochs(state_dict.get("max_epochs", None)) | ||
self._check_and_set_max_iters(state_dict.get("max_iters", None)) | ||
|
||
@staticmethod | ||
def _is_done(state: State) -> bool: | ||
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters | ||
|
@@ -612,12 +611,12 @@ def run( | |
|
||
Engine has a state and the following logic is applied in this function: | ||
|
||
- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, if provided. | ||
- At the first call, new state is defined by `max_epochs` or `max_iters` and `epoch_length`, if provided. | ||
A timer for total and per-epoch time is initialized when Events.STARTED is handled. | ||
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments | ||
provided, state is kept and used in the function. | ||
- If state is defined and engine is "done" (no iterations to run until `max_epochs`), a new state is defined. | ||
- If state is defined, engine is NOT "done", then input arguments if provided override defined state. | ||
- If state is defined such that there are iterations to run until `max_epochs` or `max_iters` | ||
and no input arguments provided, state is kept and used in the function. | ||
- If engine is "done" (no iterations to run until `max_epochs`), a new state is defined. | ||
- If engine is NOT "done", then input arguments if provided override defined state. | ||
|
||
Args: | ||
data (Iterable): Collection of batches allowing repeated iteration (e.g., list or `DataLoader`). | ||
|
@@ -637,7 +636,8 @@ def run( | |
|
||
Note: | ||
User can dynamically preprocess input batch at :attr:`~ignite.engine.events.Events.ITERATION_STARTED` and | ||
store output batch in `engine.state.batch`. Latter is passed as usually to `process_function` as argument: | ||
store output batch in ``engine.state.batch``. Latter is passed as usually to ``process_function`` | ||
as argument: | ||
|
||
.. code-block:: python | ||
|
||
|
@@ -662,54 +662,39 @@ def switch_batch(engine): | |
if not isinstance(data, Iterable): | ||
raise TypeError("Argument data should be iterable") | ||
|
||
if self.state.max_epochs is not None: | ||
# Check and apply overridden parameters | ||
if max_epochs is not None: | ||
if max_epochs < self.state.epoch: | ||
raise ValueError( | ||
"Argument max_epochs should be larger than the start epoch " | ||
f"defined in the state: {max_epochs} vs {self.state.epoch}. " | ||
"Please, set engine.state.max_epochs = None " | ||
"before calling engine.run() in order to restart the training from the beginning." | ||
) | ||
self.state.max_epochs = max_epochs | ||
if epoch_length is not None: | ||
if epoch_length != self.state.epoch_length: | ||
raise ValueError( | ||
"Argument epoch_length should be same as in the state, " | ||
f"but given {epoch_length} vs {self.state.epoch_length}" | ||
) | ||
if max_epochs is not None and max_iters is not None: | ||
raise ValueError( | ||
"Arguments max_iters and max_epochs are mutually exclusive. " | ||
"Please provide only max_epochs or max_iters." | ||
) | ||
|
||
if self.state.max_epochs is None or self._is_done(self.state): | ||
# Create new state | ||
if epoch_length is None: | ||
epoch_length = self._get_data_length(data) | ||
if epoch_length is not None and epoch_length < 1: | ||
raise ValueError("Input data has zero size. Please provide non-empty data") | ||
|
||
if max_iters is None: | ||
if max_epochs is None: | ||
max_epochs = 1 | ||
else: | ||
if max_epochs is not None: | ||
raise ValueError( | ||
"Arguments max_iters and max_epochs are mutually exclusive." | ||
"Please provide only max_epochs or max_iters." | ||
) | ||
if epoch_length is not None: | ||
max_epochs = math.ceil(max_iters / epoch_length) | ||
self._check_and_set_max_epochs(max_epochs) | ||
self._check_and_set_max_iters(max_iters) | ||
self._check_and_set_epoch_length(data, epoch_length) | ||
|
||
if self.state.max_epochs is None and self.state.max_iters is None: | ||
self.state.max_epochs = 1 | ||
|
||
if self.state.max_epochs is not None and self.state.max_iters is not None: | ||
raise ValueError( | ||
"State attributes max_iters and max_epochs are mutually exclusive. " | ||
"Please set max_epochs or max_iters to None" | ||
) | ||
|
||
msg = "Engine run starting with {}." | ||
if self._is_done(self.state): | ||
# Reset iteration/epoch counters | ||
self.state.iteration = 0 | ||
self.state.epoch = 0 | ||
self.state.max_epochs = max_epochs | ||
self.state.max_iters = max_iters | ||
self.state.epoch_length = epoch_length | ||
self.logger.info(f"Engine run starting with max_epochs={max_epochs}.") | ||
else: | ||
self.logger.info( | ||
f"Engine run resuming from iteration {self.state.iteration}, " | ||
f"epoch {self.state.epoch} until {self.state.max_epochs} epochs" | ||
) | ||
elif self.state.iteration > 0: | ||
msg = f"Engine run resuming from iteration {self.state.iteration}, epoch {self.state.epoch} " + "until {}." | ||
|
||
if self.state.max_epochs is not None: | ||
msg = msg.format(f"max_epochs={self.state.max_epochs}") | ||
elif self.state.max_iters is not None: | ||
msg = msg.format(f"max_iters={self.state.max_iters}") | ||
|
||
self.logger.info(msg) | ||
|
||
self.state.dataloader = data | ||
return self._internal_run() | ||
|
@@ -719,7 +704,7 @@ def _init_timers(state: State) -> None: | |
state.times[Events.EPOCH_COMPLETED.name] = 0.0 | ||
state.times[Events.COMPLETED.name] = 0.0 | ||
|
||
def _get_data_length(self, data: Iterable) -> Optional[int]: | ||
def _get_data_length(self, data: Union[Iterable, Sized]) -> Optional[int]: | ||
try: | ||
if hasattr(data, "__len__"): | ||
return len(data) # type: ignore[arg-type] | ||
|
@@ -728,6 +713,53 @@ def _get_data_length(self, data: Iterable) -> Optional[int]: | |
pass | ||
return None | ||
|
||
def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None: | ||
if max_epochs is not None: | ||
if max_epochs < 1: | ||
raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value") | ||
if self.state.max_epochs is not None and max_epochs <= self.state.epoch: | ||
raise ValueError( | ||
"Argument max_epochs should be larger than the current epoch " | ||
f"defined in the state: {max_epochs} vs {self.state.epoch}. " | ||
"Please, set engine.state.max_epochs = None " | ||
"before calling engine.run() in order to restart the training from the beginning." | ||
) | ||
self.state.max_epochs = max_epochs | ||
|
||
def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None: | ||
if max_iters is not None: | ||
if max_iters < 1: | ||
raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value") | ||
if (self.state.max_iters is not None) and max_iters <= self.state.iteration: | ||
raise ValueError( | ||
"Argument max_iters should be larger than the current iteration " | ||
f"defined in the state: {max_iters} vs {self.state.iteration}. " | ||
"Please, set engine.state.max_iters = None " | ||
"before calling engine.run() in order to restart the training from the beginning." | ||
) | ||
self.state.max_iters = max_iters | ||
|
||
def _check_and_set_epoch_length(self, data: Iterable, epoch_length: Optional[int] = None) -> None: | ||
# Can't we accept a redefinition ? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can't because the relationship between epoch and iteration becomes invalid then. By the way in the case |
||
if self.state.epoch_length is not None: | ||
if epoch_length is not None: | ||
if epoch_length != self.state.epoch_length: | ||
raise ValueError( | ||
"Argument epoch_length should be same as in the state, " | ||
f"but given {epoch_length} vs {self.state.epoch_length}" | ||
) | ||
else: | ||
if epoch_length is None: | ||
epoch_length = self._get_data_length(data) | ||
|
||
if epoch_length is not None and epoch_length < 1: | ||
raise ValueError( | ||
"Argument epoch_length is invalid. Please, either set a correct epoch_length value or " | ||
"check if input data has non-zero size." | ||
) | ||
|
||
self.state.epoch_length = epoch_length | ||
|
||
def _setup_engine(self) -> None: | ||
if self.state.dataloader is None: | ||
raise RuntimeError( | ||
|
@@ -805,7 +837,6 @@ def _run_once_on_dataset(self) -> float: | |
raise RuntimeError( | ||
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error." | ||
) | ||
|
||
while True: | ||
try: | ||
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted | ||
|
@@ -820,8 +851,6 @@ def _run_once_on_dataset(self) -> float: | |
if self.state.epoch_length is None: | ||
# Define epoch length and stop the epoch | ||
self.state.epoch_length = iter_counter | ||
if self.state.max_iters is not None: | ||
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length) | ||
break | ||
|
||
# Should exit while loop if we can not iterate | ||
|
@@ -838,6 +867,7 @@ def _run_once_on_dataset(self) -> float: | |
"iterations to run is not reached. " | ||
f"Current iteration: {self.state.iteration} vs Total iterations to run : {total_iters}" | ||
) | ||
self.should_terminate = True | ||
break | ||
|
||
self._fire_event(Events.DATALOADER_STOP_ITERATION) | ||
|
@@ -873,3 +903,14 @@ def _run_once_on_dataset(self) -> float: | |
self._handle_exception(e) | ||
|
||
return time.time() - start_time | ||
|
||
def debug(self, enabled: bool = True) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you still want to merge this method into the master? |
||
"""Enables/disables engine's logging debug mode | ||
""" | ||
from ignite.utils import setup_logger | ||
|
||
if enabled: | ||
setattr(self, "_stored_logger", self.logger) | ||
self.logger = setup_logger(level=logging.DEBUG) | ||
elif hasattr(self, "_stored_logger"): | ||
self.logger = getattr(self, "_stored_logger") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if we interrupted the engine during the first epoch, we would not have
epoch_length
.