forked from open-mmlab/mmdetection3d
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhancement] Refine mmcls rewriting (open-mmlab#106)
* fix mmcls head * fix lint
- Loading branch information
Showing
7 changed files
with
7 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,4 @@ | ||
from .cls_head import simple_test_of_cls_head | ||
from .linear_head import simple_test_of_linear_head | ||
from .multi_label_head import simple_test_of_multi_label_head | ||
from .multi_label_linear_head import simple_test_of_multi_label_linear_head | ||
from .stacked_head import simple_test_of_stacked_head | ||
from .vision_transformer_head import simple_test_of_vision_transformer_head | ||
from .cls_head import post_process_of_cls_head | ||
from .multi_label_head import post_process_of_multi_label_head | ||
|
||
__all__ = [ | ||
'simple_test_of_multi_label_linear_head', | ||
'simple_test_of_multi_label_head', 'simple_test_of_cls_head', | ||
'simple_test_of_linear_head', 'simple_test_of_stacked_head', | ||
'simple_test_of_vision_transformer_head' | ||
] | ||
__all__ = ['post_process_of_cls_head', 'post_process_of_multi_label_head'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,7 @@ | ||
import torch.nn.functional as F | ||
|
||
from mmdeploy.core import FUNCTION_REWRITER | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='mmcls.models.heads.ClsHead.simple_test') | ||
def simple_test_of_cls_head(ctx, self, cls_score, **kwargs): | ||
"""Test without augmentation.""" | ||
if isinstance(cls_score, list): | ||
cls_score = sum(cls_score) / float(len(cls_score)) | ||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None | ||
func_name='mmcls.models.heads.ClsHead.post_process') | ||
def post_process_of_cls_head(ctx, self, pred, **kwargs): | ||
return pred |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,7 @@ | ||
import torch.nn.functional as F | ||
|
||
from mmdeploy.core import FUNCTION_REWRITER | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='mmcls.models.heads.MultiLabelClsHead.simple_test') | ||
def simple_test_of_multi_label_head(ctx, self, cls_score, **kwargs): | ||
if isinstance(cls_score, list): | ||
cls_score = sum(cls_score) / float(len(cls_score)) | ||
pred = F.sigmoid(cls_score) if cls_score is not None else None | ||
func_name='mmcls.models.heads.MultiLabelClsHead.post_process') | ||
def post_process_of_multi_label_head(ctx, self, pred, **kwargs): | ||
return pred |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.