Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Adding mutually exclusive label check to load_labels_from_dataset_source #454

Merged
merged 3 commits into from
May 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Tests/logs
Tests/ML/models/logs
saved_models/
Tests/ML/test_data/outputs
Tests/ML/test_data/cxr_test_dataset
azureml-models
# Tests
junit
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ created.

### Added

- ([#454](https://github.com/microsoft/InnerEye-DeepLearning/pull/454)) Checking that labels are mutually exclusive.
- ([#447](https://github.com/microsoft/InnerEye-DeepLearning/pull/447/)) Added a sanity check to ensure there are no
missing channels, nor missing files. If missing channels in the csv file or filenames associated with channels are
incorrect, pipeline exits with error report before running training or inference.
Expand Down
6 changes: 6 additions & 0 deletions InnerEye/ML/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,12 @@ class SegmentationModelBase(ModelConfigBase):
"patch sampling will be shown. Nifti images and thumbnails for each"
"of the first N subjects in the training set will be "
"written to the outputs folder.")
#: If true an error is raised in InnerEye.ML.utils.io_util.load_labels_from_dataset_source if the labels are not
#: mutually exclusive. Some loss functions (e.g. SoftDice) may produce results on overlapping labels, but others (e.g.
#: FocalLoss) will fail with a cryptic error message. Set to false if you are sure that you want to use labels that
#: are not mutually exclusive.

check_exclusive: bool = param.Boolean(True, doc="Raise an error if the segmentation labels are not mutually exclusive.")
dumbledad marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, center_size: Optional[TupleInt3] = None,
inference_stride_size: Optional[TupleInt3] = None,
Expand Down
3 changes: 3 additions & 0 deletions InnerEye/ML/configs/segmentation/HelloWorld.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def __init__(self, **kwargs: Any) -> None:
# we define the structure names and colours to use.
ground_truth_ids_display_names=fg_classes,
colours=generate_random_colours_list(Random(5), len(fg_classes)),

# The HelloWorld model uses dummy data with overlapping segmentation labels
check_exclusive=False
)
self.add_and_validate(kwargs)

Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/dataset/full_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _extension_from_df_file_paths(file_paths: List[str]) -> str:
def get_samples_at_index(self, index: int) -> List[Sample]:
# load the channels into memory
ds = self.dataset_sources[self.dataset_indices[index]]
samples = [io_util.load_images_from_dataset_source(dataset_source=ds)] # type: ignore
samples = [io_util.load_images_from_dataset_source(dataset_source=ds, check_exclusive=self.args.check_exclusive)] # type: ignore
return [Compose3D.apply(self.full_image_sample_transforms, x) for x in samples]

def _load_dataset_sources(self) -> Dict[str, PatientDatasetSource]:
Expand Down
19 changes: 14 additions & 5 deletions InnerEye/ML/utils/io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,22 +412,30 @@ def load_image_in_known_formats(file: Path,
raise ValueError(f"Unsupported image file type for path {file}")


def load_labels_from_dataset_source(dataset_source: PatientDatasetSource) -> np.ndarray:
def load_labels_from_dataset_source(dataset_source: PatientDatasetSource, check_exclusive: bool = True) -> np.ndarray:
"""
Load labels containing segmentation binary labels in one-hot-encoding.
In the future, this function will be used to load global class and non-imaging information as well.

:param dataset_source: The dataset source for which channels are to be loaded into memory.
:return A label sample object containing ground-truth information.
:param check_exclusive: Check that the labels are mutually exclusive (defaults to True)
:return: A label sample object containing ground-truth information.
"""
labels = np.stack(
[load_image(gt, ImageDataType.SEGMENTATION.value).image for gt in dataset_source.ground_truth_channels])

if check_exclusive and (sum(labels) > 1.).any(): # type: ignore
raise ValueError(f'The labels for patient {dataset_source.metadata.patient_id} are not mutually exclusive. '
'Some loss functions (e.g. SoftDice) may produce results on overlapping labels, while others (e.g. FocalLoss) will fail. '
'If you are sure that you want to use mutually exclusive labels, '
'then re-run with the check_exclusive flag set to false in the settings file. '
'Note that this is the first error encountered, other samples/patients may also have overlapping labels.')

# Add the background binary map
background = np.ones_like(labels[0])
for c in range(len(labels)):
background[labels[c] == 1] = 0
background = background[None, ...]
background = background[np.newaxis, ...]
return np.vstack((background, labels))


Expand Down Expand Up @@ -475,12 +483,13 @@ def load_image(path: PathOrString, image_type: Optional[Type] = float) -> ImageW
raise ValueError(f"Invalid file type {path}")


def load_images_from_dataset_source(dataset_source: PatientDatasetSource) -> Sample:
def load_images_from_dataset_source(dataset_source: PatientDatasetSource, check_exclusive: bool = True) -> Sample:
"""
Load images. ground truth labels and masks from the provided dataset source.
With an inferred label class for the background (assumed to be not provided in the input)

:param dataset_source: The dataset source for which channels are to be loaded into memory.
:param check_exclusive: Check that the labels are mutually exclusive (defaults to True)
:return: a Sample object with the loaded volume (image), labels, mask and metadata.
"""
images = [load_image(channel, ImageDataType.IMAGE.value) for channel in dataset_source.image_channels]
Expand All @@ -492,7 +501,7 @@ def load_images_from_dataset_source(dataset_source: PatientDatasetSource) -> Sam
# create raw sample to return
metadata = copy(dataset_source.metadata)
metadata.image_header = images[0].header
labels = load_labels_from_dataset_source(dataset_source)
labels = load_labels_from_dataset_source(dataset_source, check_exclusive=check_exclusive)
return Sample(image=image,
labels=labels,
mask=mask,
Expand Down
1 change: 0 additions & 1 deletion Tests/ML/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

crop_size = [55, 55, 55]


@pytest.fixture
def num_dataload_workers() -> int:
"""PyTorch support for multiple dataloader workers is flaky on Windows (so return 0)"""
Expand Down
1 change: 1 addition & 0 deletions Tests/ML/test_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _mean_list(lists: List[List[float]]) -> List[float]:
train_config.class_weights = [0.5, 0.25, 0.25]
train_config.store_dataset_sample = True
train_config.recovery_checkpoint_save_interval = 1
train_config.check_exclusive = False

if machine_has_gpu:
expected_train_losses = [0.4552919, 0.4548529]
Expand Down
19 changes: 14 additions & 5 deletions Tests/ML/utils/test_io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,35 @@ def test_nii_load_zyx(test_output_dirs: OutputFolderForTests) -> None:
@pytest.mark.parametrize("ground_truth_channel",
[None, known_nii_path, f"{good_h5_path}|segmentation|0|1", good_npy_path])
@pytest.mark.parametrize("mask_channel", [None, known_nii_path, good_npy_path])
@pytest.mark.parametrize("check_exclusive", [True, False])
def test_load_images_from_dataset_source(
metadata: Optional[str],
image_channel: Optional[str],
ground_truth_channel: Optional[str],
mask_channel: Optional[str]) -> None:
mask_channel: Optional[str],
check_exclusive: bool) -> None:
"""
Test if images are loaded as expected from channels
"""
# metadata, image and GT channels must be present. Mask is optional
if None in [metadata, image_channel, ground_truth_channel]:
with pytest.raises(Exception):
_test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel)
_test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel, check_exclusive)
else:
_test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel)
if check_exclusive:
with pytest.raises(ValueError) as mutually_exclusive_labels_error:
_test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel, check_exclusive)
assert 'not mutually exclusive' in str(mutually_exclusive_labels_error.value)
else:
_test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel, check_exclusive)


def _test_load_images_from_channels(
metadata: Any,
image_channel: Any,
ground_truth_channel: Any,
mask_channel: Any) -> None:
mask_channel: Any,
check_exclusive: bool) -> None:
"""
Test if images are loaded as expected from channels
"""
Expand All @@ -102,7 +110,8 @@ def _test_load_images_from_channels(
image_channels=[image_channel] * 2,
ground_truth_channels=[ground_truth_channel] * 4,
mask_channel=mask_channel
)
),
check_exclusive=check_exclusive
)
if image_channel:
image_with_header = io_util.load_image(image_channel)
Expand Down