Skip to content

Commit

Permalink
Fix semantic seg ground truth mask extraction (#3270)
Browse files Browse the repository at this point in the history
* Fix semantic seg gt mask extraction

 - ignore_index is now correctly working

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>

* Increase test coverage

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>

---------

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
  • Loading branch information
vinnamkim authored Apr 8, 2024
1 parent 57a80b8 commit 61196fa
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/otx/core/config/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class DataModuleConfig:
stack_images: bool = True

include_polygons: bool = False
ignore_index: int = 255
unannotated_items_ratio: float = 0.0

auto_num_workers: bool = False
Expand Down
121 changes: 112 additions & 9 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from functools import partial
from typing import TYPE_CHECKING, Callable

import cv2
import numpy as np
import torch
from datumaro.components.annotation import Image, Mask
from torchvision import tv_tensors

Expand All @@ -23,7 +23,113 @@
from .base import OTXDataset

if TYPE_CHECKING:
from datumaro import DatasetSubset
from datumaro import DatasetItem, DatasetSubset


# NOTE: It is copied from https://github.com/openvinotoolkit/datumaro/pull/1409
# It will be replaced in the future.
def _make_index_mask(
binary_mask: np.ndarray,
index: int,
ignore_index: int = 0,
dtype: np.dtype | None = None,
) -> np.ndarray:
"""Create an index mask from a binary mask by filling a given index value.
Args:
binary_mask: Binary mask to create an index mask.
index: Scalar value to fill the ones in the binary mask.
ignore_index: Scalar value to fill in the zeros in the binary mask.
Defaults to 0.
dtype: Data type for the resulting mask. If not specified,
it will be inferred from the provided index. Defaults to None.
Returns:
np.ndarray: Index mask created from the binary mask.
Raises:
ValueError: If dtype is not specified and incompatible scalar types are used for index
and ignore_index.
Examples:
>>> binary_mask = np.eye(2, dtype=np.bool_)
>>> index_mask = make_index_mask(binary_mask, index=10, ignore_index=255, dtype=np.uint8)
>>> print(index_mask)
array([[ 10, 255],
[255, 10]], dtype=uint8)
"""
if dtype is None:
dtype = np.min_scalar_type(index)
if dtype != np.min_scalar_type(ignore_index):
raise ValueError

flipped_zero_np_scalar = ~np.full((), fill_value=0, dtype=dtype)

# NOTE: This dispatching rule is required for a performance boost
if ignore_index == flipped_zero_np_scalar:
flipped_index = ~np.full((), fill_value=index, dtype=dtype)
return ~(binary_mask * flipped_index)

mask = binary_mask * np.full((), fill_value=index, dtype=dtype)

if ignore_index == 0:
return mask

return np.where(binary_mask, mask, ignore_index)


def _extract_class_mask(item: DatasetItem, img_shape: tuple[int, int], ignore_index: int) -> np.ndarray:
"""Extract class mask from Datumaro masks.
This is a temporary workaround and will be replaced with the native Datumaro interfaces
after some works, e.g., https://github.com/openvinotoolkit/datumaro/pull/1409 are done.
Args:
item: Datumaro dataset item having mask annotations.
img_shape: Image shape (H, W).
ignore_index: Scalar value to fill in the zeros in the binary mask.
Returns:
2D numpy array
"""
if ignore_index > 255:
msg = "It is not currently support an ignore index which is more than 255."
raise ValueError(msg, ignore_index)

class_mask = np.full(shape=img_shape[:2], fill_value=ignore_index, dtype=np.uint8)

for mask in sorted(
[ann for ann in item.annotations if isinstance(ann, Mask)],
key=lambda ann: ann.z_order,
):
binary_mask = mask.image
index = mask.label

if index is None:
msg = "Mask's label index should not be None."
raise ValueError(msg)

if index > 255:
msg = "Mask's label index should not be more than 255."
raise ValueError(msg, index)

this_class_mask = _make_index_mask(
binary_mask=binary_mask,
index=index,
ignore_index=ignore_index,
dtype=np.uint8,
)

if this_class_mask.shape != img_shape:
this_class_mask = cv2.resize(
this_class_mask,
dsize=(img_shape[1], img_shape[0]), # NOTE: cv2.resize() uses (width, height) format
interpolation=cv2.INTER_NEAREST,
)

class_mask = np.where(this_class_mask != ignore_index, this_class_mask, class_mask)

return class_mask


class OTXSegmentationDataset(OTXDataset[SegDataEntity]):
Expand All @@ -38,6 +144,7 @@ def __init__(
max_refetch: int = 1000,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
ignore_index: int = 255,
) -> None:
super().__init__(
dm_subset,
Expand All @@ -52,6 +159,7 @@ def __init__(
label_names=self.label_info.label_names,
label_groups=self.label_info.label_groups,
)
self.ignore_index = ignore_index

def _get_item_impl(self, index: int) -> SegDataEntity | None:
item = self.dm_subset.get(id=self.ids[index], subset=self.dm_subset.name)
Expand All @@ -60,13 +168,8 @@ def _get_item_impl(self, index: int) -> SegDataEntity | None:
ignored_labels: list[int] = []
img_data, img_shape = self._get_img_data_and_shape(img)

# create 2D class mask. We use np.sum() since Datumaro returns 3D masks (one for each class)
mask_anns = np.sum(
[ann.as_class_mask() for ann in item.annotations if isinstance(ann, Mask)],
axis=0,
dtype=np.uint8,
)
mask = torch.as_tensor(mask_anns, dtype=torch.long)
mask = _extract_class_mask(item=item, img_shape=img_shape, ignore_index=self.ignore_index)

# assign possible ignored labels from dataset to max label class + 1.
# it is needed to compute mDice metric.
mask[mask == 255] = num_classes
Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def create( # noqa: PLR0911 # ignore too many return statements
if task == OTXTaskType.SEMANTIC_SEGMENTATION:
from .dataset.segmentation import OTXSegmentationDataset

return OTXSegmentationDataset(**common_kwargs)
return OTXSegmentationDataset(**common_kwargs, ignore_index=cfg_data_module.ignore_index)

if task == OTXTaskType.ACTION_CLASSIFICATION:
from .dataset.action_classification import OTXActionClsDataset
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/core/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def fxt_dm_item(request) -> DatasetItem:
annotations=[
Label(label=0),
Bbox(x=0, y=0, w=1, h=1, label=0),
Mask(label=0, image=np.zeros(shape=(10, 10), dtype=np.uint8)),
Mask(label=0, image=np.eye(10, dtype=np.uint8)),
Polygon(points=[399.0, 570.0, 397.0, 572.0, 397.0, 573.0, 394.0, 576.0], label=0),
],
)
Expand Down
54 changes: 54 additions & 0 deletions tests/unit/core/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from unittest.mock import MagicMock

import pytest
from datumaro.components.annotation import Mask
from otx.core.data.dataset.action_classification import OTXActionClsDataset
from otx.core.data.dataset.classification import HLabelInfo
from otx.core.data.dataset.segmentation import OTXSegmentationDataset


class TestDataset:
Expand Down Expand Up @@ -87,3 +89,55 @@ def test_mem_cache_resize(

assert item.image.shape[:2] == (h_expected, w_expected)
assert item.img_info.img_shape == (h_expected, w_expected)


class TestOTXSegmentationDataset:
def test_ignore_index(self, fxt_mock_dm_subset):
dataset = OTXSegmentationDataset(
dm_subset=fxt_mock_dm_subset,
transforms=lambda x: x,
mem_cache_img_max_size=None,
ignore_index=100,
)

# The mask is np.eye(10) with label_id = 0,
# so that the diagonal is filled with zero
# and others are filled with ignore_index.
gt_seg_map = next(iter(dataset)).gt_seg_map
assert gt_seg_map.sum() == (10 * 10 - 10) * 100

def test_overflown_ignore_index(self, fxt_mock_dm_subset):
dataset = OTXSegmentationDataset(
dm_subset=fxt_mock_dm_subset,
transforms=lambda x: x,
mem_cache_img_max_size=None,
ignore_index=65536,
)
with pytest.raises(
ValueError,
match="It is not currently support an ignore index which is more than 255.",
):
_ = next(iter(dataset))

@pytest.fixture(params=["none", "overflow"])
def fxt_invalid_label(self, fxt_dm_item, monkeypatch, request):
for ann in fxt_dm_item.annotations:
if isinstance(ann, Mask):
if request.param == "none":
monkeypatch.setattr(ann, "label", None)
elif request.param == "overflow":
monkeypatch.setattr(ann, "label", 65536)

def test_overflown_label(self, fxt_invalid_label, fxt_mock_dm_subset):
dataset = OTXSegmentationDataset(
dm_subset=fxt_mock_dm_subset,
transforms=lambda x: x,
mem_cache_img_max_size=None,
ignore_index=100,
)

with pytest.raises(
ValueError,
match="Mask's label index should not be (.*).",
):
_ = next(iter(dataset))

0 comments on commit 61196fa

Please # to comment.