Skip to content

Improved max_iters handling #1565

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions ignite/base/mixins.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from collections import OrderedDict
from collections.abc import Mapping
from typing import List, Tuple


class Serializable:

_state_dict_all_req_keys = () # type: tuple
_state_dict_one_of_opt_keys = () # type: tuple
_state_dict_all_req_keys = () # type: Tuple[str, ...]
_state_dict_one_of_opt_keys = ((),) # type: Tuple[Tuple[str, ...], ...]

def __init__(self) -> None:
self._state_dict_user_keys = [] # type: List[str]

@property
def state_dict_user_keys(self) -> List:
return self._state_dict_user_keys

def state_dict(self) -> OrderedDict:
pass
Expand All @@ -19,6 +27,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
raise ValueError(
f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
)
opts = [k in state_dict for k in self._state_dict_one_of_opt_keys]
if len(opts) > 0 and ((not any(opts)) or (all(opts))):
raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys")
for one_of_opt_keys in self._state_dict_one_of_opt_keys:
opts = [k in state_dict for k in one_of_opt_keys]
if len(opts) > 0 and (not any(opts)) or (all(opts)):
raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys")

for k in self._state_dict_user_keys:
if k not in state_dict:
raise ValueError(
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
)
193 changes: 117 additions & 76 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import functools
import logging
import math
import time
import warnings
import weakref
from collections import OrderedDict, defaultdict
from collections.abc import Mapping
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sized, Tuple, Union

from torch.utils.data import DataLoader

Expand Down Expand Up @@ -118,18 +117,18 @@ def compute_mean_std(engine, batch):

"""

_state_dict_all_req_keys = ("epoch_length", "max_epochs")
_state_dict_one_of_opt_keys = ("iteration", "epoch")
_state_dict_all_req_keys = ("epoch_length",)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if we interrupted the engine during the first epoch, we would not have epoch_length.

_state_dict_one_of_opt_keys = (("iteration", "epoch",), ("max_epochs", "max_iters",))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can relax the constraint that the both max_epochs and max_iters could not have values.


def __init__(self, process_function: Callable):
super(Engine, self).__init__()
self._event_handlers = defaultdict(list) # type: Dict[Any, List]
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
self._process_function = process_function
self.last_event_name = None # type: Optional[Events]
self.should_terminate = False
self.should_terminate_single_epoch = False
self.state = State()
self._state_dict_user_keys = [] # type: List[str]
self._allowed_events = [] # type: List[EventEnum]

self._dataloader_iter = None # type: Optional[Iterator[Any]]
Expand Down Expand Up @@ -469,13 +468,9 @@ def _handle_exception(self, e: BaseException) -> None:
else:
raise e

@property
def state_dict_user_keys(self) -> List:
return self._state_dict_user_keys

def state_dict(self) -> OrderedDict:
"""Returns a dictionary containing engine's state: "epoch_length", "max_epochs" and "iteration" and
other state values defined by `engine.state_dict_user_keys`
"""Returns a dictionary containing engine's state: "epoch_length", "iteration", "max_iters" or "max_epoch"
and other state values defined by ``engine.state_dict_user_keys``.

.. code-block:: python

Expand All @@ -500,15 +495,20 @@ def save_engine(_):
a dictionary containing engine's state

"""
keys = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) # type: Tuple[str, ...]
keys = self._state_dict_all_req_keys # type: Tuple[str, ...]
keys += ("iteration",)
if self.state.max_epochs is not None:
keys += ("max_epochs",)
else:
keys += ("max_iters",)
keys += tuple(self._state_dict_user_keys)
return OrderedDict([(k, getattr(self.state, k)) for k in keys])

def load_state_dict(self, state_dict: Mapping) -> None:
"""Setups engine from `state_dict`.

State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` and `epoch_length`.
If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` or `max_iters` and `epoch_length`.
If ``engine.state_dict_user_keys`` contains keys, they should be also present in the state dictionary.
Iteration and epoch values are 0-based: the first iteration or epoch is zero.

This method does not remove any custom attributes added by user.
Expand All @@ -530,13 +530,9 @@ def load_state_dict(self, state_dict: Mapping) -> None:
"""
super(Engine, self).load_state_dict(state_dict)

for k in self._state_dict_user_keys:
if k not in state_dict:
raise ValueError(
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
)
self.state.max_epochs = state_dict["max_epochs"]
self.state.epoch_length = state_dict["epoch_length"]
for k in self._state_dict_all_req_keys:
setattr(self.state, k, state_dict[k])

for k in self._state_dict_user_keys:
setattr(self.state, k, state_dict[k])

Expand All @@ -545,7 +541,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
self.state.epoch = 0
if self.state.epoch_length is not None:
self.state.epoch = self.state.iteration // self.state.epoch_length
elif "epoch" in state_dict:
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to state_dict this should never happen.

self.state.epoch = state_dict["epoch"]
if self.state.epoch_length is None:
raise ValueError(
Expand All @@ -554,6 +550,9 @@ def load_state_dict(self, state_dict: Mapping) -> None:
)
self.state.iteration = self.state.epoch_length * self.state.epoch

self._check_and_set_max_epochs(state_dict.get("max_epochs", None))
self._check_and_set_max_iters(state_dict.get("max_iters", None))

@staticmethod
def _is_done(state: State) -> bool:
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
Expand Down Expand Up @@ -612,12 +611,12 @@ def run(

Engine has a state and the following logic is applied in this function:

- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, if provided.
- At the first call, new state is defined by `max_epochs` or `max_iters` and `epoch_length`, if provided.
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
provided, state is kept and used in the function.
- If state is defined and engine is "done" (no iterations to run until `max_epochs`), a new state is defined.
- If state is defined, engine is NOT "done", then input arguments if provided override defined state.
- If state is defined such that there are iterations to run until `max_epochs` or `max_iters`
and no input arguments provided, state is kept and used in the function.
- If engine is "done" (no iterations to run until `max_epochs`), a new state is defined.
- If engine is NOT "done", then input arguments if provided override defined state.

Args:
data (Iterable): Collection of batches allowing repeated iteration (e.g., list or `DataLoader`).
Expand All @@ -637,7 +636,8 @@ def run(

Note:
User can dynamically preprocess input batch at :attr:`~ignite.engine.events.Events.ITERATION_STARTED` and
store output batch in `engine.state.batch`. Latter is passed as usually to `process_function` as argument:
store output batch in ``engine.state.batch``. Latter is passed as usually to ``process_function``
as argument:

.. code-block:: python

Expand All @@ -662,54 +662,39 @@ def switch_batch(engine):
if not isinstance(data, Iterable):
raise TypeError("Argument data should be iterable")

if self.state.max_epochs is not None:
# Check and apply overridden parameters
if max_epochs is not None:
if max_epochs < self.state.epoch:
raise ValueError(
"Argument max_epochs should be larger than the start epoch "
f"defined in the state: {max_epochs} vs {self.state.epoch}. "
"Please, set engine.state.max_epochs = None "
"before calling engine.run() in order to restart the training from the beginning."
)
self.state.max_epochs = max_epochs
if epoch_length is not None:
if epoch_length != self.state.epoch_length:
raise ValueError(
"Argument epoch_length should be same as in the state, "
f"but given {epoch_length} vs {self.state.epoch_length}"
)
if max_epochs is not None and max_iters is not None:
raise ValueError(
"Arguments max_iters and max_epochs are mutually exclusive. "
"Please provide only max_epochs or max_iters."
)

if self.state.max_epochs is None or self._is_done(self.state):
# Create new state
if epoch_length is None:
epoch_length = self._get_data_length(data)
if epoch_length is not None and epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")

if max_iters is None:
if max_epochs is None:
max_epochs = 1
else:
if max_epochs is not None:
raise ValueError(
"Arguments max_iters and max_epochs are mutually exclusive."
"Please provide only max_epochs or max_iters."
)
if epoch_length is not None:
max_epochs = math.ceil(max_iters / epoch_length)
self._check_and_set_max_epochs(max_epochs)
self._check_and_set_max_iters(max_iters)
self._check_and_set_epoch_length(data, epoch_length)

if self.state.max_epochs is None and self.state.max_iters is None:
self.state.max_epochs = 1

if self.state.max_epochs is not None and self.state.max_iters is not None:
raise ValueError(
"State attributes max_iters and max_epochs are mutually exclusive. "
"Please set max_epochs or max_iters to None"
)

msg = "Engine run starting with {}."
if self._is_done(self.state):
# Reset iteration/epoch counters
self.state.iteration = 0
self.state.epoch = 0
self.state.max_epochs = max_epochs
self.state.max_iters = max_iters
self.state.epoch_length = epoch_length
self.logger.info(f"Engine run starting with max_epochs={max_epochs}.")
else:
self.logger.info(
f"Engine run resuming from iteration {self.state.iteration}, "
f"epoch {self.state.epoch} until {self.state.max_epochs} epochs"
)
elif self.state.iteration > 0:
msg = f"Engine run resuming from iteration {self.state.iteration}, epoch {self.state.epoch} " + "until {}."

if self.state.max_epochs is not None:
msg = msg.format(f"max_epochs={self.state.max_epochs}")
elif self.state.max_iters is not None:
msg = msg.format(f"max_iters={self.state.max_iters}")

self.logger.info(msg)

self.state.dataloader = data
return self._internal_run()
Expand All @@ -719,7 +704,7 @@ def _init_timers(state: State) -> None:
state.times[Events.EPOCH_COMPLETED.name] = 0.0
state.times[Events.COMPLETED.name] = 0.0

def _get_data_length(self, data: Iterable) -> Optional[int]:
def _get_data_length(self, data: Union[Iterable, Sized]) -> Optional[int]:
try:
if hasattr(data, "__len__"):
return len(data) # type: ignore[arg-type]
Expand All @@ -728,6 +713,53 @@ def _get_data_length(self, data: Iterable) -> Optional[int]:
pass
return None

def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None:
if max_epochs is not None:
if max_epochs < 1:
raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value")
if self.state.max_epochs is not None and max_epochs <= self.state.epoch:
raise ValueError(
"Argument max_epochs should be larger than the current epoch "
f"defined in the state: {max_epochs} vs {self.state.epoch}. "
"Please, set engine.state.max_epochs = None "
"before calling engine.run() in order to restart the training from the beginning."
)
self.state.max_epochs = max_epochs

def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None:
if max_iters is not None:
if max_iters < 1:
raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value")
if (self.state.max_iters is not None) and max_iters <= self.state.iteration:
raise ValueError(
"Argument max_iters should be larger than the current iteration "
f"defined in the state: {max_iters} vs {self.state.iteration}. "
"Please, set engine.state.max_iters = None "
"before calling engine.run() in order to restart the training from the beginning."
)
self.state.max_iters = max_iters

def _check_and_set_epoch_length(self, data: Iterable, epoch_length: Optional[int] = None) -> None:
# Can't we accept a redefinition ?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can't because the relationship between epoch and iteration becomes invalid then. By the way in the case epoch_length is changed, we could reinitialize the state (with epoch and iteration zero along with a message to the user) as well.

if self.state.epoch_length is not None:
if epoch_length is not None:
if epoch_length != self.state.epoch_length:
raise ValueError(
"Argument epoch_length should be same as in the state, "
f"but given {epoch_length} vs {self.state.epoch_length}"
)
else:
if epoch_length is None:
epoch_length = self._get_data_length(data)

if epoch_length is not None and epoch_length < 1:
raise ValueError(
"Argument epoch_length is invalid. Please, either set a correct epoch_length value or "
"check if input data has non-zero size."
)

self.state.epoch_length = epoch_length

def _setup_engine(self) -> None:
if self.state.dataloader is None:
raise RuntimeError(
Expand Down Expand Up @@ -805,7 +837,6 @@ def _run_once_on_dataset(self) -> float:
raise RuntimeError(
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
)

while True:
try:
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
Expand All @@ -820,8 +851,6 @@ def _run_once_on_dataset(self) -> float:
if self.state.epoch_length is None:
# Define epoch length and stop the epoch
self.state.epoch_length = iter_counter
if self.state.max_iters is not None:
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
break

# Should exit while loop if we can not iterate
Expand All @@ -838,6 +867,7 @@ def _run_once_on_dataset(self) -> float:
"iterations to run is not reached. "
f"Current iteration: {self.state.iteration} vs Total iterations to run : {total_iters}"
)
self.should_terminate = True
break

self._fire_event(Events.DATALOADER_STOP_ITERATION)
Expand Down Expand Up @@ -873,3 +903,14 @@ def _run_once_on_dataset(self) -> float:
self._handle_exception(e)

return time.time() - start_time

def debug(self, enabled: bool = True) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still want to merge this method into the master?

"""Enables/disables engine's logging debug mode
"""
from ignite.utils import setup_logger

if enabled:
setattr(self, "_stored_logger", self.logger)
self.logger = setup_logger(level=logging.DEBUG)
elif hasattr(self, "_stored_logger"):
self.logger = getattr(self, "_stored_logger")
Loading