Skip to content

Commit

Permalink
More thorough support for iterable datasets
Browse files Browse the repository at this point in the history
Summary: Using PyTorch IterableDataset for streaming iterators. Such that there is a clean differentiation in interface between datasets that are streaming data and those that support indexed access.

Reviewed By: myleott

Differential Revision: D18438694

fbshipit-source-id: 482857d8357091ea2a6bf819535b09ba7f1a5b7d
  • Loading branch information
Spencer Poff authored and facebook-github-bot committed Nov 12, 2019
1 parent b31849a commit 2a9b4ec
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 5 deletions.
3 changes: 2 additions & 1 deletion fairseq/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .dictionary import Dictionary, TruncatedDictionary

from .fairseq_dataset import FairseqDataset
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset

from .base_wrapper_dataset import BaseWrapperDataset

Expand Down Expand Up @@ -65,6 +65,7 @@
'Dictionary',
'EpochBatchIterator',
'FairseqDataset',
'FairseqIterableDataset',
'GroupedIterator',
'IdDataset',
'IndexedCachedDataset',
Expand Down
20 changes: 17 additions & 3 deletions fairseq/data/fairseq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@
import torch.utils.data


class FairseqDataset(torch.utils.data.Dataset):
class EpochListening:
"""Mixin for receiving updates whenever the epoch increments."""
def set_epoch(self, epoch):
"""Will receive the updated epoch number at the beginning of the epoch.
"""
pass


class FairseqDataset(torch.utils.data.Dataset, EpochListening):
"""A dataset that provides helpers for batching."""

def __getitem__(self, index):
Expand Down Expand Up @@ -54,5 +62,11 @@ def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
raise NotImplementedError

def set_epoch(self, epoch):
pass

class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
"""For datasets that need to be read sequentially, usually because the data
is being streamed or otherwise can't be manipulated on a single machine.
"""

def __iter__(self):
raise NotImplementedError
15 changes: 14 additions & 1 deletion fairseq/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def __len__(self) -> int:
raise NotImplementedError

def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
"""Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator (default: True).
fix_batches_to_gpus: ensure that batches are always
allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching (default: False).
"""
raise NotImplementedError

def end_of_epoch(self) -> bool:
Expand All @@ -71,20 +80,23 @@ def end_of_epoch(self) -> bool:

@property
def iterations_in_epoch(self) -> int:
"""The number of consumed batches in the current epoch."""
raise NotImplementedError

def state_dict(self):
"""Returns a dictionary containing a whole state of the iterator."""
raise NotImplementedError

def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
raise NotImplementedError


class StreamingEpochBatchIterator(EpochBatchIterating):
def __init__(
self, dataset, epoch=0, num_shards=1, shard_id=0,
):
# assert isinstance(dataset, torch.utils.data.Dataset)
assert isinstance(dataset, torch.utils.data.IterableDataset)
self.dataset = dataset
self.epoch = epoch
self._current_epoch_iterator = None
Expand All @@ -93,6 +105,7 @@ def __init__(

def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
self.epoch += 1
self.dataset.set_epoch(self.epoch)
self._current_epoch_iterator = CountingIterator(
iterable=ShardedIterator(
iterable=self.dataset,
Expand Down
4 changes: 4 additions & 0 deletions fairseq/data/list_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ def __init__(self, dataset, sizes=None):
super().__init__(dataset)
self._sizes = sizes

def __iter__(self):
for x in self.dataset:
yield x

def collater(self, samples):
return samples

Expand Down

0 comments on commit 2a9b4ec

Please # to comment.