diff --git a/docs/changelog.md b/docs/changelog.md index ee928ed5e7..899a5e6455 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -12,6 +12,7 @@ - Support TSM-MobileNetV2 ([#415](https://github.com/open-mmlab/mmaction2/pull/415)) - Support flip with label mapping ([#591](https://github.com/open-mmlab/mmaction2/pull/591)) +- Add seed option for sampler ([#642](https://github.com/open-mmlab/mmaction2/pull/642)) - Support GPU Normalize ([#586](https://github.com/open-mmlab/mmaction2/pull/586)) - Support TANet ([#595](https://github.com/open-mmlab/mmaction2/pull/595)) diff --git a/mmaction/datasets/builder.py b/mmaction/datasets/builder.py index e7c64ab25e..1936586cc8 100644 --- a/mmaction/datasets/builder.py +++ b/mmaction/datasets/builder.py @@ -84,10 +84,11 @@ def build_dataloader(dataset, if dist: if sample_by_class: assert power is not None - sampler = DistributedPowerSampler(dataset, world_size, rank, power) + sampler = DistributedPowerSampler( + dataset, world_size, rank, power, seed=seed) else: sampler = DistributedSampler( - dataset, world_size, rank, shuffle=shuffle) + dataset, world_size, rank, shuffle=shuffle, seed=seed) shuffle = False batch_size = videos_per_gpu num_workers = workers_per_gpu diff --git a/mmaction/datasets/samplers/distributed_sampler.py b/mmaction/datasets/samplers/distributed_sampler.py index 59fee0caee..5529a0d9cd 100644 --- a/mmaction/datasets/samplers/distributed_sampler.py +++ b/mmaction/datasets/samplers/distributed_sampler.py @@ -10,15 +10,22 @@ class DistributedSampler(_DistributedSampler): class will port one to DistributedSampler. """ - def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): - super().__init__(dataset, num_replicas=num_replicas, rank=rank) - self.shuffle = shuffle + def __init__(self, + dataset, + num_replicas=None, + rank=None, + shuffle=True, + seed=0): + super().__init__( + dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + # for the compatibility from PyTorch 1.3+ + self.seed = seed if seed is not None else 0 def __iter__(self): # deterministically shuffle based on epoch if self.shuffle: g = torch.Generator() - g.manual_seed(self.epoch) + g.manual_seed(self.epoch + self.seed) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = torch.arange(len(self.dataset)).tolist() @@ -45,14 +52,15 @@ class DistributedPowerSampler(_DistributedSampler): from the entire dataset. """ - def __init__(self, dataset, num_replicas=None, rank=None, power=1): + def __init__(self, dataset, num_replicas=None, rank=None, power=1, seed=0): super().__init__(dataset, num_replicas=num_replicas, rank=rank) self.power = power + self.seed = seed if seed is not None else 0 def __iter__(self): # deterministically shuffle based on epoch g = torch.Generator() - g.manual_seed(self.epoch) + g.manual_seed(self.epoch + self.seed) video_infos_by_class = self.dataset.video_infos_by_class num_classes = self.dataset.num_classes # For simplicity, discontinuous labels are not permitted