Skip to content

Commit e1e3d8e

Browse files
ericspodpre-commit-ci[bot]Eric Kerfoot
authored
Modify Workflow to Allow IterableDataset Inputs (#8263)
### Description This modifies the behaviour of `Workflow` to permit `IterableDataset` to be used correctly. A check against the `epoch_length` value is removed, to allow that value to be `None`, and a test is added to verify this. The length of a data loader is not defined when using iterable datasets, so try/raise is added to allow that to be queried safely. This is related to my work on the streaming support, in my [prototype gist](https://gist.github.com/ericspod/1904713716b45631260784ac3fcd6fb3) I had to provide a bogus epoch length value in the then change it to `None` later once the evaluator object was created. This PR will remove the need for this hack. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> Signed-off-by: Eric Kerfoot <eric.kerfoot@gmail> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <eric.kerfoot@gmail>
1 parent 21920a3 commit e1e3d8e

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

monai/engines/workflow.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from __future__ import annotations
1313

1414
import warnings
15-
from collections.abc import Callable, Iterable, Sequence
15+
from collections.abc import Callable, Iterable, Sequence, Sized
1616
from typing import TYPE_CHECKING, Any
1717

1818
import torch
@@ -121,24 +121,24 @@ def __init__(
121121
to_kwargs: dict | None = None,
122122
amp_kwargs: dict | None = None,
123123
) -> None:
124-
if iteration_update is not None:
125-
super().__init__(iteration_update)
126-
else:
127-
super().__init__(self._iteration)
124+
super().__init__(self._iteration if iteration_update is None else iteration_update)
128125

129126
if isinstance(data_loader, DataLoader):
130-
sampler = data_loader.__dict__["sampler"]
127+
sampler = getattr(data_loader, "sampler", None)
128+
129+
# set the epoch value for DistributedSampler objects when an epoch starts
131130
if isinstance(sampler, DistributedSampler):
132131

133132
@self.on(Events.EPOCH_STARTED)
134133
def set_sampler_epoch(engine: Engine) -> None:
135134
sampler.set_epoch(engine.state.epoch)
136135

137-
if epoch_length is None:
136+
# if the epoch_length isn't given, attempt to get it from the length of the data loader
137+
if epoch_length is None and isinstance(data_loader, Sized):
138+
try:
138139
epoch_length = len(data_loader)
139-
else:
140-
if epoch_length is None:
141-
raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.")
140+
except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type
141+
pass # deliberately leave epoch_length as None
142142

143143
# set all sharable data for the workflow based on Ignite engine.state
144144
self.state: Any = State(
@@ -147,7 +147,7 @@ def set_sampler_epoch(engine: Engine) -> None:
147147
iteration=0,
148148
epoch=0,
149149
max_epochs=max_epochs,
150-
epoch_length=epoch_length,
150+
epoch_length=epoch_length, # None when the dataset is iterable and so has no length
151151
output=None,
152152
batch=None,
153153
metrics={},

tests/test_iterable_dataset.py

+13
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919
import nibabel as nib
2020
import numpy as np
21+
import torch.nn as nn
2122

2223
from monai.data import DataLoader, Dataset, IterableDataset
24+
from monai.engines import SupervisedEvaluator
2325
from monai.transforms import Compose, LoadImaged, SimulateDelayd
2426

2527

@@ -59,6 +61,17 @@ def test_shape(self):
5961
for d in dataloader:
6062
self.assertTupleEqual(d["image"].shape[1:], expected_shape)
6163

64+
def test_supervisedevaluator(self):
65+
"""
66+
Test that a SupervisedEvaluator is compatible with IterableDataset in conjunction with DataLoader.
67+
"""
68+
data = list(range(10))
69+
dl = DataLoader(IterableDataset(data))
70+
evaluator = SupervisedEvaluator(device="cpu", val_data_loader=dl, network=nn.Identity())
71+
evaluator.run() # fails if the epoch length or other internal setup is not done correctly
72+
73+
self.assertEqual(evaluator.state.iteration, len(data))
74+
6275

6376
if __name__ == "__main__":
6477
unittest.main()

0 commit comments

Comments
 (0)