Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Add Covid configs #456

Merged
merged 21 commits into from
May 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ console for easier diagnostics.
Additionally, the `TrainHelloWorldAndHelloContainer` job in the PR build has been split into two jobs, `TrainHelloWorld` and
`TrainHelloContainer`. A pytest marker `after_training_hello_container` has been added to run tests after training is
finished in the `TrainHelloContainer` job.
- ([#456](https://github.com/microsoft/InnerEye-DeepLearning/pull/456)) Adding configs to train Covid detection models.

### Changed

Expand Down
39 changes: 38 additions & 1 deletion InnerEye/ML/SSL/datamodules_and_datasets/cxr_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Tuple

import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Subset
from torchvision.datasets import VisionDataset

from InnerEye.Common.type_annotations import PathOrString
Expand Down Expand Up @@ -175,3 +176,39 @@ def _prepare_dataset(self) -> None:
self.dataset_dataframe.Path = self.dataset_dataframe.Path.apply(lambda x: x[strip_n:])
self.indices = np.arange(len(self.dataset_dataframe))
self.filenames = [self.root / p for p in self.dataset_dataframe.Path.values]


class CovidDataset(InnerEyeCXRDatasetWithReturnIndex):
"""
Dataset class to load CovidDataset dataset as datamodule for monitoring SSL training quality directly on
CovidDataset data.
We use CVX03 against CVX12 as proxy task.
"""

def _prepare_dataset(self) -> None:
self.dataset_dataframe = pd.read_csv(self.root / "dataset.csv")
mapping = {0: 0, 3: 0, 1: 1, 2: 1}
# For monitoring purpose with use binary classification CV03vsCV12
self.dataset_dataframe["final_label"] = self.dataset_dataframe.final_label.apply(lambda x: mapping[x])
self.indices = np.arange(len(self.dataset_dataframe))
self.subject_ids = self.dataset_dataframe.subject.values
self.filenames = [self.root / file for file in self.dataset_dataframe.filepath.values]
self.targets = self.dataset_dataframe.final_label.values.astype(np.int64).reshape(-1)

@property
def num_classes(self) -> int:
return 2

def _split_dataset(self, val_split: float, seed: int) -> Tuple[Subset, Subset]:
"""
Implements val - train split.
:param val_split: proportion to use for validation
:param seed: random seed for splitting
:return: dataset_train, dataset_val
"""
shuffled_subject_ids = np.random.RandomState(seed).permutation(np.unique(self.subject_ids))
n_val = int(len(shuffled_subject_ids) * val_split)
val_subjects, train_subjects = shuffled_subject_ids[:n_val], shuffled_subject_ids[n_val:]
train_ids, val_ids = np.where(np.isin(self.subject_ids, train_subjects))[0], \
np.where(np.isin(self.subject_ids, val_subjects))[0]
return Subset(self, train_ids), Subset(self, val_ids)
30 changes: 19 additions & 11 deletions InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from yacs.config import CfgNode

from InnerEye.ML.SSL.datamodules_and_datasets.cifar_datasets import InnerEyeCIFAR10, InnerEyeCIFAR100
from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import CheXpert, NIHCXR, RSNAKaggleCXR
from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import CheXpert, CovidDataset, NIHCXR, RSNAKaggleCXR
from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule, InnerEyeVisionDataModule
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \
InnerEyeCIFARTrainTransform, \
Expand Down Expand Up @@ -42,11 +42,12 @@ class EncoderName(Enum):


class SSLDatasetName(Enum):
RSNAKaggleCXR = "RSNAKaggleCXR"
NIHCXR = "NIHCXR"
CIFAR10 = "CIFAR10"
CIFAR100 = "CIFAR100"
RSNAKaggleCXR = "RSNAKaggleCXR"
NIHCXR = "NIHCXR"
CheXpert = "CheXpert"
Covid = "CovidDataset"


InnerEyeDataModuleTypes = Union[InnerEyeVisionDataModule, CombinedDataModule]
Expand All @@ -62,11 +63,12 @@ class SSLContainer(LightningContainer):
Note that this container is also used as the base class for SSLImageClassifier (finetuning container) as they share
setup and datamodule methods.
"""
_SSLDataClassMappings = {SSLDatasetName.RSNAKaggleCXR.value: RSNAKaggleCXR,
SSLDatasetName.NIHCXR.value: NIHCXR,
SSLDatasetName.CIFAR10.value: InnerEyeCIFAR10,
_SSLDataClassMappings = {SSLDatasetName.CIFAR10.value: InnerEyeCIFAR10,
SSLDatasetName.CIFAR100.value: InnerEyeCIFAR100,
SSLDatasetName.CheXpert.value: CheXpert}
SSLDatasetName.RSNAKaggleCXR.value: RSNAKaggleCXR,
SSLDatasetName.NIHCXR.value: NIHCXR,
SSLDatasetName.CheXpert.value: CheXpert,
SSLDatasetName.Covid.value: CovidDataset}

ssl_augmentation_config = param.ClassSelector(class_=Path, allow_None=True,
doc="The path to the yaml config defining the parameters of the "
Expand All @@ -87,11 +89,13 @@ class SSLContainer(LightningContainer):
"Used for debugging and tests.")
linear_head_augmentation_config = param.ClassSelector(class_=Path,
doc="The path to the yaml config for the linear head "
"augmentations")
"augmentations")
linear_head_dataset_name = param.ClassSelector(class_=SSLDatasetName,
doc="Name of the dataset to use for the linear head training")
linear_head_batch_size = param.Integer(default=256, doc="Batch size for linear head tuning")
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4, doc="Learning rate for linear head training during SSL training.")
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4,
doc="Learning rate for linear head training during "
"SSL training.")

def setup(self) -> None:
from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer
Expand Down Expand Up @@ -173,7 +177,8 @@ def get_data_module(self) -> InnerEyeDataModuleTypes:
return self.data_module
encoder_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=True)
linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False)
return CombinedDataModule(encoder_data_module, linear_data_module, self.use_balanced_binary_loss_for_linear_head)
return CombinedDataModule(encoder_data_module, linear_data_module,
self.use_balanced_binary_loss_for_linear_head)

def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisionDataModule:
"""
Expand Down Expand Up @@ -220,7 +225,10 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode],
applied on. If False, return only one transformation.
:return: training transformation pipeline and validation transformation pipeline.
"""
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value, SSLDatasetName.NIHCXR.value, SSLDatasetName.CheXpert.value]:
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value,
SSLDatasetName.NIHCXR.value,
SSLDatasetName.CheXpert.value,
SSLDatasetName.Covid.value]:
assert augmentation_config is not None
train_transforms, val_transforms = get_cxr_ssl_transforms(augmentation_config,
return_two_views_per_sample=is_ssl_encoder_module,
Expand Down
Loading