diff --git a/mmrazor/models/architectures/connectors/mgd_connector.py b/mmrazor/models/architectures/connectors/mgd_connector.py index 41b77df4c..9b53fed1d 100644 --- a/mmrazor/models/architectures/connectors/mgd_connector.py +++ b/mmrazor/models/architectures/connectors/mgd_connector.py @@ -27,10 +27,12 @@ def __init__( student_channels: int, teacher_channels: int, lambda_mgd: float = 0.65, + mask_on_channel: bool = False, init_cfg: Optional[Dict] = None, ) -> None: super().__init__(init_cfg) self.lambda_mgd = lambda_mgd + self.mask_on_channel = mask_on_channel if student_channels != teacher_channels: self.align = nn.Conv2d( student_channels, @@ -55,7 +57,11 @@ def forward_train(self, feature: torch.Tensor) -> torch.Tensor: N, C, H, W = feature.shape device = feature.device - mat = torch.rand((N, 1, H, W)).to(device) + if not self.mask_on_channel: + mat = torch.rand((N, 1, H, W)).to(device) + else: + mat = torch.rand((N, C, 1, 1)).to(device) + mat = torch.where(mat > 1 - self.lambda_mgd, torch.zeros(1).to(device), torch.ones(1).to(device)).to(device) diff --git a/tests/test_models/test_architectures/test_connectors/test_connectors.py b/tests/test_models/test_architectures/test_connectors/test_connectors.py index a2a2dcadc..80b3f88b2 100644 --- a/tests/test_models/test_architectures/test_connectors/test_connectors.py +++ b/tests/test_models/test_architectures/test_connectors/test_connectors.py @@ -144,6 +144,22 @@ def test_mgd_connector(self): assert s_output1.shape == torch.Size([1, 16, 8, 8]) assert s_output2.shape == torch.Size([1, 32, 8, 8]) + mgd_connector1 = MGDConnector( + student_channels=16, + teacher_channels=16, + lambda_mgd=0.65, + mask_on_channel=True) + mgd_connector2 = MGDConnector( + student_channels=16, + teacher_channels=32, + lambda_mgd=0.65, + mask_on_channel=True) + s_output1 = mgd_connector1.forward_train(s_feat) + s_output2 = mgd_connector2.forward_train(s_feat) + + assert s_output1.shape == torch.Size([1, 16, 8, 8]) + assert s_output2.shape == torch.Size([1, 32, 8, 8]) + def test_norm_connector(self): s_feat = torch.randn(2, 3, 2, 2) norm_cfg = dict(type='BN', affine=False, track_running_stats=False)