diff --git a/mmflow/models/losses/multilevel_charbonnier_loss.py b/mmflow/models/losses/multilevel_charbonnier_loss.py index 941d1b4a..5d386422 100644 --- a/mmflow/models/losses/multilevel_charbonnier_loss.py +++ b/mmflow/models/losses/multilevel_charbonnier_loss.py @@ -62,6 +62,17 @@ class MultiLevelCharbonnierLoss(nn.Module): max_flow (float): maximum value of optical flow, if some pixel's flow of target is larger than it, this pixel is not valid. Default to float('inf'). + resize_flow (str): mode for reszing flow: 'downsample' and 'upsample', + as multi-level predicted outputs don't match the ground truth. + If set to 'downsample', it will downsample the ground truth, and + if set to 'upsample' it will upsample the predicted flow, and + 'upsample' is used for sparse flow map as no generic interpolation + mode can resize a ground truth of sparse flow correctly. + Default to 'downsample'. + scale_as_level (bool): Whether flow for each level is at its native + spatial resolution. If `'scale_as_level'` is True, the ground + truth is scaled at different levels, if it is False, the ground + truth will not be scaled. Default to False. reduction (str): the reduction to apply to the output:'none', 'mean', 'sum'. 'none': no reduction will be applied and will return a full-size epe map, 'mean': the mean of the epe map is taken, 'sum': @@ -81,6 +92,7 @@ def __init__(self, level2=0.005), max_flow: float = float('inf'), resize_flow: str = 'downsample', + scale_as_level: bool = False, reduction: str = 'sum') -> None: super().__init__() @@ -102,6 +114,9 @@ def __init__(self, assert resize_flow in ('downsample', 'upsample') self.resize_flow = resize_flow + assert isinstance(scale_as_level, bool) + self.scale_as_level = scale_as_level + assert reduction in ('mean', 'sum') self.reduction = reduction @@ -133,6 +148,7 @@ def forward(self, flow_div=self.flow_div, max_flow=self.max_flow, resize_flow=self.resize_flow, + scale_as_level=self.scale_as_level, reduction=self.reduction, q=self.q, eps=self.eps, @@ -141,7 +157,9 @@ def forward(self, def __repr__(self) -> str: repr_str = self.__class__.__name__ - repr_str += (f'(flow_div={self.flow_div}, ' + repr_str += (f'(resize_flow={self.resize_flow}, ' + f'scale_as_level={self.scale_as_level}, ' + f'flow_div={self.flow_div}, ' f'weights={self.weights}, ' f'q={self.q}, ' f'eps={self.eps}, ' diff --git a/mmflow/models/losses/multilevel_epe.py b/mmflow/models/losses/multilevel_epe.py index 90ce7d10..c3410c9e 100644 --- a/mmflow/models/losses/multilevel_epe.py +++ b/mmflow/models/losses/multilevel_epe.py @@ -79,6 +79,10 @@ class MultiLevelEPE(nn.Module): 'upsample' is used for sparse flow map as no generic interpolation mode can resize a ground truth of sparse flow correctly. Default to 'downsample'. + scale_as_level (bool): Whether flow for each level is at its native + spatial resolution. If `'scale_as_level'` is True, the ground + truth is scaled at different levels, if it is False, the ground + truth will not be scaled. Default to False. reduction (str): the reduction to apply to the output:'none', 'mean', 'sum'. 'none': no reduction will be applied and will return a full-size epe map, 'mean': the mean of the epe map is taken, 'sum': @@ -99,6 +103,7 @@ def __init__(self, flow_div: float = 20., max_flow: float = float('inf'), resize_flow: str = 'downsample', + scale_as_level: bool = False, reduction: str = 'sum') -> None: super().__init__() @@ -126,6 +131,9 @@ def __init__(self, assert resize_flow in ('downsample', 'upsample') self.resize_flow = resize_flow + assert isinstance(scale_as_level, bool) + self.scale_as_level = scale_as_level + assert reduction in ('mean', 'sum') self.reduction = reduction @@ -157,6 +165,7 @@ def forward(self, flow_div=self.flow_div, max_flow=self.max_flow, resize_flow=self.resize_flow, + scale_as_level=self.scale_as_level, reduction=self.reduction, p=self.p, q=self.q, @@ -166,7 +175,9 @@ def forward(self, def __repr__(self) -> str: repr_str = self.__class__.__name__ - repr_str += (f'(flow_div={self.flow_div}, ' + repr_str += (f'(resize_flow={self.resize_flow}, ' + f'scale_as_level={self.scale_as_level}, ' + f'flow_div={self.flow_div}, ' f'weights={self.weights}, ' f'p={self.p}, ' f'q={self.q}, ' diff --git a/mmflow/models/losses/multilevel_flow_loss.py b/mmflow/models/losses/multilevel_flow_loss.py index d8d13701..d7202297 100644 --- a/mmflow/models/losses/multilevel_flow_loss.py +++ b/mmflow/models/losses/multilevel_flow_loss.py @@ -20,6 +20,7 @@ def multi_level_flow_loss(loss_function, max_flow: float = float('inf'), resize_flow: str = 'downsample', reduction: str = 'sum', + scale_as_level: bool = False, **kwargs) -> torch.Tensor: """Multi-level endpoint error loss function. @@ -45,6 +46,17 @@ def multi_level_flow_loss(loss_function, full-size epe map, 'mean': the mean of the epe map is taken, 'sum': the epe map will be summed but averaged by batch_size. Default: 'sum'. + resize_flow (str): mode for reszing flow: 'downsample' and 'upsample', + as multi-level predicted outputs don't match the ground truth. + If set to 'downsample', it will downsample the ground truth, and + if set to 'upsample' it will upsample the predicted flow, and + 'upsample' is used for sparse flow map as no generic interpolation + mode can resize a ground truth of sparse flow correctly. + Default to 'downsample'. + scale_as_level (bool): Whether flow for each level is at its native + spatial resolution. If `'scale_as_level'` is True, the ground + truth is scaled at different levels, if it is False, the ground + truth will not be scaled. Default to False. kwargs: arguments for loss_function. Returns: @@ -65,6 +77,9 @@ def multi_level_flow_loss(loss_function, target_div = target / flow_div + c_org, h_org, w_org = target.shape[1:] + assert c_org == 2, f'The channels ground truth must be 2, but got {c_org}' + loss = 0 for level in weights.keys(): @@ -77,6 +92,10 @@ def multi_level_flow_loss(loss_function, b, _, h, w = cur_pred[0].shape + scale_factor = torch.Tensor([ + float(w / w_org), float(h / h_org) + ]).to(target) if scale_as_level else torch.Tensor([1., 1.]).to(target) + cur_weight = weights.get(level) if resize_flow == 'downsample': @@ -104,6 +123,9 @@ def multi_level_flow_loss(loss_function, mode='bilinear', align_corners=False) + cur_target = torch.einsum('b c h w, c -> b c h w', cur_target, + scale_factor) + loss_map += loss_function(i_pred, cur_target, **kwargs) * cur_valid if reduction == 'mean': diff --git a/tests/test_models/test_losses.py b/tests/test_models/test_losses.py index 2e5f0700..27589e6d 100644 --- a/tests/test_models/test_losses.py +++ b/tests/test_models/test_losses.py @@ -20,7 +20,7 @@ def test_multi_level_endpoint_error(): gt = torch.randn(b, 2, h, w) weights = dict(level1=1.) - # test pred does not match gt + # test gt channels is not 2 with pytest.raises(AssertionError): multi_level_flow_loss( endpoint_error, pred, torch.randn(b, 1, 1, 1), weights=weights) @@ -99,7 +99,7 @@ def test_multi_level_charbonnier_loss(): gt = torch.randn(b, 2, h, w) weights = dict(level1=1.) - # test pred does not match gt + # test gt channels is not 2 with pytest.raises(AssertionError): multi_level_flow_loss( charbonnier_loss, pred, torch.randn(b, 1, 1, 1), weights=weights) @@ -124,10 +124,10 @@ def test_multi_level_charbonnier_loss(): assert torch.allclose(loss_gt, loss) -@pytest.mark.parametrize(['reduction', 'resize_flow'], - [['mean', 'upsample'], ['sum', 'upsample'], - ['mean', 'downsample'], ['sum', 'downsample']]) -def test_multilevel_epe(reduction, resize_flow): +@pytest.mark.parametrize('reduction', ('mean', 'sum')) +@pytest.mark.parametrize('resize_flow', ['upsample', 'downsample']) +@pytest.mark.parametrize('scale_as_level', [True, False]) +def test_multilevel_epe(reduction, resize_flow, scale_as_level): b = 8 @@ -157,24 +157,29 @@ def test_multilevel_epe(reduction, resize_flow): with pytest.raises(AssertionError): MultiLevelEPE(resize_flow='z') + # test invalid scale_as_level + with pytest.raises(AssertionError): + MultiLevelEPE(scale_as_level='a') + def answer(): loss = 0 weights = [0.005, 0.01] scales = [2, 4] - + scale_factor = [1 / 2, 1 / 4] if scale_as_level else [1., 1.] div_gt = gt / 20. for i in range(len(weights)): if resize_flow == 'downsample': - cur_gt = F.avg_pool2d(div_gt, scales[i]) + cur_gt = F.avg_pool2d(div_gt, scales[i]) * scale_factor[i] cur_pred = preds_list[i] else: - cur_gt = div_gt + cur_gt = div_gt * scale_factor[i] cur_pred = F.interpolate( preds_list[i], size=(24, 32), mode='bilinear', align_corners=False) + l2_loss = torch.norm(cur_pred - cur_gt, p=2, dim=1) if reduction == 'mean': loss += l2_loss.mean() * weights[i] @@ -187,7 +192,10 @@ def answer(): # test accuracy of mean reduction loss_func = MultiLevelEPE( - weights=weights, reduction=reduction, resize_flow=resize_flow) + weights=weights, + reduction=reduction, + resize_flow=resize_flow, + scale_as_level=scale_as_level) loss = loss_func(preds, gt) assert torch.isclose(answer_, loss, atol=1e-4) @@ -285,14 +293,16 @@ def test_sequence_loss(): @pytest.mark.parametrize('reduction', ('mean', 'sum')) -def test_multi_levels_charbonnier(reduction): +@pytest.mark.parametrize('resize_flow', ['upsample', 'downsample']) +@pytest.mark.parametrize('scale_as_level', [True, False]) +def test_multi_levels_charbonnier(reduction, resize_flow, scale_as_level): b = 2 - flow2 = torch.randn(b, 2, 16, 16) - flow3 = torch.randn(b, 2, 8, 8) + flow2 = torch.randn(b, 2, 12, 16) + flow3 = torch.randn(b, 2, 6, 8) - gt = torch.randn(b, 2, 64, 64) + gt = torch.randn(b, 2, 24, 32) preds_list = [flow2, flow3] preds = { @@ -307,17 +317,39 @@ def test_multi_levels_charbonnier(reduction): with pytest.raises(AssertionError): MultiLevelCharbonnierLoss(weights=[0.005, 0.01]) + # test reduction value + with pytest.raises(AssertionError): + MultiLevelEPE(reduction=None) + + # test invalid resize_flow + with pytest.raises(AssertionError): + MultiLevelEPE(resize_flow='z') + + # test invalid scale_as_level + with pytest.raises(AssertionError): + MultiLevelEPE(scale_as_level='a') + def answer(): loss = 0 weights = [0.005, 0.01] - scales = [4, 8] - + scales = [2, 4] + scale_factor = [1 / 2, 1 / 4] if scale_as_level else [1., 1.] div_gt = gt / 20. for i in range(len(weights)): - cur_gt = F.avg_pool2d(div_gt, scales[i]) - loss_square = torch.sum((preds_list[i] - cur_gt)**2, dim=1) + if resize_flow == 'downsample': + cur_gt = F.avg_pool2d(div_gt, scales[i]) * scale_factor[i] + cur_pred = preds_list[i] + else: + cur_gt = div_gt * scale_factor[i] + cur_pred = F.interpolate( + preds_list[i], + size=(24, 32), + mode='bilinear', + align_corners=False) + + loss_square = torch.sum((cur_pred - cur_gt)**2, dim=1) if reduction == 'mean': loss += ((loss_square + 0.01)**0.2).mean() * weights[i] else: @@ -328,10 +360,11 @@ def answer(): answer_ = answer() # test accuracy of mean reduction - loss_obj = MultiLevelCharbonnierLoss(weights=weights, reduction=reduction) + loss_obj = MultiLevelCharbonnierLoss( + weights=weights, resize_flow=resize_flow, reduction=reduction) loss = loss_obj(preds, gt) - assert torch.isclose(answer_, loss, atol=1e-4) + assert torch.isclose(answer_, loss, rtol=1e-2) valid = torch.zeros_like(gt[:, 0, :, :]) loss = loss_obj(preds, gt, valid) - assert torch.isclose(torch.Tensor([0.]), loss, atol=1e-4) + assert torch.isclose(torch.Tensor([0.]), loss, rtol=1e-2)