Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Feature] Add mask channel in MGD Loss #461

Merged
merged 3 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mmrazor/models/architectures/connectors/mgd_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ def __init__(
student_channels: int,
teacher_channels: int,
lambda_mgd: float = 0.65,
mask_on_channel: bool = False,
init_cfg: Optional[Dict] = None,
) -> None:
super().__init__(init_cfg)
self.lambda_mgd = lambda_mgd
self.mask_on_channel = mask_on_channel
if student_channels != teacher_channels:
self.align = nn.Conv2d(
student_channels,
Expand All @@ -55,7 +57,11 @@ def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
N, C, H, W = feature.shape

device = feature.device
mat = torch.rand((N, 1, H, W)).to(device)
if not self.mask_on_channel:
mat = torch.rand((N, 1, H, W)).to(device)
else:
mat = torch.rand((N, C, 1, 1)).to(device)

mat = torch.where(mat > 1 - self.lambda_mgd,
torch.zeros(1).to(device),
torch.ones(1).to(device)).to(device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,19 @@ def test_mgd_connector(self):

assert s_output1.shape == torch.Size([1, 16, 8, 8])
assert s_output2.shape == torch.Size([1, 32, 8, 8])

mgd_connector1 = MGDConnector(
student_channels=16,
teacher_channels=16,
lambda_mgd=0.65,
mask_on_channel=True)
mgd_connector2 = MGDConnector(
student_channels=16,
teacher_channels=32,
lambda_mgd=0.65,
mask_on_channel=True)
s_output1 = mgd_connector1.forward_train(s_feat)
s_output2 = mgd_connector2.forward_train(s_feat)

assert s_output1.shape == torch.Size([1, 16, 8, 8])
assert s_output2.shape == torch.Size([1, 32, 8, 8])