diff --git a/configs/localization/ssn/ssn_r50_450e_thumos14_rgb_test.py b/configs/localization/ssn/ssn_r50_450e_thumos14_rgb_test.py index 5915889cc4..cca88f1999 100644 --- a/configs/localization/ssn/ssn_r50_450e_thumos14_rgb_test.py +++ b/configs/localization/ssn/ssn_r50_450e_thumos14_rgb_test.py @@ -1,5 +1,5 @@ # model training and testing settings -train_cfg = dict( +train_cfg_ = dict( ssn=dict( assigner=dict( positive_iou_threshold=0.7, @@ -15,7 +15,7 @@ add_gt_as_proposals=True), loss_weight=dict(comp_loss_weight=0.1, reg_loss_weight=0.1), debug=False)) -test_cfg = dict( +test_cfg_ = dict( ssn=dict( sampler=dict(test_interval=6, batch_size=16), evaluater=dict( @@ -42,7 +42,7 @@ num_classes=20, consensus=dict(type='STPPTest', stpp_stage=(1, 1, 1)), use_regression=True), - test_cfg=test_cfg) + test_cfg=test_cfg_) # dataset settings dataset_type = 'SSNDataset' data_root = './data/thumos14/rawframes/' @@ -86,8 +86,8 @@ type=dataset_type, ann_file=ann_file_test, data_prefix=data_root, - train_cfg=train_cfg, - test_cfg=test_cfg, + train_cfg=train_cfg_, + test_cfg=test_cfg_, aug_ratio=0.5, test_mode=True, pipeline=test_pipeline)) diff --git a/configs/localization/ssn/ssn_r50_450e_thumos14_rgb_train.py b/configs/localization/ssn/ssn_r50_450e_thumos14_rgb_train.py index c64766cb9c..435ac635b3 100644 --- a/configs/localization/ssn/ssn_r50_450e_thumos14_rgb_train.py +++ b/configs/localization/ssn/ssn_r50_450e_thumos14_rgb_train.py @@ -1,5 +1,5 @@ # model training and testing settings -train_cfg = dict( +train_cfg_ = dict( ssn=dict( assigner=dict( positive_iou_threshold=0.7, @@ -15,7 +15,7 @@ add_gt_as_proposals=True), loss_weight=dict(comp_loss_weight=0.1, reg_loss_weight=0.1), debug=False)) -test_cfg = dict( +test_cfg_ = dict( ssn=dict( sampler=dict(test_interval=6, batch_size=16), evaluater=dict( @@ -46,7 +46,7 @@ stpp_stage=(1, 1, 1), num_segments_list=(2, 5, 2)), use_regression=True), - train_cfg=train_cfg) + train_cfg=train_cfg_) # dataset settings dataset_type = 'SSNDataset' data_root = './data/thumos14/rawframes/' @@ -116,8 +116,8 @@ type=dataset_type, ann_file=ann_file_train, data_prefix=data_root, - train_cfg=train_cfg, - test_cfg=test_cfg, + train_cfg=train_cfg_, + test_cfg=test_cfg_, body_segments=5, aug_segments=(2, 2), aug_ratio=0.5, @@ -128,8 +128,8 @@ type=dataset_type, ann_file=ann_file_val, data_prefix=data_root, - train_cfg=train_cfg, - test_cfg=test_cfg, + train_cfg=train_cfg_, + test_cfg=test_cfg_, body_segments=5, aug_segments=(2, 2), aug_ratio=0.5, diff --git a/mmaction/models/localizers/__init__.py b/mmaction/models/localizers/__init__.py index 0d50890994..523d3f20c2 100644 --- a/mmaction/models/localizers/__init__.py +++ b/mmaction/models/localizers/__init__.py @@ -1,6 +1,6 @@ -from .base import BaseLocalizer +from .base import BaseTAGClassifier, BaseTAPGenerator from .bmn import BMN from .bsn import PEM, TEM from .ssn import SSN -__all__ = ['PEM', 'TEM', 'BMN', 'SSN', 'BaseLocalizer'] +__all__ = ['PEM', 'TEM', 'BMN', 'SSN', 'BaseTAPGenerator', 'BaseTAGClassifier'] diff --git a/mmaction/models/localizers/base.py b/mmaction/models/localizers/base.py index abc715593d..893678f6bf 100644 --- a/mmaction/models/localizers/base.py +++ b/mmaction/models/localizers/base.py @@ -1,3 +1,4 @@ +import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict @@ -8,12 +9,119 @@ from .. import builder -class BaseLocalizer(nn.Module, metaclass=ABCMeta): - """Base class for localizers. +class BaseTAPGenerator(nn.Module, metaclass=ABCMeta): + """Base class for temporal action proposal generator. - All localizers should subclass it. All subclass should overwrite: - Methods:``forward_train``, supporting to forward when training. - Methods:``forward_test``, supporting to forward when testing. + All temporal action proposal generator should subclass it. All subclass + should overwrite: Methods:``forward_train``, supporting to forward when + training. Methods:``forward_test``, supporting to forward when testing. + """ + + @abstractmethod + def forward_train(self, *args, **kwargs): + """Defines the computation performed at training.""" + + @abstractmethod + def forward_test(self, *args): + """Defines the computation performed at testing.""" + + @abstractmethod + def forward(self, *args, **kwargs): + """Define the computation performed at every call.""" + + @staticmethod + def _parse_losses(losses): + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() + if 'loss' in _key) + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data_batch (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self.forward(**data_batch) + + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, + log_vars=log_vars, + num_samples=len(next(iter(data_batch.values())))) + + return outputs + + def val_step(self, data_batch, optimizer, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + results = self.forward(return_loss=False, **data_batch) + + outputs = dict(results=results) + + return outputs + + +class BaseTAGClassifier(nn.Module, metaclass=ABCMeta): + """Base class for temporal action proposal classifier. + + All temporal action generation classifier should subclass it. All subclass + should overwrite: Methods:``forward_train``, supporting to forward when + training. Methods:``forward_test``, supporting to forward when testing. """ def __init__(self, backbone, cls_head, train_cfg=None, test_cfg=None): @@ -42,19 +150,19 @@ def extract_feat(self, imgs): return x @abstractmethod - def forward_train(self, imgs, labels): + def forward_train(self, *args, **kwargs): """Defines the computation performed at training.""" @abstractmethod - def forward_test(self, imgs): + def forward_test(self, *args, **kwargs): """Defines the computation performed at testing.""" - def forward(self, imgs, return_loss=True, **kwargs): + def forward(self, *args, return_loss=True, **kwargs): """Define the computation performed at every call.""" if return_loss: - return self.forward_train(imgs, **kwargs) + return self.forward_train(*args, **kwargs) - return self.forward_test(imgs, **kwargs) + return self.forward_test(*args, **kwargs) @staticmethod def _parse_losses(losses): @@ -141,3 +249,13 @@ def val_step(self, data_batch, optimizer, **kwargs): outputs = dict(results=results) return outputs + + +class BaseLocalizer(BaseTAGClassifier): + """Deprecated class for ``BaseTAPGenerator`` and ``BaseTAGClassifier``.""" + + def __init__(*args, **kwargs): + warnings.warn('``BaseLocalizer`` is deprecated, please switch to' + '``BaseTAPGenerator`` or ``BaseTAGClassifier``. Details ' + 'see https://github.com/open-mmlab/mmaction2/pull/913') + super().__init__(*args, **kwargs) diff --git a/mmaction/models/localizers/bmn.py b/mmaction/models/localizers/bmn.py index a0bbece0cd..cb9bdc4477 100644 --- a/mmaction/models/localizers/bmn.py +++ b/mmaction/models/localizers/bmn.py @@ -6,12 +6,12 @@ from ...localization import temporal_iop, temporal_iou from ..builder import LOCALIZERS, build_loss -from .base import BaseLocalizer +from .base import BaseTAPGenerator from .utils import post_processing @LOCALIZERS.register_module() -class BMN(BaseLocalizer): +class BMN(BaseTAPGenerator): """Boundary Matching Network for temporal action proposal generation. Please refer `BMN: Boundary-Matching Network for Temporal Action Proposal @@ -52,7 +52,7 @@ def __init__(self, hidden_dim_1d=256, hidden_dim_2d=128, hidden_dim_3d=512): - super(BaseLocalizer, self).__init__() + super().__init__() self.tscale = temporal_dim self.boundary_ratio = boundary_ratio diff --git a/mmaction/models/localizers/bsn.py b/mmaction/models/localizers/bsn.py index 83843002ff..e65f7ecf8c 100644 --- a/mmaction/models/localizers/bsn.py +++ b/mmaction/models/localizers/bsn.py @@ -5,13 +5,13 @@ from ...localization import temporal_iop from ..builder import LOCALIZERS, build_loss -from .base import BaseLocalizer +from .base import BaseTAPGenerator from .utils import post_processing @LOCALIZERS.register_module() -class TEM(BaseLocalizer): - """Temporal Evaluation Model for Boundary Sensetive Network. +class TEM(BaseTAPGenerator): + """Temporal Evaluation Model for Boundary Sensitive Network. Please refer `BSN: Boundary Sensitive Network for Temporal Action Proposal Generation `_. @@ -44,7 +44,7 @@ def __init__(self, conv1_ratio=1, conv2_ratio=1, conv3_ratio=0.01): - super(BaseLocalizer, self).__init__() + super().__init__() self.temporal_dim = temporal_dim self.boundary_ratio = boundary_ratio @@ -225,8 +225,8 @@ def forward(self, @LOCALIZERS.register_module() -class PEM(BaseLocalizer): - """Proposals Evaluation Model for Boundary Sensetive Network. +class PEM(BaseTAPGenerator): + """Proposals Evaluation Model for Boundary Sensitive Network. Please refer `BSN: Boundary Sensitive Network for Temporal Action Proposal Generation `_. @@ -268,7 +268,7 @@ def __init__(self, fc1_ratio=0.1, fc2_ratio=0.1, output_dim=1): - super(BaseLocalizer, self).__init__() + super().__init__() self.feat_dim = pem_feat_dim self.hidden_dim = pem_hidden_dim diff --git a/mmaction/models/localizers/ssn.py b/mmaction/models/localizers/ssn.py index 1284f694dd..32c0dedbcc 100644 --- a/mmaction/models/localizers/ssn.py +++ b/mmaction/models/localizers/ssn.py @@ -3,11 +3,11 @@ from .. import builder from ..builder import LOCALIZERS -from .base import BaseLocalizer +from .base import BaseTAGClassifier @LOCALIZERS.register_module() -class SSN(BaseLocalizer): +class SSN(BaseTAGClassifier): """Temporal Action Detection with Structured Segment Networks. Args: