Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Feature] Add seed option for Sampler #642

Merged
merged 2 commits into from
Feb 24, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@

- Support TSM-MobileNetV2 ([#415](https://github.com/open-mmlab/mmaction2/pull/415))
- Support flip with label mapping ([#591](https://github.com/open-mmlab/mmaction2/pull/591))
- Add seed option for sampler ([#642](https://github.com/open-mmlab/mmaction2/pull/642))
- Support GPU Normalize ([#586](https://github.com/open-mmlab/mmaction2/pull/586))
- Support TANet ([#595](https://github.com/open-mmlab/mmaction2/pull/595))

5 changes: 3 additions & 2 deletions mmaction/datasets/builder.py
Original file line number Diff line number Diff line change
@@ -84,10 +84,11 @@ def build_dataloader(dataset,
if dist:
if sample_by_class:
assert power is not None
sampler = DistributedPowerSampler(dataset, world_size, rank, power)
sampler = DistributedPowerSampler(
dataset, world_size, rank, power, seed=seed)
else:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle)
dataset, world_size, rank, shuffle=shuffle, seed=seed)
shuffle = False
batch_size = videos_per_gpu
num_workers = workers_per_gpu
20 changes: 14 additions & 6 deletions mmaction/datasets/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
@@ -10,15 +10,22 @@ class DistributedSampler(_DistributedSampler):
class will port one to DistributedSampler.
"""

def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle
def __init__(self,
dataset,
num_replicas=None,
rank=None,
shuffle=True,
seed=0):
super().__init__(
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
# for the compatibility from PyTorch 1.3+
self.seed = seed if seed is not None else 0

def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
g.manual_seed(self.epoch + self.seed)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
@@ -45,14 +52,15 @@ class DistributedPowerSampler(_DistributedSampler):
from the entire dataset.
"""

def __init__(self, dataset, num_replicas=None, rank=None, power=1):
def __init__(self, dataset, num_replicas=None, rank=None, power=1, seed=0):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.power = power
self.seed = seed if seed is not None else 0

def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
g.manual_seed(self.epoch + self.seed)
video_infos_by_class = self.dataset.video_infos_by_class
num_classes = self.dataset.num_classes
# For simplicity, discontinuous labels are not permitted