Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add init_cfg for dense heads. #37

Merged
merged 5 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mmrotate/datasets/dota.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def evaluate(self,
scale_ranges=scale_ranges,
iou_thr=iou_thr,
dataset=self.CLASSES,
version=self.version,
logger=logger,
nproc=nproc)
eval_results['mAP'] = mean_ap
Expand Down
25 changes: 12 additions & 13 deletions mmrotate/models/dense_heads/kfiou_odm_refine_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) SJTU. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
from mmcv.cnn import ConvModule
from mmcv.runner import force_fp32

from ..builder import ROTATED_HEADS
Expand Down Expand Up @@ -28,6 +28,7 @@ class KFIoUODMRefineHead(KFIoURRetinaHead):
loss_bbox (dict): Config of localization loss.
train_cfg (dict): Training config of anchor head.
test_cfg (dict): Testing config of anchor head.
init_cfg (dict or list[dict], optional): Initialization config dict.
""" # noqa: W605

def __init__(self,
Expand All @@ -39,6 +40,15 @@ def __init__(self,
anchor_generator=dict(
type='PseudoAnchorGenerator',
strides=[8, 16, 32, 64, 128]),
init_cfg=dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal',
name='odm_cls',
std=0.01,
bias_prob=0.01)),
**kwargs):
self.bboxes_as_anchors = None
self.stacked_convs = stacked_convs
Expand All @@ -49,6 +59,7 @@ def __init__(self,
in_channels,
stacked_convs=2,
anchor_generator=anchor_generator,
init_cfg=init_cfg,
**kwargs)

def _init_layers(self):
Expand Down Expand Up @@ -91,18 +102,6 @@ def _init_layers(self):
self.odm_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 5, 3, padding=1)

def init_weights(self):
"""Initialize weights of the head."""

normal_init(self.or_conv, std=0.01)
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.odm_cls, std=0.01, bias=bias_cls)
normal_init(self.odm_reg, std=0.01)

def forward_single(self, x):
"""Forward feature of a single scale level.

Expand Down
11 changes: 11 additions & 0 deletions mmrotate/models/dense_heads/kfiou_rotate_retina_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class KFIoURRetinaHead(RotatedRetinaHead):
loss_bbox (dict): Config of localization loss.
train_cfg (dict): Training config of anchor head.
test_cfg (dict): Testing config of anchor head.
init_cfg (dict or list[dict], optional): Initialization config dict.
""" # noqa: W605

def __init__(self,
Expand All @@ -37,6 +38,15 @@ def __init__(self,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
init_cfg=dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal',
name='retina_cls',
std=0.01,
bias_prob=0.01)),
**kwargs):
self.bboxes_as_anchors = None
super(KFIoURRetinaHead, self).__init__(
Expand All @@ -46,6 +56,7 @@ def __init__(self,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
anchor_generator=anchor_generator,
init_cfg=init_cfg,
**kwargs)

def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
Expand Down
11 changes: 11 additions & 0 deletions mmrotate/models/dense_heads/kfiou_rotate_retina_refine_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class KFIoURRetinaRefineHead(KFIoURRetinaHead):
loss_bbox (dict): Config of localization loss.
train_cfg (dict): Training config of anchor head.
test_cfg (dict): Testing config of anchor head.
init_cfg (dict or list[dict], optional): Initialization config dict.
""" # noqa: W605

def __init__(self,
Expand All @@ -41,6 +42,15 @@ def __init__(self,
type='DeltaXYWHABBoxCoder',
target_means=(.0, .0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
init_cfg=dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal',
name='retina_cls',
std=0.01,
bias_prob=0.01)),
**kwargs):

self.bboxes_as_anchors = None
Expand All @@ -52,6 +62,7 @@ def __init__(self,
norm_cfg=norm_cfg,
anchor_generator=anchor_generator,
bbox_coder=bbox_coder,
init_cfg=init_cfg,
**kwargs)

@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
Expand Down
25 changes: 12 additions & 13 deletions mmrotate/models/dense_heads/odm_refine_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
from mmcv.cnn import ConvModule
from mmcv.runner import force_fp32

from ..builder import ROTATED_HEADS
Expand Down Expand Up @@ -28,6 +28,7 @@ class ODMRefineHead(RotatedRetinaHead):
loss_bbox (dict): Config of localization loss.
train_cfg (dict): Training config of anchor head.
test_cfg (dict): Testing config of anchor head.
init_cfg (dict or list[dict], optional): Initialization config dict.
""" # noqa: W605

def __init__(self,
Expand All @@ -39,6 +40,15 @@ def __init__(self,
anchor_generator=dict(
type='PseudoAnchorGenerator',
strides=[8, 16, 32, 64, 128]),
init_cfg=dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal',
name='odm_cls',
std=0.01,
bias_prob=0.01)),
**kwargs):
self.bboxes_as_anchors = None
self.stacked_convs = stacked_convs
Expand All @@ -49,6 +59,7 @@ def __init__(self,
in_channels,
stacked_convs=2,
anchor_generator=anchor_generator,
init_cfg=init_cfg,
**kwargs)

def _init_layers(self):
Expand Down Expand Up @@ -91,18 +102,6 @@ def _init_layers(self):
self.odm_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 5, 3, padding=1)

def init_weights(self):
"""Initialize weights of the head."""

normal_init(self.or_conv, std=0.01)
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.odm_cls, std=0.01, bias=bias_cls)
normal_init(self.odm_reg, std=0.01)

def forward_single(self, x):
"""Forward feature of a single scale level.

Expand Down
12 changes: 4 additions & 8 deletions mmrotate/models/dense_heads/rotated_anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmcv.runner import force_fp32
from mmdet.core import images_to_levels, multi_apply, unmap
from mmdet.models.dense_heads.base_dense_head import BaseDenseHead
Expand Down Expand Up @@ -33,6 +32,7 @@ class RotatedAnchorHead(BaseDenseHead):
loss_bbox (dict): Config of localization loss.
train_cfg (dict): Training config of anchor head.
test_cfg (dict): Testing config of anchor head.
init_cfg (dict or list[dict], optional): Initialization config dict.
""" # noqa: W605

def __init__(self,
Expand All @@ -59,8 +59,9 @@ def __init__(self,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0),
train_cfg=None,
test_cfg=None):
super(RotatedAnchorHead, self).__init__()
test_cfg=None,
init_cfg=dict(type='Normal', layer='Conv2d', std=0.01)):
super(RotatedAnchorHead, self).__init__(init_cfg)
self.in_channels = in_channels
self.num_classes = num_classes
self.feat_channels = feat_channels
Expand Down Expand Up @@ -105,11 +106,6 @@ def _init_layers(self):
self.num_anchors * self.cls_out_channels, 1)
self.conv_reg = nn.Conv2d(self.in_channels, self.num_anchors * 5, 1)

def init_weights(self):
"""Initialize weights of the head."""
normal_init(self.conv_cls, std=0.01)
normal_init(self.conv_reg, std=0.01)

def forward_single(self, x):
"""Forward feature of a single scale level.

Expand Down
22 changes: 11 additions & 11 deletions mmrotate/models/dense_heads/rotated_retina_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
from mmcv.cnn import ConvModule
from mmcv.runner import force_fp32

from ..builder import ROTATED_HEADS
Expand Down Expand Up @@ -39,6 +39,15 @@ def __init__(self,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
init_cfg=dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal',
name='retina_cls',
std=0.01,
bias_prob=0.01)),
**kwargs):
self.stacked_convs = stacked_convs
self.conv_cfg = conv_cfg
Expand All @@ -47,6 +56,7 @@ def __init__(self,
num_classes,
in_channels,
anchor_generator=anchor_generator,
init_cfg=init_cfg,
**kwargs)

def _init_layers(self):
Expand Down Expand Up @@ -82,16 +92,6 @@ def _init_layers(self):
self.retina_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 5, 3, padding=1)

def init_weights(self):
"""Initialize weights of the head."""
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01)

def forward_single(self, x):
"""Forward feature of a single scale level.

Expand Down
10 changes: 10 additions & 0 deletions mmrotate/models/dense_heads/rotated_retina_refine_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ def __init__(self,
type='DeltaXYWHABBoxCoder',
target_means=(.0, .0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
init_cfg=dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal',
name='retina_cls',
std=0.01,
bias_prob=0.01)),
**kwargs):

self.bboxes_as_anchors = None
Expand All @@ -52,6 +61,7 @@ def __init__(self,
norm_cfg=norm_cfg,
anchor_generator=anchor_generator,
bbox_coder=bbox_coder,
init_cfg=init_cfg,
**kwargs)

@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
Expand Down
23 changes: 12 additions & 11 deletions mmrotate/models/detectors/r3det.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) SJTU. All rights reserved.
import torch.nn as nn
import warnings

from mmcv.runner import ModuleList

from mmrotate.core import rbbox2result
from ..builder import ROTATED_DETECTORS, build_backbone, build_head, build_neck
Expand All @@ -20,10 +22,13 @@ def __init__(self,
refine_heads=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(R3Det, self).__init__()

backbone.pretrained = pretrained
pretrained=None,
init_cfg=None):
super(R3Det, self).__init__(init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
self.num_refine_stages = num_refine_stages
if neck is not None:
Expand All @@ -32,19 +37,15 @@ def __init__(self,
bbox_head.update(train_cfg=train_cfg['s0'])
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.bbox_head.init_weights()
self.feat_refine_module = nn.ModuleList()
self.refine_head = nn.ModuleList()
self.feat_refine_module = ModuleList()
self.refine_head = ModuleList()
for i, (frm_cfg,
refine_head) in enumerate(zip(frm_cfgs, refine_heads)):
self.feat_refine_module.append(FeatureRefineModule(**frm_cfg))
if train_cfg is not None:
refine_head.update(train_cfg=train_cfg['sr'][i])
refine_head.update(test_cfg=test_cfg)
self.refine_head.append(build_head(refine_head))
for i in range(self.num_refine_stages):
self.feat_refine_module[i].init_weights()
self.refine_head[i].init_weights()
self.train_cfg = train_cfg
self.test_cfg = test_cfg

Expand Down
1 change: 0 additions & 1 deletion mmrotate/models/detectors/rotated_reppoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ def __init__(self,
pretrained=None):
super(RotatedRepPoints, self).__init__(backbone, neck, bbox_head,
train_cfg, test_cfg, pretrained)
self.bbox_head.init_weights()
9 changes: 5 additions & 4 deletions mmrotate/models/detectors/rotated_retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def __init__(self,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(RotatedRetinaNet, self).__init__(backbone, neck, bbox_head,
train_cfg, test_cfg, pretrained)
self.bbox_head.init_weights()
pretrained=None,
init_cfg=None):
super(RotatedRetinaNet,
self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
pretrained, init_cfg)
2 changes: 0 additions & 2 deletions mmrotate/models/detectors/s2anet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(self,
fam_head.update(train_cfg=train_cfg['fam_cfg'])
fam_head.update(test_cfg=test_cfg)
self.fam_head = build_head(fam_head)
self.fam_head.init_weights()

self.align_conv_type = align_cfgs['type']
self.align_conv_size = align_cfgs['kernel_size']
Expand All @@ -44,7 +43,6 @@ def __init__(self,
odm_head.update(train_cfg=train_cfg['odm_cfg'])
odm_head.update(test_cfg=test_cfg)
self.odm_head = build_head(odm_head)
self.odm_head.init_weights()

self.train_cfg = train_cfg
self.test_cfg = test_cfg
Expand Down
2 changes: 1 addition & 1 deletion requirements/build.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# These must be installed before building mmdetection
# These must be installed before building mmrotate
cython
numpy
Loading