diff --git a/.gitignore b/.gitignore index dff0c9792..120d80cf5 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ Tests/logs Tests/ML/models/logs saved_models/ Tests/ML/test_data/outputs +Tests/ML/test_data/cxr_test_dataset azureml-models # Tests junit diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fab68290..f03936a47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ created. ### Added +- ([#454](https://github.com/microsoft/InnerEye-DeepLearning/pull/454)) Checking that labels are mutually exclusive. - ([#447](https://github.com/microsoft/InnerEye-DeepLearning/pull/447/)) Added a sanity check to ensure there are no missing channels, nor missing files. If missing channels in the csv file or filenames associated with channels are incorrect, pipeline exits with error report before running training or inference. diff --git a/InnerEye/ML/config.py b/InnerEye/ML/config.py index 796d626e6..79a8a4cb0 100644 --- a/InnerEye/ML/config.py +++ b/InnerEye/ML/config.py @@ -479,6 +479,12 @@ class SegmentationModelBase(ModelConfigBase): "patch sampling will be shown. Nifti images and thumbnails for each" "of the first N subjects in the training set will be " "written to the outputs folder.") + #: If true an error is raised in InnerEye.ML.utils.io_util.load_labels_from_dataset_source if the labels are not + #: mutually exclusive. Some loss functions (e.g. SoftDice) may produce results on overlapping labels, but others (e.g. + #: FocalLoss) will fail with a cryptic error message. Set to false if you are sure that you want to use labels that + #: are not mutually exclusive. + + check_exclusive: bool = param.Boolean(True, doc="Raise an error if the segmentation labels are not mutually exclusive.") def __init__(self, center_size: Optional[TupleInt3] = None, inference_stride_size: Optional[TupleInt3] = None, diff --git a/InnerEye/ML/configs/segmentation/HelloWorld.py b/InnerEye/ML/configs/segmentation/HelloWorld.py index 174d97c45..e9fd7d242 100644 --- a/InnerEye/ML/configs/segmentation/HelloWorld.py +++ b/InnerEye/ML/configs/segmentation/HelloWorld.py @@ -77,6 +77,9 @@ def __init__(self, **kwargs: Any) -> None: # we define the structure names and colours to use. ground_truth_ids_display_names=fg_classes, colours=generate_random_colours_list(Random(5), len(fg_classes)), + + # The HelloWorld model uses dummy data with overlapping segmentation labels + check_exclusive=False ) self.add_and_validate(kwargs) diff --git a/InnerEye/ML/dataset/full_image_dataset.py b/InnerEye/ML/dataset/full_image_dataset.py index 5e81461a7..aa3bf2dff 100644 --- a/InnerEye/ML/dataset/full_image_dataset.py +++ b/InnerEye/ML/dataset/full_image_dataset.py @@ -250,7 +250,7 @@ def _extension_from_df_file_paths(file_paths: List[str]) -> str: def get_samples_at_index(self, index: int) -> List[Sample]: # load the channels into memory ds = self.dataset_sources[self.dataset_indices[index]] - samples = [io_util.load_images_from_dataset_source(dataset_source=ds)] # type: ignore + samples = [io_util.load_images_from_dataset_source(dataset_source=ds, check_exclusive=self.args.check_exclusive)] # type: ignore return [Compose3D.apply(self.full_image_sample_transforms, x) for x in samples] def _load_dataset_sources(self) -> Dict[str, PatientDatasetSource]: diff --git a/InnerEye/ML/utils/io_util.py b/InnerEye/ML/utils/io_util.py index 06194ff00..8229aa823 100644 --- a/InnerEye/ML/utils/io_util.py +++ b/InnerEye/ML/utils/io_util.py @@ -412,22 +412,30 @@ def load_image_in_known_formats(file: Path, raise ValueError(f"Unsupported image file type for path {file}") -def load_labels_from_dataset_source(dataset_source: PatientDatasetSource) -> np.ndarray: +def load_labels_from_dataset_source(dataset_source: PatientDatasetSource, check_exclusive: bool = True) -> np.ndarray: """ Load labels containing segmentation binary labels in one-hot-encoding. In the future, this function will be used to load global class and non-imaging information as well. :param dataset_source: The dataset source for which channels are to be loaded into memory. - :return A label sample object containing ground-truth information. + :param check_exclusive: Check that the labels are mutually exclusive (defaults to True) + :return: A label sample object containing ground-truth information. """ labels = np.stack( [load_image(gt, ImageDataType.SEGMENTATION.value).image for gt in dataset_source.ground_truth_channels]) + if check_exclusive and (sum(labels) > 1.).any(): # type: ignore + raise ValueError(f'The labels for patient {dataset_source.metadata.patient_id} are not mutually exclusive. ' + 'Some loss functions (e.g. SoftDice) may produce results on overlapping labels, while others (e.g. FocalLoss) will fail. ' + 'If you are sure that you want to use mutually exclusive labels, ' + 'then re-run with the check_exclusive flag set to false in the settings file. ' + 'Note that this is the first error encountered, other samples/patients may also have overlapping labels.') + # Add the background binary map background = np.ones_like(labels[0]) for c in range(len(labels)): background[labels[c] == 1] = 0 - background = background[None, ...] + background = background[np.newaxis, ...] return np.vstack((background, labels)) @@ -475,12 +483,13 @@ def load_image(path: PathOrString, image_type: Optional[Type] = float) -> ImageW raise ValueError(f"Invalid file type {path}") -def load_images_from_dataset_source(dataset_source: PatientDatasetSource) -> Sample: +def load_images_from_dataset_source(dataset_source: PatientDatasetSource, check_exclusive: bool = True) -> Sample: """ Load images. ground truth labels and masks from the provided dataset source. With an inferred label class for the background (assumed to be not provided in the input) :param dataset_source: The dataset source for which channels are to be loaded into memory. + :param check_exclusive: Check that the labels are mutually exclusive (defaults to True) :return: a Sample object with the loaded volume (image), labels, mask and metadata. """ images = [load_image(channel, ImageDataType.IMAGE.value) for channel in dataset_source.image_channels] @@ -492,7 +501,7 @@ def load_images_from_dataset_source(dataset_source: PatientDatasetSource) -> Sam # create raw sample to return metadata = copy(dataset_source.metadata) metadata.image_header = images[0].header - labels = load_labels_from_dataset_source(dataset_source) + labels = load_labels_from_dataset_source(dataset_source, check_exclusive=check_exclusive) return Sample(image=image, labels=labels, mask=mask, diff --git a/Tests/ML/datasets/test_dataset.py b/Tests/ML/datasets/test_dataset.py index 9ffad6b00..6fc446f2d 100644 --- a/Tests/ML/datasets/test_dataset.py +++ b/Tests/ML/datasets/test_dataset.py @@ -28,7 +28,6 @@ crop_size = [55, 55, 55] - @pytest.fixture def num_dataload_workers() -> int: """PyTorch support for multiple dataloader workers is flaky on Windows (so return 0)""" diff --git a/Tests/ML/test_model_training.py b/Tests/ML/test_model_training.py index a29dbb960..6b37481a2 100644 --- a/Tests/ML/test_model_training.py +++ b/Tests/ML/test_model_training.py @@ -100,6 +100,7 @@ def _mean_list(lists: List[List[float]]) -> List[float]: train_config.class_weights = [0.5, 0.25, 0.25] train_config.store_dataset_sample = True train_config.recovery_checkpoint_save_interval = 1 + train_config.check_exclusive = False if machine_has_gpu: expected_train_losses = [0.4552919, 0.4548529] diff --git a/Tests/ML/utils/test_io_util.py b/Tests/ML/utils/test_io_util.py index 5399516a3..141c997b8 100644 --- a/Tests/ML/utils/test_io_util.py +++ b/Tests/ML/utils/test_io_util.py @@ -72,27 +72,35 @@ def test_nii_load_zyx(test_output_dirs: OutputFolderForTests) -> None: @pytest.mark.parametrize("ground_truth_channel", [None, known_nii_path, f"{good_h5_path}|segmentation|0|1", good_npy_path]) @pytest.mark.parametrize("mask_channel", [None, known_nii_path, good_npy_path]) +@pytest.mark.parametrize("check_exclusive", [True, False]) def test_load_images_from_dataset_source( metadata: Optional[str], image_channel: Optional[str], ground_truth_channel: Optional[str], - mask_channel: Optional[str]) -> None: + mask_channel: Optional[str], + check_exclusive: bool) -> None: """ Test if images are loaded as expected from channels """ # metadata, image and GT channels must be present. Mask is optional if None in [metadata, image_channel, ground_truth_channel]: with pytest.raises(Exception): - _test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel) + _test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel, check_exclusive) else: - _test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel) + if check_exclusive: + with pytest.raises(ValueError) as mutually_exclusive_labels_error: + _test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel, check_exclusive) + assert 'not mutually exclusive' in str(mutually_exclusive_labels_error.value) + else: + _test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel, check_exclusive) def _test_load_images_from_channels( metadata: Any, image_channel: Any, ground_truth_channel: Any, - mask_channel: Any) -> None: + mask_channel: Any, + check_exclusive: bool) -> None: """ Test if images are loaded as expected from channels """ @@ -102,7 +110,8 @@ def _test_load_images_from_channels( image_channels=[image_channel] * 2, ground_truth_channels=[ground_truth_channel] * 4, mask_channel=mask_channel - ) + ), + check_exclusive=check_exclusive ) if image_channel: image_with_header = io_util.load_image(image_channel)