Skip to content

Commit

Permalink
[Enhancement] Refine mmcls rewriting (open-mmlab#106)
Browse files Browse the repository at this point in the history
* fix mmcls head

* fix lint
  • Loading branch information
AllentDan authored Sep 28, 2021
1 parent 8633993 commit 9b070a5
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 85 deletions.
15 changes: 3 additions & 12 deletions mmdeploy/mmcls/models/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,4 @@
from .cls_head import simple_test_of_cls_head
from .linear_head import simple_test_of_linear_head
from .multi_label_head import simple_test_of_multi_label_head
from .multi_label_linear_head import simple_test_of_multi_label_linear_head
from .stacked_head import simple_test_of_stacked_head
from .vision_transformer_head import simple_test_of_vision_transformer_head
from .cls_head import post_process_of_cls_head
from .multi_label_head import post_process_of_multi_label_head

__all__ = [
'simple_test_of_multi_label_linear_head',
'simple_test_of_multi_label_head', 'simple_test_of_cls_head',
'simple_test_of_linear_head', 'simple_test_of_stacked_head',
'simple_test_of_vision_transformer_head'
]
__all__ = ['post_process_of_cls_head', 'post_process_of_multi_label_head']
10 changes: 2 additions & 8 deletions mmdeploy/mmcls/models/heads/cls_head.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
import torch.nn.functional as F

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.ClsHead.simple_test')
def simple_test_of_cls_head(ctx, self, cls_score, **kwargs):
"""Test without augmentation."""
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
func_name='mmcls.models.heads.ClsHead.post_process')
def post_process_of_cls_head(ctx, self, pred, **kwargs):
return pred
14 changes: 0 additions & 14 deletions mmdeploy/mmcls/models/heads/linear_head.py

This file was deleted.

9 changes: 2 additions & 7 deletions mmdeploy/mmcls/models/heads/multi_label_head.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import torch.nn.functional as F

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.MultiLabelClsHead.simple_test')
def simple_test_of_multi_label_head(ctx, self, cls_score, **kwargs):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None
func_name='mmcls.models.heads.MultiLabelClsHead.post_process')
def post_process_of_multi_label_head(ctx, self, pred, **kwargs):
return pred
14 changes: 0 additions & 14 deletions mmdeploy/mmcls/models/heads/multi_label_linear_head.py

This file was deleted.

16 changes: 0 additions & 16 deletions mmdeploy/mmcls/models/heads/stacked_head.py

This file was deleted.

14 changes: 0 additions & 14 deletions mmdeploy/mmcls/models/heads/vision_transformer_head.py

This file was deleted.

0 comments on commit 9b070a5

Please # to comment.