diff --git a/mmflow/models/utils/__init__.py b/mmflow/models/utils/__init__.py index 95c1596..3a4882d 100644 --- a/mmflow/models/utils/__init__.py +++ b/mmflow/models/utils/__init__.py @@ -3,10 +3,11 @@ from .correlation_block import CorrBlock from .densenet import BasicDenseBlock, DenseLayer from .estimators_link import BasicLink, LinkOutput +from .occlusion_estimation import occlusion_estimation from .res_layer import BasicBlock, Bottleneck, ResLayer __all__ = [ 'ResLayer', 'BasicBlock', 'Bottleneck', 'BasicLink', 'LinkOutput', 'DenseLayer', 'BasicDenseBlock', 'BasicEncoder', 'BasicConvBlock', - 'CorrBlock' + 'CorrBlock', 'occlusion_estimation' ] diff --git a/mmflow/models/utils/occlusion_estimation.py b/mmflow/models/utils/occlusion_estimation.py new file mode 100644 index 0000000..2d83203 --- /dev/null +++ b/mmflow/models/utils/occlusion_estimation.py @@ -0,0 +1,185 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict + +import torch +from torch import Tensor + +from mmflow.ops import build_operators + + +def flow_to_coords(flow: Tensor) -> Tensor: + """Generate shifted coordinate grid based based input flow. + + Args: + flow (Tensor): Estimated optical flow. + + Returns: + Tensor: Coordinate that shifted by input flow with shape (B, 2, H, W). + """ + B, _, H, W = flow.shape + xx = torch.arange(0, W, device=flow.device, requires_grad=False) + yy = torch.arange(0, H, device=flow.device, requires_grad=False) + coords = torch.meshgrid(yy, xx) + coords = torch.stack(coords[::-1], dim=0).float() + coords = coords[None].repeat(B, 1, 1, 1) + flow + return coords + + +def compute_range_map(flow: Tensor) -> Tensor: + """Compute range map. + + Args: + flow (Tensor): The backward flow with shape (N, 2, H, W) + + Return: + Tensor: The forward-to-backward occlusion mask with shape (N, 1, H, W) + """ + + N, _, H, W = flow.shape + + coords = flow_to_coords(flow) + + # Split coordinates into an integer part and + # a float offset for interpolation. + coords_floor = torch.floor(coords) + coords_offset = coords - coords_floor + coords_floor = coords_floor.to(torch.int32) + + # Define a batch offset for flattened indexes into all pixels. + batch_range = torch.arange(N).view(N, 1, 1) + idx_batch_offset = batch_range.repeat(1, H, W) * H * W + + # Flatten everything. + coords_floor_flattened = coords_floor.permute(0, 2, 3, 1).reshape(-1, 2) + coords_offset_flattened = coords_offset.permute(0, 2, 3, 1).reshape(-1, 2) + idx_batch_offset_flattened = idx_batch_offset.reshape(-1) + + # Initialize results. + idxs_list = [] + weights_list = [] + + # Loop over differences di and dj to the four neighboring pixels. + for di in range(2): + for dj in range(2): + # Compute the neighboring pixel coordinates. + idxs_j = coords_floor_flattened[..., 0] + dj + idxs_i = coords_floor_flattened[..., 1] + di + # Compute the flat index into all pixels. + idxs = idx_batch_offset_flattened + idxs_i * W + idxs_j + + # Only count valid pixels. + mask = torch.logical_and( + torch.logical_and(idxs_j >= 0, idxs_j < W), + torch.logical_and(idxs_i >= 0, idxs_i < H)) + valid_idxs = idxs[mask] + valid_offsets = coords_offset_flattened[mask] + + # Compute weights according to bilinear interpolation. + weights_j = (1. - dj) - (-1)**dj * valid_offsets[:, 0] + weights_i = (1. - di) - (-1)**di * valid_offsets[:, 1] + weights = weights_i * weights_j + + # Append indices and weights to the corresponding list. + idxs_list.append(valid_idxs) + weights_list.append(weights) + # Concatenate everything. + idxs = torch.cat(idxs_list, dim=0) + weights = torch.cat(weights_list, dim=0) + + # Sum up weights for each pixel and reshape the result. + count_image = torch.zeros(N * H * W) + count_image = count_image.index_add_( + dim=0, index=idxs, source=weights).reshape(N, H, W) + occ = (count_image >= 1).to(flow)[:, None, ...] + return occ + + +def forward_backward_consistency( + flow_fw: Tensor, + flow_bw: Tensor, + warp_cfg: dict = dict(type='Warp', align_corners=True), +) -> Tensor: + """Occlusion mask from forward-backward consistency. + + Args: + flow_fw (Tensor): The forward flow with shape (N, 2, H, W) + flow_bw (Tensor): The backward flow with shape (N, 2, H, W) + + Returns: + Tensor: The forward-to-backward occlusion mask with shape (N, 1, H, W) + """ + + warp = build_operators(warp_cfg) + + warped_flow_bw = warp(flow_bw, flow_fw) + + forward_backward_sq_diff = torch.sum( + (flow_fw + warped_flow_bw)**2, dim=1, keepdim=True) + forward_backward_sum_sq = torch.sum( + flow_fw * 2 + warped_flow_bw**2, dim=1, keepdim=True) + + occ = (forward_backward_sq_diff < + forward_backward_sum_sq * 0.01 + 0.5).to(flow_fw) + return occ + + +def forward_backward_absdiff(flow_fw: Tensor, + flow_bw: Tensor, + warp_cfg: dict = dict( + type='Warp', align_corners=True), + diff: int = 1.5) -> Tensor: + """Occlusion mask from forward-backward consistency. + + Args: + flow_fw (Tensor): The forward flow with shape (N, 2, H, W) + flow_bw (Tensor): The backward flow with shape (N, 2, H, W) + + Returns: + Tensor: The forward-to-backward occlusion mask with shape (N, 1, H, W) + """ + + warp = build_operators(warp_cfg) + + warped_flow_bw = warp(flow_bw, flow_fw) + + forward_backward_sq_diff = torch.sum( + (flow_fw + warped_flow_bw)**2, dim=1, keepdim=True) + + occ = (forward_backward_sq_diff**0.5 < diff).to(flow_fw) + + return occ + + +def occlusion_estimation(flow_fw: Tensor, + flow_bw: Tensor, + mode: str = 'consistency', + **kwarg) -> Dict[str, Tensor]: + """Occlusion estimation. + + Args: + flow_fw (Tensor): The forward flow with shape (N, 2, H, W) + flow_bw (Tensor): The backward flow with shape (N, 2, H, W) + mode (str): The method for occlusion estimation, which can be + ``'consistency'``, ``'range_map'`` or ``'fb_abs'``. + warp_cfg (dict, optional): _description_. Defaults to None. + + Returns: + Dict[str,Tensor]: 1 denote non-occluded and 0 denote occluded + """ + assert mode in ('consistency', 'range_map', 'fb_abs'), \ + 'mode must be \'consistency\', \'range_map\' or \'fb_abs\', ' \ + f'but got {mode}' + + if mode == 'consistency': + occ_fw = forward_backward_consistency(flow_fw, flow_bw, **kwarg) + occ_bw = forward_backward_consistency(flow_bw, flow_fw, **kwarg) + + elif mode == 'range_map': + occ_fw = compute_range_map(flow_bw) + occ_bw = compute_range_map(flow_fw) + + elif mode == 'fb_abs': + occ_fw = forward_backward_absdiff(flow_fw, flow_bw, **kwarg) + occ_bw = forward_backward_absdiff(flow_bw, flow_fw, **kwarg) + + return dict(occ_fw=occ_fw, occ_bw=occ_bw) diff --git a/tests/test_models/test_utils/test_occlusion_esimation.py b/tests/test_models/test_utils/test_occlusion_esimation.py new file mode 100644 index 0000000..d8ecd23 --- /dev/null +++ b/tests/test_models/test_utils/test_occlusion_esimation.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmflow.models.utils import occlusion_estimation + + +def test_occlusion_estimation(): + """Test occ estimation.""" + """ + img1 img2 + + | A | B | E | | G | A | B | + ------------- ------------- + | C | D | F | | H | C | D | + + flow_fw flow_bw + |(1, 0)|(1, 0)|(1, 0)| |(-1, 0)|(-1, 0)|(-1, 0)| + ---------------------- ------------------------- + |(1, 0)|(1, 0)|(1, 0)| |(-1, 0)|(-1, 0)|(-1, 0)| + + occ_fw occ_bw + | 1 | 1 | 0 | | 0 | 1 | 1 | + ------------- ------------- + | 1 | 1 | 0 | | 0 | 1 | 1 | + """ + H = 2 + W = 3 + flow_fw = torch.zeros(4, 2, H, W) + flow_fw[:, 0, ...] = 1 + flow_bw = -flow_fw.clone() + + occ_fw = torch.ones(4, 1, H, W) + occ_fw[..., -1] = 0. + occ_bw = torch.ones(4, 1, H, W) + occ_bw[..., 0] = 0. + + # test invalid mode + with pytest.raises(AssertionError): + occlusion_estimation(flow_fw, flow_bw, mode='a') + + # test forward-backward consistency + occ = occlusion_estimation( + flow_fw, + flow_bw, + mode='consistency', + warp_cfg=dict(type='Warp', align_corners=True)) + assert torch.all(occ['occ_fw'] == occ_fw) + assert torch.all(occ['occ_bw'] == occ_bw) + + # test fb_abs + occ = occlusion_estimation( + flow_fw, + flow_bw, + mode='fb_abs', + warp_cfg=dict(type='Warp', align_corners=True), + diff=1.) + assert torch.all(occ['occ_fw'] == occ_fw) + assert torch.all(occ['occ_bw'] == occ_bw) + + # test range map + occ = occlusion_estimation(flow_fw, flow_bw, mode='range_map') + assert torch.all(occ['occ_fw'] == occ_fw) + assert torch.all(occ['occ_bw'] == occ_bw)