From e4ad2ff5a86f2966cc98a5902fa798e0429cda56 Mon Sep 17 00:00:00 2001 From: "alec.tu" Date: Tue, 21 Feb 2023 14:44:33 +0800 Subject: [PATCH 1/2] [Feature] Add mask channel in MGD Loss --- .../architectures/connectors/mgd_connector.py | 8 +++++++- .../test_connectors/test_connectors.py | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/mmrazor/models/architectures/connectors/mgd_connector.py b/mmrazor/models/architectures/connectors/mgd_connector.py index 41b77df4c..9b53fed1d 100644 --- a/mmrazor/models/architectures/connectors/mgd_connector.py +++ b/mmrazor/models/architectures/connectors/mgd_connector.py @@ -27,10 +27,12 @@ def __init__( student_channels: int, teacher_channels: int, lambda_mgd: float = 0.65, + mask_on_channel: bool = False, init_cfg: Optional[Dict] = None, ) -> None: super().__init__(init_cfg) self.lambda_mgd = lambda_mgd + self.mask_on_channel = mask_on_channel if student_channels != teacher_channels: self.align = nn.Conv2d( student_channels, @@ -55,7 +57,11 @@ def forward_train(self, feature: torch.Tensor) -> torch.Tensor: N, C, H, W = feature.shape device = feature.device - mat = torch.rand((N, 1, H, W)).to(device) + if not self.mask_on_channel: + mat = torch.rand((N, 1, H, W)).to(device) + else: + mat = torch.rand((N, C, 1, 1)).to(device) + mat = torch.where(mat > 1 - self.lambda_mgd, torch.zeros(1).to(device), torch.ones(1).to(device)).to(device) diff --git a/tests/test_models/test_architectures/test_connectors/test_connectors.py b/tests/test_models/test_architectures/test_connectors/test_connectors.py index 1f7137022..4808ea216 100644 --- a/tests/test_models/test_architectures/test_connectors/test_connectors.py +++ b/tests/test_models/test_architectures/test_connectors/test_connectors.py @@ -143,3 +143,19 @@ def test_mgd_connector(self): assert s_output1.shape == torch.Size([1, 16, 8, 8]) assert s_output2.shape == torch.Size([1, 32, 8, 8]) + + mgd_connector1 = MGDConnector( + student_channels=16, + teacher_channels=16, + lambda_mgd=0.65, + mask_on_channel=True) + mgd_connector2 = MGDConnector( + student_channels=16, + teacher_channels=32, + lambda_mgd=0.65, + mask_on_channel=True) + s_output1 = mgd_connector1.forward_train(s_feat) + s_output2 = mgd_connector2.forward_train(s_feat) + + assert s_output1.shape == torch.Size([1, 16, 8, 8]) + assert s_output2.shape == torch.Size([1, 32, 8, 8]) From 73bbdb01e22089af64991beb97b32a2abcd5f5f4 Mon Sep 17 00:00:00 2001 From: pppppM Date: Wed, 1 Mar 2023 17:42:09 +0800 Subject: [PATCH 2/2] fix lint --- .../test_architectures/test_connectors/test_connectors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_models/test_architectures/test_connectors/test_connectors.py b/tests/test_models/test_architectures/test_connectors/test_connectors.py index c42b8f026..80b3f88b2 100644 --- a/tests/test_models/test_architectures/test_connectors/test_connectors.py +++ b/tests/test_models/test_architectures/test_connectors/test_connectors.py @@ -167,4 +167,3 @@ def test_norm_connector(self): output = norm_connector.forward_train(s_feat) assert output.shape == torch.Size([2, 3, 2, 2]) -