Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Support multi-batch inference #184

Merged
merged 6 commits into from
Sep 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

**Improvements**
- Support to run a demo with a video url ([#165](https://github.com/open-mmlab/mmaction2/pull/165))
- Support multi-batch inference when testing ([#184](https://github.com/open-mmlab/mmaction2/pull/184))
- Add tutorial for adding a new learning rate updater ([#181](https://github.com/open-mmlab/mmaction2/pull/181))
- Add config name in meta info ([#183](https://github.com/open-mmlab/mmaction2/pull/183))
- Remove git hash in `__version__` ([#189](https://github.com/open-mmlab/mmaction2/pull/189))
Expand Down
12 changes: 10 additions & 2 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,12 @@ test_pipeline = [ # List of testing pipeline steps
data = dict( # Config of data
videos_per_gpu=8, # Batch size of each single GPU
workers_per_gpu=8, # Workers to pre-fetch data for each single GPU
train_dataloader=dict( # Addition config of train dataloader
train_dataloader=dict( # Additional config of train dataloader
drop_last=True), # Whether to drop out the last batch of data in training
val_dataloader=dict( # Addition config of validation dataloader
val_dataloader=dict( # Additional config of validation dataloader
videos_per_gpu=1), # Batch size of each single GPU during evaluation
test_dataloader=dict( # Additional config of test dataloader
videos_per_gpu=2), # Batch size of each single GPU during testing
test=dict( # Testing dataset config
type=dataset_type,
ann_file=ann_file_test,
Expand Down Expand Up @@ -335,6 +337,12 @@ test_pipeline = [ # List of testing pipeline steps
data = dict( # Config of data
videos_per_gpu=32, # Batch size of each single GPU
workers_per_gpu=4, # Workers to pre-fetch data for each single GPU
train_dataloader=dict( # Additional config of train dataloader
drop_last=True), # Whether to drop out the last batch of data in training
val_dataloader=dict( # Additional config of validation dataloader
videos_per_gpu=1), # Batch size of each single GPU during evaluation
test_dataloader=dict( # Additional config of test dataloader
videos_per_gpu=2), # Batch size of each single GPU during testing
train=dict( # Training dataset config
type=dataset_type,
ann_file=ann_file_train,
Expand Down
13 changes: 10 additions & 3 deletions mmaction/models/recognizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def extract_feat(self, imgs):
x = self.backbone(imgs)
return x

def average_clip(self, cls_score):
def average_clip(self, cls_score, num_segs=1):
"""Averaging class score over multiple clips.

Using different averaging types ('score' or 'prob' or None,
Expand All @@ -77,10 +77,17 @@ class score.
f'Currently supported ones are '
f'["score", "prob", None]')

if average_clips is None:
return cls_score

batch_size = cls_score.shape[0]
cls_score = cls_score.view(batch_size // num_segs, num_segs, -1)

if average_clips == 'prob':
cls_score = F.softmax(cls_score, dim=1).mean(dim=0, keepdim=True)
cls_score = F.softmax(cls_score, dim=2).mean(dim=1)
elif average_clips == 'score':
cls_score = cls_score.mean(dim=0, keepdim=True)
cls_score = cls_score.mean(dim=1)

return cls_score

@abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion mmaction/models/recognizers/recognizer3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ def forward_train(self, imgs, labels):
def forward_test(self, imgs):
"""Defines the computation performed at every call when evaluation and
testing."""
num_segs = imgs.shape[1]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])

x = self.extract_feat(imgs)
cls_score = self.cls_head(x)
cls_score = self.average_clip(cls_score)
cls_score = self.average_clip(cls_score, num_segs)

return cls_score.cpu().numpy()

Expand Down
4 changes: 2 additions & 2 deletions tests/test_models/test_recognizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def test_base_recognizer():
# average_clips='score'
test_cfg = dict(average_clips='score')
recognizer = ExampleRecognizer(None, test_cfg)
score = recognizer.average_clip(cls_score)
score = recognizer.average_clip(cls_score, num_segs=5)
assert torch.equal(score, cls_score.mean(dim=0, keepdim=True))

# average_clips='prob'
test_cfg = dict(average_clips='prob')
recognizer = ExampleRecognizer(None, test_cfg)
score = recognizer.average_clip(cls_score)
score = recognizer.average_clip(cls_score, num_segs=5)
assert torch.equal(score,
F.softmax(cls_score, dim=1).mean(dim=0, keepdim=True))

Expand Down
14 changes: 8 additions & 6 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def parse_args():
parser.add_argument('--options', nargs='+', help='custom options')
parser.add_argument(
'--average-clips',
choices=['score', 'prob'],
default='score',
choices=['score', 'prob', None],
default=None,
help='average type when averaging test clips')
parser.add_argument(
'--launcher',
Expand Down Expand Up @@ -111,12 +111,14 @@ def main():
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# build the dataloader
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
data_loader = build_dataloader(
dataset,
videos_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dataloader_setting = dict(
videos_per_gpu=cfg.data.get('videos_per_gpu', {}),
workers_per_gpu=cfg.data.get('workers_per_gpu', {}),
dist=distributed,
shuffle=False)
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('test_dataloader', {}))
data_loader = build_dataloader(dataset, **dataloader_setting)

# build the model and load checkpoint
model = build_model(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
Expand Down