From 35fcaef4a62ba0895609158118630627ba764a4b Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Sat, 26 Feb 2022 14:06:27 +0800 Subject: [PATCH 1/4] switch to original pycocotools on Windows --- requirements/runtime.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 5f73aec4f..b86343c3c 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -3,8 +3,8 @@ matplotlib mmcv-full mmdet numpy -pycocotools; platform_system == "Linux" -pycocotools-windows; platform_system == "Windows" +opencv-python +pycocotools six terminaltables torch From 6e033d7a35d4aa6873aa6f3338153ffa62549b02 Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Sat, 26 Feb 2022 14:07:13 +0800 Subject: [PATCH 2/4] change name --- requirements/build.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/build.txt b/requirements/build.txt index 815582985..05c9412e0 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -1,3 +1,3 @@ -# These must be installed before building mmdetection +# These must be installed before building mmrotate cython numpy From 492e425cf74eb621e0c89c86bb8d393839aca5d7 Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Sat, 26 Feb 2022 16:39:14 +0800 Subject: [PATCH 3/4] add init_cfg for dense _heads --- mmrotate/datasets/dota.py | 1 - .../dense_heads/kfiou_odm_refine_head.py | 25 +++++++++---------- .../dense_heads/kfiou_rotate_retina_head.py | 11 ++++++++ .../kfiou_rotate_retina_refine_head.py | 11 ++++++++ .../models/dense_heads/odm_refine_head.py | 25 +++++++++---------- .../models/dense_heads/rotated_anchor_head.py | 12 +++------ .../models/dense_heads/rotated_retina_head.py | 22 ++++++++-------- .../dense_heads/rotated_retina_refine_head.py | 10 ++++++++ mmrotate/models/detectors/r3det.py | 23 +++++++++-------- .../models/detectors/rotated_reppoints.py | 1 - .../models/detectors/rotated_retinanet.py | 9 ++++--- mmrotate/models/detectors/s2anet.py | 2 -- 12 files changed, 88 insertions(+), 64 deletions(-) diff --git a/mmrotate/datasets/dota.py b/mmrotate/datasets/dota.py index 35e327ebf..d05fb3f17 100644 --- a/mmrotate/datasets/dota.py +++ b/mmrotate/datasets/dota.py @@ -202,7 +202,6 @@ def evaluate(self, scale_ranges=scale_ranges, iou_thr=iou_thr, dataset=self.CLASSES, - version=self.version, logger=logger, nproc=nproc) eval_results['mAP'] = mean_ap diff --git a/mmrotate/models/dense_heads/kfiou_odm_refine_head.py b/mmrotate/models/dense_heads/kfiou_odm_refine_head.py index 4045b9b72..3161f7247 100644 --- a/mmrotate/models/dense_heads/kfiou_odm_refine_head.py +++ b/mmrotate/models/dense_heads/kfiou_odm_refine_head.py @@ -1,6 +1,6 @@ # Copyright (c) SJTU. All rights reserved. import torch.nn as nn -from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init +from mmcv.cnn import ConvModule from mmcv.runner import force_fp32 from ..builder import ROTATED_HEADS @@ -28,6 +28,7 @@ class KFIoUODMRefineHead(KFIoURRetinaHead): loss_bbox (dict): Config of localization loss. train_cfg (dict): Training config of anchor head. test_cfg (dict): Testing config of anchor head. + init_cfg (dict or list[dict], optional): Initialization config dict. """ # noqa: W605 def __init__(self, @@ -39,6 +40,15 @@ def __init__(self, anchor_generator=dict( type='PseudoAnchorGenerator', strides=[8, 16, 32, 64, 128]), + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='odm_cls', + std=0.01, + bias_prob=0.01)), **kwargs): self.bboxes_as_anchors = None self.stacked_convs = stacked_convs @@ -49,6 +59,7 @@ def __init__(self, in_channels, stacked_convs=2, anchor_generator=anchor_generator, + init_cfg=init_cfg, **kwargs) def _init_layers(self): @@ -91,18 +102,6 @@ def _init_layers(self): self.odm_reg = nn.Conv2d( self.feat_channels, self.num_anchors * 5, 3, padding=1) - def init_weights(self): - """Initialize weights of the head.""" - - normal_init(self.or_conv, std=0.01) - for m in self.cls_convs: - normal_init(m.conv, std=0.01) - for m in self.reg_convs: - normal_init(m.conv, std=0.01) - bias_cls = bias_init_with_prob(0.01) - normal_init(self.odm_cls, std=0.01, bias=bias_cls) - normal_init(self.odm_reg, std=0.01) - def forward_single(self, x): """Forward feature of a single scale level. diff --git a/mmrotate/models/dense_heads/kfiou_rotate_retina_head.py b/mmrotate/models/dense_heads/kfiou_rotate_retina_head.py index 5e79f6def..19c56f512 100644 --- a/mmrotate/models/dense_heads/kfiou_rotate_retina_head.py +++ b/mmrotate/models/dense_heads/kfiou_rotate_retina_head.py @@ -23,6 +23,7 @@ class KFIoURRetinaHead(RotatedRetinaHead): loss_bbox (dict): Config of localization loss. train_cfg (dict): Training config of anchor head. test_cfg (dict): Testing config of anchor head. + init_cfg (dict or list[dict], optional): Initialization config dict. """ # noqa: W605 def __init__(self, @@ -37,6 +38,15 @@ def __init__(self, scales_per_octave=3, ratios=[0.5, 1.0, 2.0], strides=[8, 16, 32, 64, 128]), + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='retina_cls', + std=0.01, + bias_prob=0.01)), **kwargs): self.bboxes_as_anchors = None super(KFIoURRetinaHead, self).__init__( @@ -46,6 +56,7 @@ def __init__(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg, anchor_generator=anchor_generator, + init_cfg=init_cfg, **kwargs) def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights, diff --git a/mmrotate/models/dense_heads/kfiou_rotate_retina_refine_head.py b/mmrotate/models/dense_heads/kfiou_rotate_retina_refine_head.py index ae92a7761..2d8f4428e 100644 --- a/mmrotate/models/dense_heads/kfiou_rotate_retina_refine_head.py +++ b/mmrotate/models/dense_heads/kfiou_rotate_retina_refine_head.py @@ -26,6 +26,7 @@ class KFIoURRetinaRefineHead(KFIoURRetinaHead): loss_bbox (dict): Config of localization loss. train_cfg (dict): Training config of anchor head. test_cfg (dict): Testing config of anchor head. + init_cfg (dict or list[dict], optional): Initialization config dict. """ # noqa: W605 def __init__(self, @@ -41,6 +42,15 @@ def __init__(self, type='DeltaXYWHABBoxCoder', target_means=(.0, .0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)), + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='retina_cls', + std=0.01, + bias_prob=0.01)), **kwargs): self.bboxes_as_anchors = None @@ -52,6 +62,7 @@ def __init__(self, norm_cfg=norm_cfg, anchor_generator=anchor_generator, bbox_coder=bbox_coder, + init_cfg=init_cfg, **kwargs) @force_fp32(apply_to=('cls_scores', 'bbox_preds')) diff --git a/mmrotate/models/dense_heads/odm_refine_head.py b/mmrotate/models/dense_heads/odm_refine_head.py index b3c33443a..0190b9c02 100644 --- a/mmrotate/models/dense_heads/odm_refine_head.py +++ b/mmrotate/models/dense_heads/odm_refine_head.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn -from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init +from mmcv.cnn import ConvModule from mmcv.runner import force_fp32 from ..builder import ROTATED_HEADS @@ -28,6 +28,7 @@ class ODMRefineHead(RotatedRetinaHead): loss_bbox (dict): Config of localization loss. train_cfg (dict): Training config of anchor head. test_cfg (dict): Testing config of anchor head. + init_cfg (dict or list[dict], optional): Initialization config dict. """ # noqa: W605 def __init__(self, @@ -39,6 +40,15 @@ def __init__(self, anchor_generator=dict( type='PseudoAnchorGenerator', strides=[8, 16, 32, 64, 128]), + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='odm_cls', + std=0.01, + bias_prob=0.01)), **kwargs): self.bboxes_as_anchors = None self.stacked_convs = stacked_convs @@ -49,6 +59,7 @@ def __init__(self, in_channels, stacked_convs=2, anchor_generator=anchor_generator, + init_cfg=init_cfg, **kwargs) def _init_layers(self): @@ -91,18 +102,6 @@ def _init_layers(self): self.odm_reg = nn.Conv2d( self.feat_channels, self.num_anchors * 5, 3, padding=1) - def init_weights(self): - """Initialize weights of the head.""" - - normal_init(self.or_conv, std=0.01) - for m in self.cls_convs: - normal_init(m.conv, std=0.01) - for m in self.reg_convs: - normal_init(m.conv, std=0.01) - bias_cls = bias_init_with_prob(0.01) - normal_init(self.odm_cls, std=0.01, bias=bias_cls) - normal_init(self.odm_reg, std=0.01) - def forward_single(self, x): """Forward feature of a single scale level. diff --git a/mmrotate/models/dense_heads/rotated_anchor_head.py b/mmrotate/models/dense_heads/rotated_anchor_head.py index 647be2871..3dad51205 100644 --- a/mmrotate/models/dense_heads/rotated_anchor_head.py +++ b/mmrotate/models/dense_heads/rotated_anchor_head.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn -from mmcv.cnn import normal_init from mmcv.runner import force_fp32 from mmdet.core import images_to_levels, multi_apply, unmap from mmdet.models.dense_heads.base_dense_head import BaseDenseHead @@ -33,6 +32,7 @@ class RotatedAnchorHead(BaseDenseHead): loss_bbox (dict): Config of localization loss. train_cfg (dict): Training config of anchor head. test_cfg (dict): Testing config of anchor head. + init_cfg (dict or list[dict], optional): Initialization config dict. """ # noqa: W605 def __init__(self, @@ -59,8 +59,9 @@ def __init__(self, loss_weight=1.0), loss_bbox=dict(type='L1Loss', loss_weight=1.0), train_cfg=None, - test_cfg=None): - super(RotatedAnchorHead, self).__init__() + test_cfg=None, + init_cfg=dict(type='Normal', layer='Conv2d', std=0.01)): + super(RotatedAnchorHead, self).__init__(init_cfg) self.in_channels = in_channels self.num_classes = num_classes self.feat_channels = feat_channels @@ -105,11 +106,6 @@ def _init_layers(self): self.num_anchors * self.cls_out_channels, 1) self.conv_reg = nn.Conv2d(self.in_channels, self.num_anchors * 5, 1) - def init_weights(self): - """Initialize weights of the head.""" - normal_init(self.conv_cls, std=0.01) - normal_init(self.conv_reg, std=0.01) - def forward_single(self, x): """Forward feature of a single scale level. diff --git a/mmrotate/models/dense_heads/rotated_retina_head.py b/mmrotate/models/dense_heads/rotated_retina_head.py index 5ca8b0f98..cc6bcaa90 100644 --- a/mmrotate/models/dense_heads/rotated_retina_head.py +++ b/mmrotate/models/dense_heads/rotated_retina_head.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn -from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init +from mmcv.cnn import ConvModule from mmcv.runner import force_fp32 from ..builder import ROTATED_HEADS @@ -39,6 +39,15 @@ def __init__(self, scales_per_octave=3, ratios=[0.5, 1.0, 2.0], strides=[8, 16, 32, 64, 128]), + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='retina_cls', + std=0.01, + bias_prob=0.01)), **kwargs): self.stacked_convs = stacked_convs self.conv_cfg = conv_cfg @@ -47,6 +56,7 @@ def __init__(self, num_classes, in_channels, anchor_generator=anchor_generator, + init_cfg=init_cfg, **kwargs) def _init_layers(self): @@ -82,16 +92,6 @@ def _init_layers(self): self.retina_reg = nn.Conv2d( self.feat_channels, self.num_anchors * 5, 3, padding=1) - def init_weights(self): - """Initialize weights of the head.""" - for m in self.cls_convs: - normal_init(m.conv, std=0.01) - for m in self.reg_convs: - normal_init(m.conv, std=0.01) - bias_cls = bias_init_with_prob(0.01) - normal_init(self.retina_cls, std=0.01, bias=bias_cls) - normal_init(self.retina_reg, std=0.01) - def forward_single(self, x): """Forward feature of a single scale level. diff --git a/mmrotate/models/dense_heads/rotated_retina_refine_head.py b/mmrotate/models/dense_heads/rotated_retina_refine_head.py index f31fe86cf..35a21b396 100644 --- a/mmrotate/models/dense_heads/rotated_retina_refine_head.py +++ b/mmrotate/models/dense_heads/rotated_retina_refine_head.py @@ -41,6 +41,15 @@ def __init__(self, type='DeltaXYWHABBoxCoder', target_means=(.0, .0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)), + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='retina_cls', + std=0.01, + bias_prob=0.01)), **kwargs): self.bboxes_as_anchors = None @@ -52,6 +61,7 @@ def __init__(self, norm_cfg=norm_cfg, anchor_generator=anchor_generator, bbox_coder=bbox_coder, + init_cfg=init_cfg, **kwargs) @force_fp32(apply_to=('cls_scores', 'bbox_preds')) diff --git a/mmrotate/models/detectors/r3det.py b/mmrotate/models/detectors/r3det.py index 6b1cdc5a1..523d33234 100644 --- a/mmrotate/models/detectors/r3det.py +++ b/mmrotate/models/detectors/r3det.py @@ -1,5 +1,7 @@ # Copyright (c) SJTU. All rights reserved. -import torch.nn as nn +import warnings + +from mmcv.runner import ModuleList from mmrotate.core import rbbox2result from ..builder import ROTATED_DETECTORS, build_backbone, build_head, build_neck @@ -20,10 +22,13 @@ def __init__(self, refine_heads=None, train_cfg=None, test_cfg=None, - pretrained=None): - super(R3Det, self).__init__() - - backbone.pretrained = pretrained + pretrained=None, + init_cfg=None): + super(R3Det, self).__init__(init_cfg) + if pretrained: + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + backbone.pretrained = pretrained self.backbone = build_backbone(backbone) self.num_refine_stages = num_refine_stages if neck is not None: @@ -32,9 +37,8 @@ def __init__(self, bbox_head.update(train_cfg=train_cfg['s0']) bbox_head.update(test_cfg=test_cfg) self.bbox_head = build_head(bbox_head) - self.bbox_head.init_weights() - self.feat_refine_module = nn.ModuleList() - self.refine_head = nn.ModuleList() + self.feat_refine_module = ModuleList() + self.refine_head = ModuleList() for i, (frm_cfg, refine_head) in enumerate(zip(frm_cfgs, refine_heads)): self.feat_refine_module.append(FeatureRefineModule(**frm_cfg)) @@ -42,9 +46,6 @@ def __init__(self, refine_head.update(train_cfg=train_cfg['sr'][i]) refine_head.update(test_cfg=test_cfg) self.refine_head.append(build_head(refine_head)) - for i in range(self.num_refine_stages): - self.feat_refine_module[i].init_weights() - self.refine_head[i].init_weights() self.train_cfg = train_cfg self.test_cfg = test_cfg diff --git a/mmrotate/models/detectors/rotated_reppoints.py b/mmrotate/models/detectors/rotated_reppoints.py index 829db1579..7386a2755 100644 --- a/mmrotate/models/detectors/rotated_reppoints.py +++ b/mmrotate/models/detectors/rotated_reppoints.py @@ -16,4 +16,3 @@ def __init__(self, pretrained=None): super(RotatedRepPoints, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained) - self.bbox_head.init_weights() diff --git a/mmrotate/models/detectors/rotated_retinanet.py b/mmrotate/models/detectors/rotated_retinanet.py index cd9530e8f..bb3963226 100644 --- a/mmrotate/models/detectors/rotated_retinanet.py +++ b/mmrotate/models/detectors/rotated_retinanet.py @@ -16,7 +16,8 @@ def __init__(self, bbox_head, train_cfg=None, test_cfg=None, - pretrained=None): - super(RotatedRetinaNet, self).__init__(backbone, neck, bbox_head, - train_cfg, test_cfg, pretrained) - self.bbox_head.init_weights() + pretrained=None, + init_cfg=None): + super(RotatedRetinaNet, + self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, + pretrained, init_cfg) diff --git a/mmrotate/models/detectors/s2anet.py b/mmrotate/models/detectors/s2anet.py index 4d51c3cd8..65dee95da 100644 --- a/mmrotate/models/detectors/s2anet.py +++ b/mmrotate/models/detectors/s2anet.py @@ -28,7 +28,6 @@ def __init__(self, fam_head.update(train_cfg=train_cfg['fam_cfg']) fam_head.update(test_cfg=test_cfg) self.fam_head = build_head(fam_head) - self.fam_head.init_weights() self.align_conv_type = align_cfgs['type'] self.align_conv_size = align_cfgs['kernel_size'] @@ -44,7 +43,6 @@ def __init__(self, odm_head.update(train_cfg=train_cfg['odm_cfg']) odm_head.update(test_cfg=test_cfg) self.odm_head = build_head(odm_head) - self.odm_head.init_weights() self.train_cfg = train_cfg self.test_cfg = test_cfg From 2f54e87f89f340b3acfd15dffc77887fff3dd509 Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Sat, 26 Feb 2022 16:47:15 +0800 Subject: [PATCH 4/4] remove opencv --- requirements/runtime.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index b86343c3c..5dc0a7117 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -3,7 +3,6 @@ matplotlib mmcv-full mmdet numpy -opencv-python pycocotools six terminaltables