diff --git a/CHANGELOG.md b/CHANGELOG.md index b2969d7fe..199d1de17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,6 +84,7 @@ console for easier diagnostics. Additionally, the `TrainHelloWorldAndHelloContainer` job in the PR build has been split into two jobs, `TrainHelloWorld` and `TrainHelloContainer`. A pytest marker `after_training_hello_container` has been added to run tests after training is finished in the `TrainHelloContainer` job. +- ([#456](https://github.com/microsoft/InnerEye-DeepLearning/pull/456)) Adding configs to train Covid detection models. ### Changed diff --git a/InnerEye/ML/SSL/datamodules_and_datasets/cxr_datasets.py b/InnerEye/ML/SSL/datamodules_and_datasets/cxr_datasets.py index 3742b13d7..5d6d38d64 100644 --- a/InnerEye/ML/SSL/datamodules_and_datasets/cxr_datasets.py +++ b/InnerEye/ML/SSL/datamodules_and_datasets/cxr_datasets.py @@ -4,11 +4,12 @@ # ------------------------------------------------------------------------------------------ import logging from pathlib import Path -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Tuple import numpy as np import pandas as pd from PIL import Image +from torch.utils.data import Subset from torchvision.datasets import VisionDataset from InnerEye.Common.type_annotations import PathOrString @@ -175,3 +176,39 @@ def _prepare_dataset(self) -> None: self.dataset_dataframe.Path = self.dataset_dataframe.Path.apply(lambda x: x[strip_n:]) self.indices = np.arange(len(self.dataset_dataframe)) self.filenames = [self.root / p for p in self.dataset_dataframe.Path.values] + + +class CovidDataset(InnerEyeCXRDatasetWithReturnIndex): + """ + Dataset class to load CovidDataset dataset as datamodule for monitoring SSL training quality directly on + CovidDataset data. + We use CVX03 against CVX12 as proxy task. + """ + + def _prepare_dataset(self) -> None: + self.dataset_dataframe = pd.read_csv(self.root / "dataset.csv") + mapping = {0: 0, 3: 0, 1: 1, 2: 1} + # For monitoring purpose with use binary classification CV03vsCV12 + self.dataset_dataframe["final_label"] = self.dataset_dataframe.final_label.apply(lambda x: mapping[x]) + self.indices = np.arange(len(self.dataset_dataframe)) + self.subject_ids = self.dataset_dataframe.subject.values + self.filenames = [self.root / file for file in self.dataset_dataframe.filepath.values] + self.targets = self.dataset_dataframe.final_label.values.astype(np.int64).reshape(-1) + + @property + def num_classes(self) -> int: + return 2 + + def _split_dataset(self, val_split: float, seed: int) -> Tuple[Subset, Subset]: + """ + Implements val - train split. + :param val_split: proportion to use for validation + :param seed: random seed for splitting + :return: dataset_train, dataset_val + """ + shuffled_subject_ids = np.random.RandomState(seed).permutation(np.unique(self.subject_ids)) + n_val = int(len(shuffled_subject_ids) * val_split) + val_subjects, train_subjects = shuffled_subject_ids[:n_val], shuffled_subject_ids[n_val:] + train_ids, val_ids = np.where(np.isin(self.subject_ids, train_subjects))[0], \ + np.where(np.isin(self.subject_ids, val_subjects))[0] + return Subset(self, train_ids), Subset(self, val_ids) diff --git a/InnerEye/ML/SSL/lightning_containers/ssl_container.py b/InnerEye/ML/SSL/lightning_containers/ssl_container.py index 781617946..12e971679 100644 --- a/InnerEye/ML/SSL/lightning_containers/ssl_container.py +++ b/InnerEye/ML/SSL/lightning_containers/ssl_container.py @@ -13,7 +13,7 @@ from yacs.config import CfgNode from InnerEye.ML.SSL.datamodules_and_datasets.cifar_datasets import InnerEyeCIFAR10, InnerEyeCIFAR100 -from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import CheXpert, NIHCXR, RSNAKaggleCXR +from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import CheXpert, CovidDataset, NIHCXR, RSNAKaggleCXR from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule, InnerEyeVisionDataModule from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \ InnerEyeCIFARTrainTransform, \ @@ -42,11 +42,12 @@ class EncoderName(Enum): class SSLDatasetName(Enum): - RSNAKaggleCXR = "RSNAKaggleCXR" - NIHCXR = "NIHCXR" CIFAR10 = "CIFAR10" CIFAR100 = "CIFAR100" + RSNAKaggleCXR = "RSNAKaggleCXR" + NIHCXR = "NIHCXR" CheXpert = "CheXpert" + Covid = "CovidDataset" InnerEyeDataModuleTypes = Union[InnerEyeVisionDataModule, CombinedDataModule] @@ -62,11 +63,12 @@ class SSLContainer(LightningContainer): Note that this container is also used as the base class for SSLImageClassifier (finetuning container) as they share setup and datamodule methods. """ - _SSLDataClassMappings = {SSLDatasetName.RSNAKaggleCXR.value: RSNAKaggleCXR, - SSLDatasetName.NIHCXR.value: NIHCXR, - SSLDatasetName.CIFAR10.value: InnerEyeCIFAR10, + _SSLDataClassMappings = {SSLDatasetName.CIFAR10.value: InnerEyeCIFAR10, SSLDatasetName.CIFAR100.value: InnerEyeCIFAR100, - SSLDatasetName.CheXpert.value: CheXpert} + SSLDatasetName.RSNAKaggleCXR.value: RSNAKaggleCXR, + SSLDatasetName.NIHCXR.value: NIHCXR, + SSLDatasetName.CheXpert.value: CheXpert, + SSLDatasetName.Covid.value: CovidDataset} ssl_augmentation_config = param.ClassSelector(class_=Path, allow_None=True, doc="The path to the yaml config defining the parameters of the " @@ -87,11 +89,13 @@ class SSLContainer(LightningContainer): "Used for debugging and tests.") linear_head_augmentation_config = param.ClassSelector(class_=Path, doc="The path to the yaml config for the linear head " - "augmentations") + "augmentations") linear_head_dataset_name = param.ClassSelector(class_=SSLDatasetName, doc="Name of the dataset to use for the linear head training") linear_head_batch_size = param.Integer(default=256, doc="Batch size for linear head tuning") - learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4, doc="Learning rate for linear head training during SSL training.") + learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4, + doc="Learning rate for linear head training during " + "SSL training.") def setup(self) -> None: from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer @@ -173,7 +177,8 @@ def get_data_module(self) -> InnerEyeDataModuleTypes: return self.data_module encoder_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=True) linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False) - return CombinedDataModule(encoder_data_module, linear_data_module, self.use_balanced_binary_loss_for_linear_head) + return CombinedDataModule(encoder_data_module, linear_data_module, + self.use_balanced_binary_loss_for_linear_head) def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisionDataModule: """ @@ -220,7 +225,10 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode], applied on. If False, return only one transformation. :return: training transformation pipeline and validation transformation pipeline. """ - if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value, SSLDatasetName.NIHCXR.value, SSLDatasetName.CheXpert.value]: + if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value, + SSLDatasetName.NIHCXR.value, + SSLDatasetName.CheXpert.value, + SSLDatasetName.Covid.value]: assert augmentation_config is not None train_transforms, val_transforms = get_cxr_ssl_transforms(augmentation_config, return_two_views_per_sample=is_ssl_encoder_module, diff --git a/InnerEye/ML/configs/classification/CovidHierarchicalModel.py b/InnerEye/ML/configs/classification/CovidHierarchicalModel.py new file mode 100644 index 000000000..b6a176742 --- /dev/null +++ b/InnerEye/ML/configs/classification/CovidHierarchicalModel.py @@ -0,0 +1,247 @@ +import codecs +import logging +import pickle +from pathlib import Path + +from typing import Any, Callable + +import PIL +import pandas as pd +import param +import torch +from PIL import Image +from pytorch_lightning import LightningModule +from torchvision.transforms import Compose + +from InnerEye.Common.common_util import ModelProcessing, get_best_epoch_results_path +from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import create_chest_xray_transform +from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName + +from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier +from InnerEye.ML.SSL.utils import create_ssl_encoder, create_ssl_image_classifier, load_ssl_augmentation_config +from InnerEye.ML.common import ModelExecutionMode + +from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_linear_head_augmentation_cxr +from InnerEye.ML.deep_learning_config import LRSchedulerType, MultiprocessingStartMethod, \ + OptimizerType + +from InnerEye.ML.model_config_base import ModelTransformsPerExecutionMode +from InnerEye.ML.model_testing import MODEL_OUTPUT_CSV + +from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImagingFeatureType +from InnerEye.ML.reports.notebook_report import generate_notebook, get_ipynb_report_name, str_or_empty + +from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase +from InnerEye.ML.utils.augmentation import ScalarItemAugmentation +from InnerEye.ML.utils.run_recovery import RunRecovery +from InnerEye.ML.utils.split_dataset import DatasetSplits + +from InnerEye.ML.configs.ssl.CovidContainers import COVID_DATASET_ID +from InnerEye.Common import fixed_paths as fixed_paths_innereye + + +class CovidHierarchicalModel(ScalarModelBase): + """ + Model to train a CovidDataset model from scratch or finetune from SSL-pretrained model. + + For AML you need to provide the run_id of your SSL training job as a command line argument + --pretraining_run_recovery_id=id_of_your_ssl_model, this will download the checkpoints of the run to your + machine and load the corresponding pretrained model. + + To recover from a particular checkpoint from your SSL run e.g. "recovery_epoch=499.ckpt" please use hte + --name_of_checkpoint argument. + """ + use_pretrained_model = param.Boolean(default=False, doc="If True, start training from a model pretrained with SSL." + "If False, start training a DenseNet model from scratch" + "(random initialization).") + freeze_encoder = param.Boolean(default=False, doc="Whether to freeze the pretrained encoder or not.") + name_of_checkpoint = param.String(default=None, doc="Filename of checkpoint to use for recovery") + test_set_ids_csv = param.String(default=None, + doc="Name of the csv file in the dataset folder with the test set ids. The dataset" + "is expected to have a 'series' and a 'subject' column. The subject column" + "is assumed to contain unique ids.") + + def __init__(self, covid_dataset_id: str = COVID_DATASET_ID, **kwargs: Any): + learning_rate = 1e-5 if self.use_pretrained_model else 1e-4 + super().__init__(target_names=['CVX03vs12', 'CVX0vs3', 'CVX1vs2'], + loss_type=ScalarLoss.CustomClassification, + class_names=['CVX0', 'CVX1', 'CVX2', 'CVX3'], + max_num_gpus=1, + azure_dataset_id=covid_dataset_id, + subject_column="series", + image_file_column="filepath", + label_value_column="final_label", + non_image_feature_channels=[], + numerical_columns=[], + use_mixed_precision=False, + num_dataload_workers=12, + multiprocessing_start_method=MultiprocessingStartMethod.fork, + train_batch_size=64, + optimizer_type=OptimizerType.Adam, + num_epochs=50, + l_rate_scheduler=LRSchedulerType.Step, + l_rate_step_gamma=1.0, + l_rate=learning_rate, + l_rate_multi_step_milestones=None, + **kwargs) + self.num_classes = 3 + if not self.use_pretrained_model and self.freeze_encoder: + raise ValueError("No encoder to freeze when training from scratch. You requested training from scratch and" + "encoder freezing.") + + def should_generate_multilabel_report(self) -> bool: + return False + + def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits: + if self.test_set_ids_csv: + test_df = pd.read_csv(self.local_dataset / self.test_set_ids_csv) + in_test_set = dataset_df.series.isin(test_df.series) + train_ids = dataset_df.series[~in_test_set].values + test_ids = dataset_df.series[in_test_set].values + num_val_samples = 400 + val_ids = train_ids[:num_val_samples] + train_ids = train_ids[num_val_samples:] + return DatasetSplits.from_subject_ids(dataset_df, train_ids=train_ids, val_ids=val_ids, test_ids=test_ids, + subject_column="series", group_column="subject") + else: + return DatasetSplits.from_proportions(dataset_df, + proportion_train=0.8, + proportion_val=0.1, + proportion_test=0.1, + subject_column="series", + group_column="subject", + shuffle=True) + + # noinspection PyTypeChecker + def get_image_sample_transforms(self) -> ModelTransformsPerExecutionMode: + config = load_ssl_augmentation_config(path_linear_head_augmentation_cxr) + train_transforms = ScalarItemAugmentation( + Compose( + [DicomPreparation(), create_chest_xray_transform(config, apply_augmentations=True)])) + val_transforms = ScalarItemAugmentation( + Compose( + [DicomPreparation(), create_chest_xray_transform(config, apply_augmentations=False)])) + + return ModelTransformsPerExecutionMode(train=train_transforms, + val=val_transforms, + test=val_transforms) + + def create_model(self) -> LightningModule: + """ + This method must create the actual Lightning model that will be trained. + """ + if self.use_pretrained_model: + path_to_checkpoint = self._get_ssl_checkpoint_path() + + model = create_ssl_image_classifier( + num_classes=self.num_classes, + pl_checkpoint_path=str(path_to_checkpoint), + freeze_encoder=self.freeze_encoder) + + else: + encoder = create_ssl_encoder(encoder_name=EncoderName.densenet121.value) + model = SSLClassifier(num_classes=self.num_classes, + encoder=encoder, + freeze_encoder=self.freeze_encoder, + class_weights=None) + # Next args are just here because we are using this model within an InnerEyeContainer + model.imaging_feature_type = ImagingFeatureType.Image # type: ignore + model.num_non_image_features = 0 # type: ignore + model.encode_channels_jointly = True # type: ignore + return model + + def _get_ssl_checkpoint_path(self) -> Path: + # Get the SSL weights from the AML run provided via "pretraining_run_recovery_id" command line argument. + # Accessible via extra_downloaded_run_id field of the config. + assert self.extra_downloaded_run_id is not None + assert isinstance(self.extra_downloaded_run_id, RunRecovery) + ssl_path = self.checkpoint_folder / "ssl_checkpoint.ckpt" + + if not ssl_path.exists(): # for test (when it is already present) we don't need to redo this. + if self.name_of_checkpoint is not None: + logging.info(f"Using checkpoint: {self.name_of_checkpoint} as starting point.") + path_to_checkpoint = self.extra_downloaded_run_id.checkpoints_roots[0] / self.name_of_checkpoint + else: + path_to_checkpoint = self.extra_downloaded_run_id.get_best_checkpoint_paths()[0] + if not path_to_checkpoint.exists(): + logging.info("No best checkpoint found for this model. Getting the latest recovery " + "checkpoint instead.") + path_to_checkpoint = self.extra_downloaded_run_id.get_recovery_checkpoint_paths()[0] + assert path_to_checkpoint.exists() + path_to_checkpoint.rename(ssl_path) + return ssl_path + + def pre_process_dataset_dataframe(self) -> None: + pass + + @staticmethod + def get_posthoc_label_transform() -> Callable: + import torch + + def multiclass_to_hierarchical_labels(classes: torch.Tensor) -> torch.Tensor: + classes = classes.clone() + cvx03vs12 = classes[..., 1] + classes[..., 2] + cvx0vs3 = classes[..., 3] + cvx1vs2 = classes[..., 2] + cvx0vs3[cvx03vs12 == 1] = float('nan') # CVX0vs3 only gets gradient for CVX03 + cvx1vs2[cvx03vs12 == 0] = float('nan') # CVX1vs2 only gets gradient for CVX12 + return torch.stack([cvx03vs12, cvx0vs3, cvx1vs2], -1) + + return multiclass_to_hierarchical_labels + + @staticmethod + def get_loss_function() -> Callable: + import torch + import torch.nn.functional as F + + def nan_bce_with_logits(output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Compute BCE with logits, ignoring NaN values""" + valid = labels.isfinite() + losses = F.binary_cross_entropy_with_logits(output[valid], labels[valid], reduction='none') + return losses.sum() / labels.shape[0] + + return nan_bce_with_logits + + def generate_custom_report(self, report_dir: Path, model_proc: ModelProcessing) -> Path: + """ + Generate a custom report for the CovidDataset Hierarchical model. At the moment, this report will read the + file model_output.csv generated for the training, validation or test sets and compute a 4 class accuracy + and confusion matrix based on this. + :param report_dir: Directory report is to be written to + :param model_proc: Whether this is a single or ensemble model (model_output.csv will be located in different + paths for single vs ensemble runs.) + """ + + def get_output_csv_path(mode: ModelExecutionMode) -> Path: + p = get_best_epoch_results_path(mode=mode, model_proc=model_proc) + return self.outputs_folder / p / MODEL_OUTPUT_CSV + + train_metrics = get_output_csv_path(ModelExecutionMode.TRAIN) + val_metrics = get_output_csv_path(ModelExecutionMode.VAL) + test_metrics = get_output_csv_path(ModelExecutionMode.TEST) + + notebook_params = \ + { + 'innereye_path': str(fixed_paths_innereye.repository_root_directory()), + 'train_metrics_csv': str_or_empty(train_metrics), + 'val_metrics_csv': str_or_empty(val_metrics), + 'test_metrics_csv': str_or_empty(test_metrics), + "config": codecs.encode(pickle.dumps(self), "base64").decode(), + "is_crossval_report": False + } + template = Path(__file__).absolute().parent.parent / "reports" / "CovidHierarchicalModelReport.ipynb" + return generate_notebook(template, + notebook_params=notebook_params, + result_notebook=report_dir / get_ipynb_report_name( + f"{self.model_category.value}_hierarchical")) + + +class DicomPreparation: + def __call__(self, item: torch.Tensor) -> PIL.Image: + # Item will be of dimension [C, Z, X, Y] + images = item.numpy() + assert images.shape[0] == 1 and images.shape[1] == 1 + images = images.reshape(images.shape[2:]) + normalized_image = (images - images.min()) * 255. / (images.max() - images.min()) + image = Image.fromarray(normalized_image).convert("L") + return image diff --git a/InnerEye/ML/configs/reports/CovidHierarchicalModelReport.ipynb b/InnerEye/ML/configs/reports/CovidHierarchicalModelReport.ipynb new file mode 100644 index 000000000..07afe2543 --- /dev/null +++ b/InnerEye/ML/configs/reports/CovidHierarchicalModelReport.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "%%javascript\n", + "IPython.OutputArea.prototype._should_scroll = function(lines) {\n", + " return false;\n", + "}\n", + "// Stops auto-scrolling so entire output is visible: see https://stackoverflow.com/a/41646403" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "# Default parameter values. They will be overwritten by papermill notebook parameters.\n", + "# This cell must carry the tag \"parameters\" in its metadata.\n", + "from pathlib import Path\n", + "import pickle\n", + "import codecs\n", + "\n", + "innereye_path = Path.cwd().parent.parent.parent.parent\n", + "train_metrics_csv = \"\"\n", + "val_metrics_csv = \"\"\n", + "test_metrics_csv = \"\"\n", + "config = \"\"\n", + "is_crossval_report = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "print(f\"Adding to path: {innereye_path}\")\n", + "if str(innereye_path) not in sys.path:\n", + " sys.path.append(str(innereye_path))\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "\n", + "config = pickle.loads(codecs.decode(config.encode(), \"base64\"))\n", + "\n", + "from InnerEye.ML.common import ModelExecutionMode\n", + "from InnerEye.ML.reports.notebook_report import print_header\n", + "from InnerEye.ML.configs.reports.covid_hierarchical_model_report import print_metrics_from_csv\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "plt.rcParams['figure.figsize'] = (20, 10)\n", + "\n", + "#convert params to Path\n", + "train_metrics_csv = Path(train_metrics_csv)\n", + "val_metrics_csv = Path(val_metrics_csv)\n", + "test_metrics_csv = Path(test_metrics_csv)" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "# Metrics\n", + "## Train Set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "if train_metrics_csv.is_file():\n", + " print_metrics_from_csv(csv_to_set_optimal_threshold=train_metrics_csv,\n", + " csv_to_compute_metrics=train_metrics_csv,\n", + " config=config, is_crossval_report=is_crossval_report)" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Validation Set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "if val_metrics_csv.is_file():\n", + " print_metrics_from_csv(csv_to_set_optimal_threshold=val_metrics_csv,\n", + " csv_to_compute_metrics=val_metrics_csv,\n", + " config=config, is_crossval_report=is_crossval_report)" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "## Test Set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n", + " print_metrics_from_csv(csv_to_set_optimal_threshold=val_metrics_csv,\n", + " csv_to_compute_metrics=test_metrics_csv,\n", + " config=config, is_crossval_report=is_crossval_report)" + ] + } + ], + "metadata": { + "celltoolbar": "Tags", + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/InnerEye/ML/configs/reports/covid_hierarchical_model_report.py b/InnerEye/ML/configs/reports/covid_hierarchical_model_report.py new file mode 100644 index 000000000..f46e797e6 --- /dev/null +++ b/InnerEye/ML/configs/reports/covid_hierarchical_model_report.py @@ -0,0 +1,104 @@ +import pandas as pd +import numpy as np + +from pathlib import Path +from sklearn.metrics import accuracy_score, confusion_matrix +from typing import Dict + +from InnerEye.Common.metrics_constants import LoggingColumns +from InnerEye.ML.reports.classification_report import get_labels_and_predictions_from_dataframe, LabelsAndPredictions +from InnerEye.ML.reports.notebook_report import print_table +from InnerEye.ML.scalar_config import ScalarModelBase + +TARGET_NAMES = ['CVX03vs12', 'CVX0vs3', 'CVX1vs2'] +MULTICLASS_HUE_NAME = "Multiclass" + + +def get_label_from_label_dict(label_dict: Dict[str, float]) -> int: + """ + Converts strings CVX03vs12, CVX1vs2, CVX0vs3 to the corresponding class as int. + """ + if label_dict['CVX03vs12'] == 0: + assert np.isnan(label_dict['CVX1vs2']) + if label_dict['CVX0vs3'] == 0: + label = 0 + elif label_dict['CVX0vs3'] == 1: + label = 3 + else: + raise ValueError("CVX0vs3 should be 0 or 1.") + elif label_dict['CVX03vs12'] == 1: + assert np.isnan(label_dict['CVX0vs3']) + if label_dict['CVX1vs2'] == 0: + label = 1 + elif label_dict['CVX1vs2'] == 1: + label = 2 + else: + raise ValueError("CVX1vs2 should be 0 or 1.") + else: + raise ValueError("CVX03vs12 should be 0 or 1.") + return label + + +def get_model_prediction_by_probabilities(output_dict: Dict[str, float]) -> int: + """ + Based on the values for CVX03vs12, CVX0vs3 and CVX1vs2 predicted by the model, predict the CVX scores as followed: + score(CVX0) = [1 - score(CVX03vs12)][1 - score(CVX0vs3)] + score(CVX1) = score(CVX03vs12)[1 - score(CVX1vs2)] + score(CVX2) = score(CVX03vs12)score(CVX1vs2) + score(CVX3) = [1 - score(CVX03vs12)]score(CVX0vs3) + """ + cvx0 = (1 - output_dict['CVX03vs12']) * (1 - output_dict['CVX0vs3']) + cvx3 = (1 - output_dict['CVX03vs12']) * output_dict['CVX0vs3'] + cvx1 = output_dict['CVX03vs12'] * (1 - output_dict['CVX1vs2']) + cvx2 = output_dict['CVX03vs12'] * output_dict['CVX1vs2'] + return np.argmax([cvx0, cvx1, cvx2, cvx3]) + + +def get_dataframe_with_covid_labels(metrics_df: pd.DataFrame) -> pd.DataFrame: + def get_CVX_labels(df: pd.DataFrame) -> pd.DataFrame: + """ + Given a dataframe (with only one subject) with the model outputs for CVX03vs12, CVX0vs3 and CVX1vs2, + returns a corresponding dataframe with scores for CVX0, CVX1, CVX2 and CVX3 for this subject. See + `get_model_prediction_by_probabilities` for details on mapping the model output to CVX labels. + """ + df_by_hue = df[df[LoggingColumns.Hue.value].isin(TARGET_NAMES)].set_index(LoggingColumns.Hue.value) + model_output = get_model_prediction_by_probabilities(df_by_hue[LoggingColumns.ModelOutput.value].to_dict()) + label = get_label_from_label_dict(df_by_hue[LoggingColumns.Label.value].to_dict()) + + return pd.DataFrame.from_dict({LoggingColumns.Patient.value: [df.iloc[0][LoggingColumns.Patient.value]], + LoggingColumns.ModelOutput.value: [model_output], + LoggingColumns.Label.value: [label]}) + + df = metrics_df.copy() + # Group by subject, and for each subject, convert the CVX03vs12, CVX0vs3 and CVX1vs2 predictions to CVX labels. + df = df.groupby(LoggingColumns.Patient.value, as_index=False).apply(get_CVX_labels).reset_index(drop=True) + df[LoggingColumns.Hue.value] = [MULTICLASS_HUE_NAME] * len(df) + return df + + +def get_labels_and_predictions_covid_labels(csv: Path) -> LabelsAndPredictions: + metrics_df = pd.read_csv(csv) + df = get_dataframe_with_covid_labels(metrics_df=metrics_df) + return get_labels_and_predictions_from_dataframe(df) + + +def print_metrics_from_csv(csv_to_set_optimal_threshold: Path, + csv_to_compute_metrics: Path, + config: ScalarModelBase, + is_crossval_report: bool) -> None: + assert config.target_names == TARGET_NAMES + + predictions_to_compute_metrics = get_labels_and_predictions_covid_labels( + csv=csv_to_compute_metrics) + + acc = accuracy_score(predictions_to_compute_metrics.labels, predictions_to_compute_metrics.model_outputs) + rows = [[f"{acc:.4f}"]] + print_table(rows, header=["Multiclass Accuracy"]) + + conf_matrix = confusion_matrix(predictions_to_compute_metrics.labels, predictions_to_compute_metrics.model_outputs) + rows = [] + header = ["", "CVX0 predicted", "CVX1 predicted", "CVX2 predicted", "CVX3 predicted"] + for i in range(conf_matrix.shape[0]): + line = [f"CVX{i} GT"] + list(conf_matrix[i]) + rows.append(line) + print_table(rows, header=header) diff --git a/InnerEye/ML/configs/ssl/CovidContainers.py b/InnerEye/ML/configs/ssl/CovidContainers.py new file mode 100644 index 000000000..2941b1b39 --- /dev/null +++ b/InnerEye/ML/configs/ssl/CovidContainers.py @@ -0,0 +1,36 @@ +from typing import Any + +from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName +from InnerEye.ML.SSL.utils import SSLTrainingType +from InnerEye.ML.configs.ssl.CXR_SSL_configs import NIH_AZURE_DATASET_ID, path_encoder_augmentation_cxr, \ + path_linear_head_augmentation_cxr + +COVID_DATASET_ID = "id-of-your-dataset" + + +class NIH_COVID_BYOL(SSLContainer): + """ + Class to train a SSL model on NIH dataset and monitor embeddings quality on a Covid Dataset. + """ + + def __init__(self, + covid_dataset_id: str = COVID_DATASET_ID, + **kwargs: Any): + super().__init__(ssl_training_dataset_name=SSLDatasetName.NIHCXR, + linear_head_dataset_name=SSLDatasetName.Covid, + random_seed=1, + recovery_checkpoint_save_interval=50, + recovery_checkpoints_save_last_k=3, + num_epochs=500, + ssl_training_batch_size=1200, # This runs with 16 gpus (4 nodes) + num_workers=12, + ssl_encoder=EncoderName.densenet121, + ssl_training_type=SSLTrainingType.BYOL, + use_balanced_binary_loss_for_linear_head=True, + ssl_augmentation_config=path_encoder_augmentation_cxr, + extra_azure_dataset_ids=[covid_dataset_id], + azure_dataset_id=NIH_AZURE_DATASET_ID, + linear_head_augmentation_config=path_linear_head_augmentation_cxr, + online_evaluator_lr=1e-5, + linear_head_batch_size=64, + **kwargs) diff --git a/Tests/ML/configs/utils/test_hierarchical_covid_model_report.py b/Tests/ML/configs/utils/test_hierarchical_covid_model_report.py new file mode 100644 index 000000000..1dea1bff5 --- /dev/null +++ b/Tests/ML/configs/utils/test_hierarchical_covid_model_report.py @@ -0,0 +1,22 @@ +import pandas as pd +from math import nan + +from InnerEye.Common.metrics_constants import LoggingColumns +from InnerEye.ML.configs.reports.covid_hierarchical_model_report import MULTICLASS_HUE_NAME, \ + get_dataframe_with_covid_labels + + +def test_get_dataframe_with_covid_labels() -> None: + + df = pd.DataFrame.from_dict({LoggingColumns.Patient.value: [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4], + LoggingColumns.Hue.value: ['CVX03vs12', 'CVX0vs3', 'CVX1vs2'] * 4, + LoggingColumns.Label.value: [0, 0, nan, 0, 1, nan, 1, nan, 0, 1, nan, 1], + LoggingColumns.ModelOutput.value: [0.1, 0.1, 0.5, 0.1, 0.9, 0.5, 0.9, 0.9, 0.9, 0.1, 0.2, 0.1]}) + expected_df = pd.DataFrame.from_dict({LoggingColumns.Patient.value: [1, 2, 3, 4], + LoggingColumns.ModelOutput.value: [0, 3, 2, 0], + LoggingColumns.Label.value: [0, 3, 1, 2], + LoggingColumns.Hue.value: [MULTICLASS_HUE_NAME] * 4 + }) + + multiclass_df = get_dataframe_with_covid_labels(df) + assert expected_df.equals(multiclass_df)