diff --git a/configs/distill/fgd/README.md b/configs/distill/fgd/README.md new file mode 100644 index 000000000..f500a16ab --- /dev/null +++ b/configs/distill/fgd/README.md @@ -0,0 +1,30 @@ +# FGD + +> [Focal and Global Knowledge Distillation for Detectors](https://arxiv.org/abs/2111.11837) + + + +## Abstract + +Knowledge distillation has been applied to image classification successfully. However, object detection is much more sophisticated and most knowledge distillation methods have failed on it. In this paper, we point out that in object detection, the features of the teacher and student vary greatly in different areas, especially in the foreground and background. If we distill them equally, the uneven differences between feature maps will negatively affect the distillation. Thus, we propose Focal and Global Distillation (FGD). Focal distillation separates the foreground and background, forcing the student to focus on the teacher's critical pixels and channels. Global distillation rebuilds the relation between different pixels and transfers it from teachers to students, compensating for missing global information in focal distillation. As our method only needs to calculate the loss on the feature map, FGD can be applied to various detectors. We experiment on various detectors with different backbones and the results show that the student detector achieves excellent mAP improvement. For example, ResNet-50 based RetinaNet, Faster RCNN, RepPoints and Mask RCNN with our distillation method achieve 40.7%, 42.0%, 42.0% and 42.1% mAP on COCO2017, which are 3.3, 3.6, 3.4 and 2.9 higher than the baseline, respectively. + +![pipeline](https://user-images.githubusercontent.com/41630003/220037957-25a1440f-fcb3-413a-a350-97937bf6a042.png) + +## Results and models + +### Detection + +| Location | Dataset | Teacher | Student | mAP | mAP(T) | mAP(S) | Config | Download | +| :------: | :-----: | :---------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :--: | :----: | :----: | :-------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| FPN | COCO | [retina_x101_1x](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_x101_64x4d_fpn_1x_coco.py) | [retina_r50_2x](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | 40.5 | 41.0 | 37.4 | [config](./fgd_retina_x101_fpn_retina_r50_fpn_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth) \|[model](https://download.openmmlab.com/mmrazor/v0.1/distill/fgd/fgd_retina_x101_retina_r50_2x_coco_20221216_114845-c4c7496d.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/distill/fgd/fgd_retina_x101_retina_r50_2x_coco_20221216_114845-c4c7496d.json) | + +## Citation + +```latex +@article{yang2021focal, + title={Focal and Global Knowledge Distillation for Detectors}, + author={Yang, Zhendong and Li, Zhe and Jiang, Xiaohu and Gong, Yuan and Yuan, Zehuan and Zhao, Danpei and Yuan, Chun}, + journal={arXiv preprint arXiv:2111.11837}, + year={2021} +} +``` diff --git a/configs/distill/fgd/fgd_retina_x101_fpn_retina_r50_fpn_2x_coco.py b/configs/distill/fgd/fgd_retina_x101_fpn_retina_r50_fpn_2x_coco.py new file mode 100644 index 000000000..dcbeda181 --- /dev/null +++ b/configs/distill/fgd/fgd_retina_x101_fpn_retina_r50_fpn_2x_coco.py @@ -0,0 +1,228 @@ +_base_ = [ + '../../_base_/datasets/mmdet/coco_detection.py', + '../../_base_/schedules/mmdet/schedule_2x.py', + '../../_base_/mmdet_runtime.py' +] + +# model settings +t_weight = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth' # noqa: E501 + +student = dict( + type='mmdet.RetinaNet', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + num_outs=5), + bbox_head=dict( + type='RetinaHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + # model training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) + +teacher = dict( + type='mmdet.RetinaNet', + init_cfg=dict(type='Pretrained', checkpoint=t_weight), + backbone=dict( + type='ResNeXt', + depth=101, + groups=64, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + num_outs=5), + bbox_head=dict( + type='RetinaHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + # model training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) + +# algorithm setting +in_channels = 256 +temp = 0.5 +alpha_fgd = 0.001 +beta_fgd = 0.0005 +gamma_fgd = 0.0005 +lambda_fgd = 0.000005 +algorithm = dict( + type='GeneralDistill', + architecture=dict( + type='MMDetArchitecture', + model=student, + ), + distiller=dict( + type='SingleTeacherDistiller', + teacher=teacher, + teacher_trainable=False, + components=[ + dict( + student_module='neck.fpn_convs.0.conv', + teacher_module='neck.fpn_convs.0.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_0', + in_channels=in_channels, + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + dict( + student_module='neck.fpn_convs.1.conv', + teacher_module='neck.fpn_convs.1.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_1', + in_channels=in_channels, + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + dict( + student_module='neck.fpn_convs.2.conv', + teacher_module='neck.fpn_convs.2.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_2', + in_channels=in_channels, + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + dict( + student_module='neck.fpn_convs.3.conv', + teacher_module='neck.fpn_convs.3.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_3', + in_channels=in_channels, + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + dict( + student_module='neck.fpn_convs.4.conv', + teacher_module='neck.fpn_convs.4.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_4', + in_channels=in_channels, + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + ]), +) + +find_unused_parameters = True + +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict( + _delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) diff --git a/configs/distill/fgd/metafile.yml b/configs/distill/fgd/metafile.yml new file mode 100644 index 000000000..4d96a6c82 --- /dev/null +++ b/configs/distill/fgd/metafile.yml @@ -0,0 +1,32 @@ +Collections: + - Name: FGD + Metadata: + Training Data: + - COCO + Paper: + URL: https://arxiv.org/abs/2111.11837 + Title: Focal and Global Knowledge Distillation for Detectors + README: configs/distill/fgd/README.md + Code: + URL: + Version: v0.1.0 + Converted From: + Code: + - https://github.com/yzd-v/FGD +Models: + - Name: fgd_retina_x101_fpn_retina_r50_fpn_2x_coco + In Collection: FGD + Metadata: + Location: FPN + Student: retinanet-r50 + Teacher: retinanet-x101 + Teacher Checkpoint: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 40.5 + box AP(S): 37.4 + box AP(T): 41.0 + Config: configs/distill/fgd/fgd_retina_x101_fpn_retina_r50_fpn_2x_coco.py + Weights: https://download.openmmlab.com/mmrazor/v0.1/distill/fgd/fgd_retina_x101_retina_r50_2x_coco_20221216_114845-c4c7496d.pth diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index 3d3e97e52..c6f1a1f55 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cwd import ChannelWiseDivergence +from .fgd import FGDLoss from .kl_divergence import KLDivergence from .relational_kd import AngleWiseRKD, DistanceWiseRKD from .weighted_soft_label_distillation import WSLD __all__ = [ 'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD', - 'WSLD' + 'WSLD', 'FGDLoss' ] diff --git a/mmrazor/models/losses/fgd.py b/mmrazor/models/losses/fgd.py new file mode 100644 index 000000000..817755113 --- /dev/null +++ b/mmrazor/models/losses/fgd.py @@ -0,0 +1,238 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import constant_init, kaiming_init + +from ..builder import LOSSES + + +@LOSSES.register_module() +class FGDLoss(nn.Module): + """PyTorch version of 'Focal and Global Knowledge Distillation for + Detectors'. + + + + Args: + in_channels (int): Channels of the input feature map. + temp (float, optional): Temperature coefficient. Defaults to 0.5. + alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001. + beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005. + gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.001. + lambda_fgd (float, optional): Weight of relation_loss. + Defaults to 0.000005. + """ + + def __init__( + self, + in_channels, + temp=0.5, + alpha_fgd=0.001, + beta_fgd=0.0005, + gamma_fgd=0.001, + lambda_fgd=0.000005, + ): + super(FGDLoss, self).__init__() + self.temp = temp + self.alpha_fgd = alpha_fgd + self.beta_fgd = beta_fgd + self.gamma_fgd = gamma_fgd + self.lambda_fgd = lambda_fgd + + self.conv_mask_s = nn.Conv2d(in_channels, 1, kernel_size=1) + self.conv_mask_t = nn.Conv2d(in_channels, 1, kernel_size=1) + self.channel_add_conv_s = nn.Sequential( + nn.Conv2d(in_channels, in_channels // 2, kernel_size=1), + nn.LayerNorm([in_channels // 2, 1, 1]), nn.ReLU(inplace=True), + nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)) + self.channel_add_conv_t = nn.Sequential( + nn.Conv2d(in_channels, in_channels // 2, kernel_size=1), + nn.LayerNorm([in_channels // 2, 1, 1]), nn.ReLU(inplace=True), + nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)) + + self.reset_parameters() + + def forward(self, preds_S, preds_T): + """Forward function. + + Args: + preds_S(Tensor): Bs*C*H*W, student's feature map + preds_T(Tensor): Bs*C*H*W, teacher's feature map + """ + assert preds_S.shape[-2:] == preds_T.shape[-2:] + N, C, H, W = preds_S.shape + # Bs*[nt*4], (tl_x, tl_y, br_x, br_y) + gt_bboxes = self.current_data['gt_bboxes'] + # Meta information of each image, e.g., image size, scaling factor. + metas = self.current_data['img_metas'] # list[dict] + + spatial_attention_t, channel_attention_t = self.get_attention( + preds_T, self.temp) + spatial_attention_s, channel_attention_s = self.get_attention( + preds_S, self.temp) + + mask_fg = torch.zeros_like(spatial_attention_t) + mask_bg = torch.ones_like(spatial_attention_t) + wmin, wmax, hmin, hmax = [], [], [], [] + for i in range(N): + new_boxx = torch.ones_like(gt_bboxes[i]) + new_boxx[:, 0] = gt_bboxes[i][:, 0] / metas[i]['img_shape'][1] * W + new_boxx[:, 2] = gt_bboxes[i][:, 2] / metas[i]['img_shape'][1] * W + new_boxx[:, 1] = gt_bboxes[i][:, 1] / metas[i]['img_shape'][0] * H + new_boxx[:, 3] = gt_bboxes[i][:, 3] / metas[i]['img_shape'][0] * H + + wmin.append(torch.floor(new_boxx[:, 0]).int()) + wmax.append(torch.ceil(new_boxx[:, 2]).int()) + hmin.append(torch.floor(new_boxx[:, 1]).int()) + hmax.append(torch.ceil(new_boxx[:, 3]).int()) + + height = hmax[i].view(1, -1) + 1 - hmin[i].view(1, -1) + width = wmax[i].view(1, -1) + 1 - wmin[i].view(1, -1) + area = 1.0 / height.float() / width.float() + + for j in range(len(gt_bboxes[i])): + mask_fg[i][hmin[i][j]:hmax[i][j]+1, + wmin[i][j]:wmax[i][j]+1] = \ + torch.max(mask_fg[i][hmin[i][j]:hmax[i][j]+1, + wmin[i][j]:wmax[i][j]+1], area[0][j]) + + mask_bg[i] = torch.where(mask_fg[i] > 0, torch.zeros_like(mask_bg), + torch.ones_like(mask_bg)) + if torch.sum(mask_bg[i]): + mask_bg[i] /= torch.sum(mask_bg[i]) + + fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, mask_fg, + mask_bg, channel_attention_t, + spatial_attention_t) + mask_loss = self.get_mask_loss(channel_attention_s, + channel_attention_t, + spatial_attention_s, + spatial_attention_t) + rela_loss = self.get_rela_loss(preds_S, preds_T) + + loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ + + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss + + return loss + + def get_attention(self, preds, temp): + """Calculate spatial and channel attention. + + Args: + preds (Tensor): Model prediction with shape (N, C, H, W). + temp (float): Temperature coefficient. + """ + N, C, H, W = preds.shape + + value = torch.abs(preds) + # Bs*W*H + fea_map = value.mean(axis=1, keepdim=True) + spatial_attention = (H * W * F.softmax( + (fea_map / temp).view(N, -1), dim=1)).view(N, H, W) + + # Bs*C + channel_map = value.mean( + axis=2, keepdim=False).mean( + axis=2, keepdim=False) + channel_attention = C * F.softmax(channel_map / temp, dim=1) + + return spatial_attention, channel_attention + + def get_fea_loss(self, preds_S, preds_T, mask_fg, mask_bg, + channel_attention_t, spatial_attention_t): + loss_mse = nn.MSELoss(reduction='sum') + + mask_fg = mask_fg.unsqueeze(dim=1) + mask_bg = mask_bg.unsqueeze(dim=1) + + channel_attention_t = channel_attention_t.unsqueeze(dim=-1).unsqueeze( + dim=-1) + spatial_attention_t = spatial_attention_t.unsqueeze(dim=1) + + fea_t = torch.mul(preds_T, torch.sqrt(spatial_attention_t)) + fea_t = torch.mul(fea_t, torch.sqrt(channel_attention_t)) + fea_t_fg = torch.mul(fea_t, torch.sqrt(mask_fg)) + fea_t_bg = torch.mul(fea_t, torch.sqrt(mask_bg)) + + fea_s = torch.mul(preds_S, torch.sqrt(spatial_attention_t)) + fea_s = torch.mul(fea_s, torch.sqrt(channel_attention_t)) + fea_s_fg = torch.mul(fea_s, torch.sqrt(mask_fg)) + fea_s_bg = torch.mul(fea_s, torch.sqrt(mask_bg)) + + loss_fg = loss_mse(fea_s_fg, fea_t_fg) / len(mask_fg) + loss_bg = loss_mse(fea_s_bg, fea_t_bg) / len(mask_bg) + + return loss_fg, loss_bg + + def get_mask_loss(self, channel_attention_s, channel_attention_t, + spatial_attention_s, spatial_attention_t): + + mask_loss = torch.sum( + torch.abs( + (channel_attention_s - + channel_attention_t))) / len(channel_attention_s) + torch.sum( + torch.abs( + (spatial_attention_s - + spatial_attention_t))) / len(spatial_attention_s) + + return mask_loss + + def spatial_pool(self, x, is_student_input): + batch, channel, width, height = x.size() + input_x = x + # [N, C, H * W] + input_x = input_x.view(batch, channel, height * width) + # [N, 1, C, H * W] + input_x = input_x.unsqueeze(1) + # [N, 1, H, W] + if is_student_input: + context_mask = self.conv_mask_s(x) + else: + context_mask = self.conv_mask_t(x) + # [N, 1, H * W] + context_mask = context_mask.view(batch, 1, height * width) + # [N, 1, H * W] + context_mask = F.softmax(context_mask, dim=2) + # [N, 1, H * W, 1] + context_mask = context_mask.unsqueeze(-1) + # [N, 1, C, 1] + context = torch.matmul(input_x, context_mask) + # [N, C, 1, 1] + context = context.view(batch, channel, 1, 1) + + return context + + def get_rela_loss(self, preds_S, preds_T): + loss_mse = nn.MSELoss(reduction='sum') + + context_s = self.spatial_pool(preds_S, is_student_input=True) + context_t = self.spatial_pool(preds_T, is_student_input=False) + + out_s = preds_S + out_t = preds_T + + channel_add_s = self.channel_add_conv_s(context_s) + out_s = out_s + channel_add_s + + channel_add_t = self.channel_add_conv_t(context_t) + out_t = out_t + channel_add_t + + rela_loss = loss_mse(out_s, out_t) / len(out_s) + + return rela_loss + + def last_zero_init(self, m): + if isinstance(m, nn.Sequential): + constant_init(m[-1], val=0) + else: + constant_init(m, val=0) + + def reset_parameters(self): + kaiming_init(self.conv_mask_s, mode='fan_in') + kaiming_init(self.conv_mask_t, mode='fan_in') + self.conv_mask_s.inited = True + self.conv_mask_t.inited = True + + self.last_zero_init(self.channel_add_conv_s) + self.last_zero_init(self.channel_add_conv_t) diff --git a/tests/data/fgd_retina.py b/tests/data/fgd_retina.py new file mode 100644 index 000000000..dcd76833e --- /dev/null +++ b/tests/data/fgd_retina.py @@ -0,0 +1,220 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# model settings +student = dict( + type='mmdet.RetinaNet', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + num_outs=5), + bbox_head=dict( + type='RetinaHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + # model training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) + +teacher = dict( + type='mmdet.RetinaNet', + backbone=dict( + type='ResNeXt', + depth=101, + groups=64, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + num_outs=5), + bbox_head=dict( + type='RetinaHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + # model training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) + +# algorithm setting +in_channels = 256 +temp = 0.5 +alpha_fgd = 0.001 +beta_fgd = 0.0005 +gamma_fgd = 0.0005 +lambda_fgd = 0.000005 +algorithm = dict( + type='GeneralDistill', + architecture=dict( + type='MMDetArchitecture', + model=student, + ), + distiller=dict( + type='SingleTeacherDistiller', + teacher=teacher, + teacher_trainable=False, + components=[ + dict( + student_module='neck.fpn_convs.0.conv', + teacher_module='neck.fpn_convs.0.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_0', + in_channels=in_channels, + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + dict( + student_module='neck.fpn_convs.1.conv', + teacher_module='neck.fpn_convs.1.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_1', + in_channels=in_channels, + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + dict( + student_module='neck.fpn_convs.2.conv', + teacher_module='neck.fpn_convs.2.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_2', + in_channels=in_channels, + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + dict( + student_module='neck.fpn_convs.3.conv', + teacher_module='neck.fpn_convs.3.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_3', + in_channels=in_channels, + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + dict( + student_module='neck.fpn_convs.4.conv', + teacher_module='neck.fpn_convs.4.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_4', + in_channels=in_channels, + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + ]), +) + +find_unused_parameters = True + +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict( + _delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) diff --git a/tests/test_models/test_algorithms/test_algorithm.py b/tests/test_models/test_algorithms/test_algorithm.py index 1869c94e3..1ccbb5ad4 100644 --- a/tests/test_models/test_algorithms/test_algorithm.py +++ b/tests/test_models/test_algorithms/test_algorithm.py @@ -46,6 +46,84 @@ def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10): return mm_inputs +def _demo_mmdet_inputs(input_shape=(1, 3, 300, 300), + num_items=None, num_classes=10, + with_semantic=False): # yapf: disable + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + + num_items (None | List[int]): + specifies the number of boxes in each batch item + + num_classes (int): + number of different labels a box might have + """ + from mmdet.core import BitmapMasks + + (N, C, H, W) = input_shape + + rng = np.random.RandomState(0) + + imgs = rng.rand(*input_shape) + + img_metas = [{ + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': '.png', + 'scale_factor': np.array([1.1, 1.2, 1.1, 1.2]), + 'flip': False, + 'flip_direction': None, + } for _ in range(N)] + + gt_bboxes = [] + gt_labels = [] + gt_masks = [] + + for batch_idx in range(N): + if num_items is None: + num_boxes = rng.randint(1, 10) + else: + num_boxes = num_items[batch_idx] + + cx, cy, bw, bh = rng.rand(num_boxes, 4).T + + tl_x = ((cx * W) - (W * bw / 2)).clip(0, W) + tl_y = ((cy * H) - (H * bh / 2)).clip(0, H) + br_x = ((cx * W) + (W * bw / 2)).clip(0, W) + br_y = ((cy * H) + (H * bh / 2)).clip(0, H) + + boxes = np.vstack([tl_x, tl_y, br_x, br_y]).T + class_idxs = rng.randint(1, num_classes, size=num_boxes) + + gt_bboxes.append(torch.FloatTensor(boxes)) + gt_labels.append(torch.LongTensor(class_idxs)) + + mask = np.random.randint(0, 2, (len(boxes), H, W), dtype=np.uint8) + gt_masks.append(BitmapMasks(mask, H, W)) + + mm_inputs = { + 'img': torch.FloatTensor(imgs).requires_grad_(True), + 'img_metas': img_metas, + 'gt_bboxes': gt_bboxes, + 'gt_labels': gt_labels, + 'gt_bboxes_ignore': None, + 'gt_masks': gt_masks, + } + + if with_semantic: + # assume gt_semantic_seg using scale 1/8 of the img + gt_semantic_seg = np.random.randint( + 0, num_classes, (1, 1, H // 8, W // 8), dtype=np.uint8) + mm_inputs.update( + {'gt_semantic_seg': torch.ByteTensor(gt_semantic_seg)}) + + return mm_inputs + + def test_autoslim_pretrain(): model_cfg = dict( type='mmcls.ImageClassifier', @@ -546,3 +624,17 @@ def test_rkd(): losses = algorithm(imgs, return_loss=True, gt_label=label) assert losses['loss'].item() > 0 + + +def test_fgd(): + config_path = './tests/data/fgd_retina.py' + + config = Config.fromfile(config_path) + + mm_inputs = _demo_mmdet_inputs(num_classes=80) + mm_inputs.pop('gt_masks') + algorithm = ALGORITHMS.build(config.algorithm) + + # test algorithm train_step + losses = algorithm.train_step(mm_inputs, None) + assert losses['loss'].item() > 0