diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index 685059baab..2ce6d459cc 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -5,7 +5,7 @@ from .dictionary import Dictionary, TruncatedDictionary -from .fairseq_dataset import FairseqDataset +from .fairseq_dataset import FairseqDataset, FairseqIterableDataset from .base_wrapper_dataset import BaseWrapperDataset @@ -65,6 +65,7 @@ 'Dictionary', 'EpochBatchIterator', 'FairseqDataset', + 'FairseqIterableDataset', 'GroupedIterator', 'IdDataset', 'IndexedCachedDataset', diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index ca6fd47dc1..fe5681be5a 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -7,7 +7,15 @@ import torch.utils.data -class FairseqDataset(torch.utils.data.Dataset): +class EpochListening: + """Mixin for receiving updates whenever the epoch increments.""" + def set_epoch(self, epoch): + """Will receive the updated epoch number at the beginning of the epoch. + """ + pass + + +class FairseqDataset(torch.utils.data.Dataset, EpochListening): """A dataset that provides helpers for batching.""" def __getitem__(self, index): @@ -54,5 +62,11 @@ def prefetch(self, indices): """Prefetch the data required for this epoch.""" raise NotImplementedError - def set_epoch(self, epoch): - pass + +class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): + """For datasets that need to be read sequentially, usually because the data + is being streamed or otherwise can't be manipulated on a single machine. + """ + + def __iter__(self): + raise NotImplementedError diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 83760b3615..813d9cbaca 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -63,6 +63,15 @@ def __len__(self) -> int: raise NotImplementedError def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): + """Return a new iterator over the dataset. + + Args: + shuffle (bool, optional): shuffle batches before returning the + iterator (default: True). + fix_batches_to_gpus: ensure that batches are always + allocated to the same shards across epochs. Requires + that :attr:`dataset` supports prefetching (default: False). + """ raise NotImplementedError def end_of_epoch(self) -> bool: @@ -71,12 +80,15 @@ def end_of_epoch(self) -> bool: @property def iterations_in_epoch(self) -> int: + """The number of consumed batches in the current epoch.""" raise NotImplementedError def state_dict(self): + """Returns a dictionary containing a whole state of the iterator.""" raise NotImplementedError def load_state_dict(self, state_dict): + """Copies the state of the iterator from the given *state_dict*.""" raise NotImplementedError @@ -84,7 +96,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating): def __init__( self, dataset, epoch=0, num_shards=1, shard_id=0, ): - # assert isinstance(dataset, torch.utils.data.Dataset) + assert isinstance(dataset, torch.utils.data.IterableDataset) self.dataset = dataset self.epoch = epoch self._current_epoch_iterator = None @@ -93,6 +105,7 @@ def __init__( def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): self.epoch += 1 + self.dataset.set_epoch(self.epoch) self._current_epoch_iterator = CountingIterator( iterable=ShardedIterator( iterable=self.dataset, diff --git a/fairseq/data/list_dataset.py b/fairseq/data/list_dataset.py index 4d3b01d7bf..b96bba3437 100644 --- a/fairseq/data/list_dataset.py +++ b/fairseq/data/list_dataset.py @@ -12,6 +12,10 @@ def __init__(self, dataset, sizes=None): super().__init__(dataset) self._sizes = sizes + def __iter__(self): + for x in self.dataset: + yield x + def collater(self, samples): return samples