Skip to content

Latest commit

 

History

History
492 lines (361 loc) · 13.6 KB

customize_models.md

File metadata and controls

492 lines (361 loc) · 13.6 KB

Tutorial 3: Customize Models

We basically categorize model components into 5 types the same as mmdet.

  • backbone: usually an FCN network to extract feature maps, e.g., ResNet, MobileNet.
  • neck: the component between backbones and heads, e.g., FPN, PAFPN.
  • head: the component for specific tasks, e.g., bbox prediction and mask prediction.
  • roi extractor: the part for extracting RoI features from feature maps, e.g., RoI Align.
  • loss: the component in head for calculating losses, e.g., FocalLoss, L1Loss, and GHMLoss.

Develop new components

Add a new detector

Here we show how to develop new components with an example.

Add a new backbone

Here we show how to develop new components with an example of MobileNet.

1. Define a new backbone (e.g. MobileNet)

Create a new file mmfewshot/detection/models/backbones/mobilenet.py.

import torch.nn as nn

from ..builder import BACKBONES


@BACKBONES.register_module()
class MobileNet(nn.Module):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

2. Import the module

You can either add the following line to mmfewshot/detection/models/backbones/__init__.py

from .mobilenet import MobileNet

or alternatively add

custom_imports = dict(
    imports=['mmfewshot.detection.models.backbones.mobilenet'],
    allow_failed_imports=False)

to the config file to avoid modifying the original code.

3. Use the backbone in your config file

model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...

Add new necks

1. Define a neck (e.g. PAFPN)

Create a new file mmfewshot/detection/models/necks/pafpn.py.

from ..builder import NECKS

@NECKS.register_module()
class PAFPN(nn.Module):

    def __init__(self,
                in_channels,
                out_channels,
                num_outs,
                start_level=0,
                end_level=-1,
                add_extra_convs=False):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

2. Import the module

You can either add the following line to mmfewshot/detection/models/necks/__init__.py,

from .pafpn import PAFPN

or alternatively add

custom_imports = dict(
    imports=['mmdet.models.necks.pafpn.py'],
    allow_failed_imports=False)

to the config file and avoid modifying the original code.

3. Modify the config file

neck=dict(
    type='PAFPN',
    in_channels=[256, 512, 1024, 2048],
    out_channels=256,
    num_outs=5)

Add new heads

Here we show how to develop a new head with the example of Double Head R-CNN as the following.

First, add a new bbox head in mmfewshot/detection/models/roi_heads/bbox_heads/double_bbox_head.py. Double Head R-CNN implements a new bbox head for object detection. To implement a bbox head, basically we need to implement three functions of the new module as the following.

from mmdet.models.builder import HEADS
from .bbox_head import BBoxHead

@HEADS.register_module()
class DoubleConvFCBBoxHead(BBoxHead):
    r"""Bbox head used in Double-Head R-CNN

                                      /-> cls
                  /-> shared convs ->
                                      \-> reg
    roi features
                                      /-> cls
                  \-> shared fc    ->
                                      \-> reg
    """  # noqa: W605

    def __init__(self,
                 num_convs=0,
                 num_fcs=0,
                 conv_out_channels=1024,
                 fc_out_channels=1024,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 **kwargs):
        kwargs.setdefault('with_avg_pool', True)
        super(DoubleConvFCBBoxHead, self).__init__(**kwargs)


    def forward(self, x_cls, x_reg):

Second, implement a new RoI Head if it is necessary. We plan to inherit the new DoubleHeadRoIHead from StandardRoIHead. We can find that a StandardRoIHead already implements the following functions.

import torch

from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
from ..builder import HEADS, build_head, build_roi_extractor
from .base_roi_head import BaseRoIHead
from .test_mixins import BBoxTestMixin, MaskTestMixin


@HEADS.register_module()
class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
    """Simplest base roi head including one bbox head and one mask head.
    """

    def init_assigner_sampler(self):

    def init_bbox_head(self, bbox_roi_extractor, bbox_head):

    def init_mask_head(self, mask_roi_extractor, mask_head):


    def forward_dummy(self, x, proposals):


    def forward_train(self,
                      x,
                      img_metas,
                      proposal_list,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None):

    def _bbox_forward(self, x, rois):

    def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
                            img_metas):

    def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
                            img_metas):

    def _mask_forward(self, x, rois=None, pos_inds=None, bbox_feats=None):


    def simple_test(self,
                    x,
                    proposal_list,
                    img_metas,
                    proposals=None,
                    rescale=False):
        """Test without augmentation."""

Double Head's modification is mainly in the bbox_forward logic, and it inherits other logics from the StandardRoIHead. In the mmfewshot/detection/models/roi_heads/double_roi_head.py, we implement the new RoI Head as the following:

from ..builder import HEADS
from .standard_roi_head import StandardRoIHead


@HEADS.register_module()
class DoubleHeadRoIHead(StandardRoIHead):
    """RoI head for Double Head RCNN

    https://arxiv.org/abs/1904.06493
    """

    def __init__(self, reg_roi_scale_factor, **kwargs):
        super(DoubleHeadRoIHead, self).__init__(**kwargs)
        self.reg_roi_scale_factor = reg_roi_scale_factor

    def _bbox_forward(self, x, rois):
        bbox_cls_feats = self.bbox_roi_extractor(
            x[:self.bbox_roi_extractor.num_inputs], rois)
        bbox_reg_feats = self.bbox_roi_extractor(
            x[:self.bbox_roi_extractor.num_inputs],
            rois,
            roi_scale_factor=self.reg_roi_scale_factor)
        if self.with_shared_head:
            bbox_cls_feats = self.shared_head(bbox_cls_feats)
            bbox_reg_feats = self.shared_head(bbox_reg_feats)
        cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)

        bbox_results = dict(
            cls_score=cls_score,
            bbox_pred=bbox_pred,
            bbox_feats=bbox_cls_feats)
        return bbox_results

Last, the users need to add the module in mmfewshot/detection/models/bbox_heads/__init__.py and mmfewshot/detection/models/roi_heads/__init__.py thus the corresponding registry could find and load them.

Alternatively, the users can add

custom_imports=dict(
    imports=['mmfewshot.detection.models.roi_heads.double_roi_head', 'mmfewshot.detection.models.bbox_heads.double_bbox_head'])

to the config file and achieve the same goal.

The config file of Double Head R-CNN is as the following

_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
model = dict(
    roi_head=dict(
        type='DoubleHeadRoIHead',
        reg_roi_scale_factor=1.3,
        bbox_head=dict(
            _delete_=True,
            type='DoubleConvFCBBoxHead',
            num_convs=4,
            num_fcs=2,
            in_channels=256,
            conv_out_channels=1024,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=80,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0),
            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0))))

Since MMDetection 2.0, the config system supports to inherit configs such that the users can focus on the modification. The Double Head R-CNN mainly uses a new DoubleHeadRoIHead and a new DoubleConvFCBBoxHead, the arguments are set according to the __init__ function of each module.

Add new loss

Assume you want to add a new loss as MyLoss, for bounding box regression. To add a new loss function, the users need implement it in mmfewshot/detection/models/losses/my_loss.py. The decorator weighted_loss enable the loss to be weighted for each element.

import torch
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss_bbox

Then the users need to add it in the mmfewshot/detection/models/losses/__init__.py.

from .my_loss import MyLoss, my_loss

Alternatively, you can add

custom_imports=dict(
    imports=['mmfewshot.detection.models.losses.my_loss'])

to the config file and achieve the same goal.

To use it, modify the loss_xxx field. Since MyLoss is for regression, you need to modify the loss_bbox field in the head.

loss_bbox=dict(type='MyLoss', loss_weight=1.0))

Customize frozen parameters

We support frozen_parameters to freeze the parameters during training by parameters' prefix. For example, in roi_head if we only want to freeze the shared_fcs in bbox_head, we can add roi_head.bbox_head.shared_fcs into frozen_parameters list.

model = dict(
    frozen_parameters=[
        'backbone', 'neck', 'rpn_head', 'roi_head.bbox_head.shared_fcs'
    ])

Customize a query-support based detector

Here we show how to develop a new query-support based detector with the example.

1. Define a new detector

Create a new file mmfewshot/detection/models/detector/my_detector.py.

@DETECTORS.register_module()
class MyDetector(QuerySupportDetector):
    # customize the input data
    def forward(self, query_data, support_data, img, img_metas, mode, **kwargs):
        if mode == 'train':
            return self.forward_train(query_data, support_data, **kwargs)
        elif mode == 'model_init':
            return self.forward_model_init(img, img_metas, **kwargs)
        elif mode == 'test':
            return self.forward_test(img, img_metas, **kwargs)
        ...

    def forward_train(self, query_data, support_data, proposals, **kwargs):
        ...

    # before testing the model will forward the whole support set
    # customize the forward logic and save all the needed information
    def forward_model_init(self, img, img_metas, gt_bboxes, gt_labels):
        ...
    # customize the process logic for the saved information from images
    def model_init(self, **kwargs):
        ...

2. Import the module

You can either add the following line to mmfewshot/detection/models/detectors/__init__.py

from .my_detector import MyDetector

or alternatively add

custom_imports = dict(
    imports=['mmfewshot.detection.models.detectors.my_detector'],
    allow_failed_imports=False)

to the config file to avoid modifying the original code.

3. Use the detector in your config file

model = dict(
    type='MyDetector',
    ...

Customize an aggregation layer

we also support to reuse the code of feature fusion from different data usually used in query support based methods. Here we show how to develop a new aggregator with the example.

1. Define a new aggregator

Add customize code in mmfewshot/detection/models/utils/aggregation_layer.py.

@AGGREGATORS.register_module()
class MyAggregator(BaseModule):

    def __init__(self,...):

    def forward(self, query_feat, support_feat):
        ...
        return feat

2. Use the aggregator in your config file

The aggregation_layer can build from single aggregator:

aggregation_layer = dict(type='MyAggregator',...)

or build with multiple aggregators and wrap by a AggregationLayer.

aggregation_layer = dict(
    type = 'AggregationLayer',
    aggregator_cfgs = [
        dict(type = 'MyAggregator',...),
        ...]
)

3. Use the aggregator in your model

from mmfewshot.detection.models.utils import build_aggregator
@HEADS.register_module()
class MyHead(...):
    def __init__(self, ..., aggregation_layer):
        self.aggregation_layer = build_aggregator(copy.deepcopy(aggregation_layer))

    def forward_train(self, ...):
        ...
        self.aggregation_layer(query_feat=..., support_feat=...)
        ...