Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Bug] distill for yolox_s failed #184

Closed
tanghy2016 opened this issue Jun 20, 2022 · 8 comments · May be fixed by #192
Closed

[Bug] distill for yolox_s failed #184

tanghy2016 opened this issue Jun 20, 2022 · 8 comments · May be fixed by #192
Assignees

Comments

@tanghy2016
Copy link

Describe the bug

The original Backbone for yolox_s is CSPDarknet. I use ResNet18 instead. The detailed configuration is as follows:

algorithm = dict(
    type='GeneralDistill',
    architecture=dict(
        type='MMDetArchitecture',
        model=dict(
            type='mmdet.YOLOX',
            input_size=(640, 640),
            random_size_range=(15, 25),
            random_size_interval=10,
            backbone=dict(
                type='ResNet',
                depth=18,
                num_stages=4,
                out_indices=(1, 2, 3),
                norm_cfg=dict(type='BN', requires_grad=True),
                norm_eval=True,
                style='pytorch',
                init_cfg=dict(
                    type='Pretrained', checkpoint='torchvision://resnet18')),
            neck=dict(
                type='YOLOXPAFPN',
                in_channels=[128, 256, 512],
                out_channels=128,
                num_csp_blocks=1),
            bbox_head=dict(
                type='YOLOXHead',
                num_classes=1,
                in_channels=128,
                feat_channels=128),
            train_cfg=dict(
                assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
            test_cfg=dict(
                score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))),
    with_student_loss=True,
    with_teacher_loss=False,
    distiller=dict(
        type='SingleTeacherDistiller',
        teacher=dict(
            type='mmdet.YOLOX',
            init_cfg=dict(
                type='Pretrained',
                checkpoint='/root/minio_model/1518897383465840642/epoch_19.pth'
            ),
            input_size=(640, 640),
            random_size_range=(15, 25),
            random_size_interval=10,
            backbone=dict(
                type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
            neck=dict(
                type='YOLOXPAFPN',
                in_channels=[128, 256, 512],
                out_channels=128,
                num_csp_blocks=1),
            bbox_head=dict(
                type='YOLOXHead',
                num_classes=1,
                in_channels=128,
                feat_channels=128),
            train_cfg=dict(
                assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
            test_cfg=dict(
                score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))),
        teacher_trainable=False,
        components=[
            dict(
                student_module='neck.out_convs.0.conv',
                teacher_module='neck.out_convs.0.conv',
                losses=[
                    dict(
                        type='ChannelWiseDivergence',
                        name='loss_cwd_logits',
                        tau=1,
                        loss_weight=5)
                ]),
            dict(
                student_module='neck.out_convs.1.conv',
                teacher_module='neck.out_convs.1.conv',
                losses=[
                    dict(
                        type='ChannelWiseDivergence',
                        name='loss_cwd_logits',
                        tau=1,
                        loss_weight=5)
                ]),
            dict(
                student_module='neck.out_convs.2.conv',
                teacher_module='neck.out_convs.2.conv',
                losses=[
                    dict(
                        type='ChannelWiseDivergence',
                        name='loss_cwd_logits',
                        tau=1,
                        loss_weight=5)
                ])
        ]))

When it runs to the epoch 2, the following error occurs:

2022-06-20 07:27:45,022 - mmdet - INFO - Epoch [2][2/8]	lr: 6.250e-04, eta: 0:05:27, time: 0.702, data_time: 0.108, memory: 6444, student.loss_cls: 0.4989, student.loss_bbox: 4.8508, student.loss_obj: 19.1068, distiller.loss_cwd_logits.0: 0.2968, loss: 24.7533
Traceback (most recent call last):
  File "MMCV/train.py", line 52, in <module>
    main()
  File "MMCV/train.py", line 46, in main
    train_distill(args)
  File "/root/project/xbrain-ai/XbrainAiModels/MMCV/tools/train_distill.py", line 39, in train_distill
    train_mmdet(args, cfg)
  File "/root/project/xbrain-ai/XbrainAiModels/MMCV/tools/train_mmdet.py", line 43, in train_mmdet
    train_mmdet_model(
  File "/root/project/xbrain-ai/ThirdParty/MMCV/mmrazor/mmrazor/apis/mmdet/train.py", line 206, in train_mmdet_model
    runner.run(data_loader, cfg.workflow)
  File "/opt/conda/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 50, in train
    self.run_iter(data_batch, train_mode=True, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 29, in run_iter
    outputs = self.model.train_step(data_batch, self.optimizer,
  File "/opt/conda/lib/python3.8/site-packages/mmcv/parallel/data_parallel.py", line 75, in train_step
    return self.module.train_step(*inputs[0], **kwargs[0])
  File "/root/project/xbrain-ai/ThirdParty/MMCV/mmrazor/mmrazor/models/algorithms/general_distill.py", line 49, in train_step
    distill_losses = self.distiller.compute_distill_loss(data)
  File "/root/project/xbrain-ai/ThirdParty/MMCV/mmrazor/mmrazor/models/distillers/single_teacher.py", line 240, in compute_distill_loss
    losses[loss_name] = loss_module(s_out, t_out)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/root/project/xbrain-ai/ThirdParty/MMCV/mmrazor/mmrazor/models/losses/cwd.py", line 41, in forward
    assert preds_S.shape[-2:] == preds_T.shape[-2:]
AssertionError

Correct situation, no matter preds_S or preds_T, its shape should be (batch_size, 3, 80, 80), (batch_size, 3, 40, 40) or (batch_size, 3, 20, 20). But when the above error occurs, in my debugging, various sizes have appeared.

@pppppM
Copy link
Collaborator

pppppM commented Jun 21, 2022

This should not be a bug, the resize of yolox is in the forward, which will cause the teacher and student input picture size to be inconsistent.

https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/yolox.py#L93

This issue #17 may help you.

@pppppM pppppM added usage and removed bug Something isn't working labels Jun 21, 2022
@tanghy2016
Copy link
Author

tanghy2016 commented Jun 21, 2022

reference to your reply, I added:

    align_methods=[
        dict(method='YOLOX._preprocess', import_module='mmdet.models')
    ],

but it didn't take effect.

There is a note in BaseDistiller.__init__:

        if align_methods is None:
            self.context_manager = None
        else:
            # To obtain the python function's outputs, there will build a
            # specific context manager. When enter the context manager, the
            # functions will be rewrite. The context manager could record
            # inputs or outputs of the functions , and pass from teachr to
            # student. When exit the context manager, the rewritten functions
            # will restore.
            self.context_manager = ConversionContext(align_methods)

but, in SingleTeacherDistiller.exec_student_forward, input of student is **data.
Is it the problem here?

@pppppM pppppM self-assigned this Jun 22, 2022
@pppppM
Copy link
Collaborator

pppppM commented Jun 22, 2022

Same error as before adding align_method?

@tanghy2016
Copy link
Author

Yes. The config after adding align_method is as follows:

algorithm = dict(
    type='GeneralDistill',
    architecture=dict(
        type='MMDetArchitecture',
        model=dict(
            type='mmdet.YOLOX',
            input_size=(640, 640),
            random_size_range=(15, 25),
            random_size_interval=10,
            backbone=dict(
                type='ResNet',
                depth=18,
                num_stages=4,
                out_indices=(1, 2, 3),
                norm_cfg=dict(type='BN', requires_grad=True),
                norm_eval=True,
                style='pytorch',
                init_cfg=dict(
                    type='Pretrained', checkpoint='torchvision://resnet18')),
            neck=dict(
                type='YOLOXPAFPN',
                in_channels=[128, 256, 512],
                out_channels=128,
                num_csp_blocks=1),
            bbox_head=dict(
                type='YOLOXHead',
                num_classes=1,
                in_channels=128,
                feat_channels=128),
            train_cfg=dict(
                assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
            test_cfg=dict(
                score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))),
    with_student_loss=True,
    with_teacher_loss=False,
    distiller=dict(
        type='SingleTeacherDistiller',
        teacher=dict(
            type='mmdet.YOLOX',
            init_cfg=dict(
                type='Pretrained',
                checkpoint='/root/minio_model/1518897383465840642/epoch_19.pth'
            ),
            input_size=(640, 640),
            random_size_range=(15, 25),
            random_size_interval=10,
            backbone=dict(
                type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
            neck=dict(
                type='YOLOXPAFPN',
                in_channels=[128, 256, 512],
                out_channels=128,
                num_csp_blocks=1),
            bbox_head=dict(
                type='YOLOXHead',
                num_classes=1,
                in_channels=128,
                feat_channels=128),
            train_cfg=dict(
                assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
            test_cfg=dict(
                score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))),
        teacher_trainable=False,
        align_methods=[
            dict(method='YOLOX._preprocess', import_module='mmdet.models')
        ],
        components=[
            dict(
                student_module='neck.out_convs.0.conv',
                teacher_module='neck.out_convs.0.conv',
                losses=[
                    dict(
                        type='ChannelWiseDivergence',
                        name='loss_cwd_logits',
                        tau=1,
                        loss_weight=5)
                ]),
            dict(
                student_module='neck.out_convs.1.conv',
                teacher_module='neck.out_convs.1.conv',
                losses=[
                    dict(
                        type='ChannelWiseDivergence',
                        name='loss_cwd_logits',
                        tau=1,
                        loss_weight=5)
                ]),
            dict(
                student_module='neck.out_convs.2.conv',
                teacher_module='neck.out_convs.2.conv',
                losses=[
                    dict(
                        type='ChannelWiseDivergence',
                        name='loss_cwd_logits',
                        tau=1,
                        loss_weight=5)
                ])
        ]))

@pppppM pppppM linked a pull request Jun 27, 2022 that will close this issue
6 tasks
@pppppM pppppM added the Bug:P2 label Jun 27, 2022
@pppppM
Copy link
Collaborator

pppppM commented Jun 27, 2022

I just submitted a pr and added an example about yolox.

@tanghy2016
Copy link
Author

ths, solved

@tanghy2016
Copy link
Author

hi, I have a new problem. This problem is not encountered when the iteration total epoch setting is small (<10). It is encountered when it is large (20, 50, etc.). Is it related to lr_config.num_last_epochs?:

2022-09-26 15:09:49,822 - mmdet - INFO - Exp name: cwd_neck-yolox_s-csp-r18.py
2022-09-26 15:09:49,822 - mmdet - INFO - Epoch(val) [4][49]	AP50: 0.0000, mAP: 0.0004
2022-09-26 15:09:49,834 - mmdet - INFO - No mosaic and mixup aug now!
2022-09-26 15:09:49,940 - mmdet - INFO - Add additional L1 loss now!
Traceback (most recent call last):
  File "MMCV/train.py", line 55, in <module>
    main()
  File "MMCV/train.py", line 49, in main
    train_distill(args)
  File "/root/project/xbrain-ai/XbrainAiModels/MMCV/tools/train_distill.py", line 44, in train_distill
    train_mmdet(args, cfg)
  File "/root/project/xbrain-ai/XbrainAiModels/MMCV/tools/train_mmdet.py", line 43, in train_mmdet
    train_mmdet_model(
  File "/root/project/xbrain-ai/ThirdParty/MMCV/mmrazor/mmrazor/apis/mmdet/train.py", line 206, in train_mmdet_model
    runner.run(data_loader, cfg.workflow)
  File "/opt/conda/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 45, in train
    self.call_hook('before_train_epoch')
  File "/opt/conda/lib/python3.8/site-packages/mmcv/runner/base_runner.py", line 309, in call_hook
    getattr(hook, fn_name)(self)
  File "/root/project/xbrain-ai/ThirdParty/MMCV/mmdetection/mmdet/core/hook/yolox_mode_switch_hook.py", line 47, in before_train_epoch
    model.bbox_head.use_l1 = True
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 947, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'AlignMethodDistill' object has no attribute 'bbox_head'

lr_config as follow:

max_epochs = 20
num_last_epochs = 15
resume_from = None
interval = 10

# learning policy
lr_config = dict(
    _delete_=True,
    policy='YOLOX',
    warmup='exp',
    by_epoch=False,
    warmup_by_epoch=True,
    warmup_ratio=1,
    warmup_iters=3,  # 5 epoch
    num_last_epochs=num_last_epochs,
    min_lr_ratio=0.05)

@Linker-Stars
Copy link

hi, I have a new problem. This problem is not encountered when the iteration total epoch setting is small (<10). It is encountered when it is large (20, 50, etc.). Is it related to lr_config.num_last_epochs?:

2022-09-26 15:09:49,822 - mmdet - INFO - Exp name: cwd_neck-yolox_s-csp-r18.py
2022-09-26 15:09:49,822 - mmdet - INFO - Epoch(val) [4][49]	AP50: 0.0000, mAP: 0.0004
2022-09-26 15:09:49,834 - mmdet - INFO - No mosaic and mixup aug now!
2022-09-26 15:09:49,940 - mmdet - INFO - Add additional L1 loss now!
Traceback (most recent call last):
  File "MMCV/train.py", line 55, in <module>
    main()
  File "MMCV/train.py", line 49, in main
    train_distill(args)
  File "/root/project/xbrain-ai/XbrainAiModels/MMCV/tools/train_distill.py", line 44, in train_distill
    train_mmdet(args, cfg)
  File "/root/project/xbrain-ai/XbrainAiModels/MMCV/tools/train_mmdet.py", line 43, in train_mmdet
    train_mmdet_model(
  File "/root/project/xbrain-ai/ThirdParty/MMCV/mmrazor/mmrazor/apis/mmdet/train.py", line 206, in train_mmdet_model
    runner.run(data_loader, cfg.workflow)
  File "/opt/conda/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 45, in train
    self.call_hook('before_train_epoch')
  File "/opt/conda/lib/python3.8/site-packages/mmcv/runner/base_runner.py", line 309, in call_hook
    getattr(hook, fn_name)(self)
  File "/root/project/xbrain-ai/ThirdParty/MMCV/mmdetection/mmdet/core/hook/yolox_mode_switch_hook.py", line 47, in before_train_epoch
    model.bbox_head.use_l1 = True
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 947, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'AlignMethodDistill' object has no attribute 'bbox_head'

lr_config as follow:

max_epochs = 20
num_last_epochs = 15
resume_from = None
interval = 10

# learning policy
lr_config = dict(
    _delete_=True,
    policy='YOLOX',
    warmup='exp',
    by_epoch=False,
    warmup_by_epoch=True,
    warmup_ratio=1,
    warmup_iters=3,  # 5 epoch
    num_last_epochs=num_last_epochs,
    min_lr_ratio=0.05)

I think this is caused by the YOLOXModeSwitchHook of the yolox model. You can rewrite this hook and add model.bbox_ head.use_ L1=True is changed to model.architecture.model.bbox_ head.use_ l1 = True

humu789 pushed a commit to humu789/mmrazor that referenced this issue Feb 13, 2023
* add docs about config comment

* fix blank

* fix comment

* fix comment

* fix comment

* fix comment
humu789 pushed a commit to humu789/mmrazor that referenced this issue Feb 13, 2023
* [enhancement] Reorganizing OpenMMLab projects in readme

* add MMPose in supported codebase list

* add MMPose in supported codebase list
# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants