Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit 96840e8

Browse files
author
Tim Regan
authored
Merge branch 'main' into alberto/inference
2 parents f1fd3c8 + cab68cc commit 96840e8

12 files changed

+371
-108
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ created.
1515
### Added
1616
- ([#465](https://github.com/microsoft/InnerEye-DeepLearning/pull/465/)) Adding ability to run segmentation inference
1717
module in the test data without or partial ground truth files.
18+
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) More flags for fine control of when to run inference.
1819
- ([#492](https://github.com/microsoft/InnerEye-DeepLearning/pull/492)) Adding capability for regression tests for test
1920
jobs that run in AzureML.
2021

2122
### Changed
23+
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) Renamed command line option 'perform_training_set_inference' to 'inference_on_train_set'. Replaced command line option 'perform_validation_and_test_set_inference' with the pair of options 'inference_on_val_set' and 'inference_on_test_set'.
2224
- ([#496](https://github.com/microsoft/InnerEye-DeepLearning/pull/496)) All plots are now saved as PNG, rather than JPG.
2325
- ([#497](https://github.com/microsoft/InnerEye-DeepLearning/pull/497)) Reducing the size of the code snapshot that
2426
gets uploaded to AzureML, by skipping all test folders.

InnerEye/Azure/azure_config.py

+2-26
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from azureml.train.hyperdrive import HyperDriveConfig
2020
from git import Repo
2121

22-
from InnerEye.Azure.azure_util import fetch_run, is_offline_run_context
22+
from InnerEye.Azure.azure_util import fetch_run, is_offline_run_context, remove_arg
2323
from InnerEye.Azure.secrets_handling import SecretsHandling, read_all_settings
2424
from InnerEye.Common import fixed_paths
2525
from InnerEye.Common.generic_parsing import GenericConfig
@@ -324,31 +324,7 @@ def set_script_params_except_submit_flag(self) -> None:
324324
Populates the script_param field of the present object from the arguments in sys.argv, with the exception
325325
of the "azureml" flag.
326326
"""
327-
args = sys.argv[1:]
328-
submit_flag = f"--{AZURECONFIG_SUBMIT_TO_AZUREML}"
329-
retained_args = []
330-
i = 0
331-
while i < len(args):
332-
arg = args[i]
333-
if arg.startswith(submit_flag):
334-
if len(arg) == len(submit_flag):
335-
# The commandline argument is "--azureml", with something possibly following: This can either be
336-
# "--azureml True" or "--azureml --some_other_param"
337-
if i < (len(args) - 1):
338-
# If the next argument starts with a "-" then assume that it does not belong to the --azureml
339-
# flag. If there is no "-", assume it belongs to the --azureml flag, and skip both
340-
if not args[i + 1].startswith("-"):
341-
i = i + 1
342-
elif arg[len(submit_flag)] == "=":
343-
# The commandline argument is "--azureml=True" or "--azureml=False": Continue with next arg
344-
pass
345-
else:
346-
# The argument list contains a flag like "--azureml_foo": Keep that.
347-
retained_args.append(arg)
348-
else:
349-
retained_args.append(arg)
350-
i = i + 1
351-
self.script_params = retained_args
327+
self.script_params = remove_arg(AZURECONFIG_SUBMIT_TO_AZUREML, sys.argv[1:])
352328

353329

354330
@dataclass

InnerEye/Azure/azure_util.py

+40
Original file line numberDiff line numberDiff line change
@@ -455,3 +455,43 @@ def step_up_directories(path: Path) -> Generator[Path, None, None]:
455455
if parent == path:
456456
break
457457
path = parent
458+
459+
460+
def remove_arg(arg: str, args: List[str]) -> List[str]:
461+
"""
462+
Remove an argument from a list of arguments. The argument list is assumed to contain
463+
elements of the form:
464+
"-a", "--arg1", "--arg2", "value2", or "--arg3=value"
465+
If there is an item matching "--arg" then it will be removed from the list.
466+
467+
:param arg: Argument to look for.
468+
:param args: List of arguments to scan.
469+
:return: List of arguments with --arg removed, if present.
470+
"""
471+
arg_opt = f"--{arg}"
472+
no_arg_opt = f"--no-{arg}"
473+
retained_args = []
474+
i = 0
475+
while i < len(args):
476+
arg = args[i]
477+
if arg.startswith(arg_opt):
478+
if len(arg) == len(arg_opt):
479+
# The commandline argument is "--arg", with something possibly following: This can either be
480+
# "--arg_opt value" or "--arg_opt --some_other_param"
481+
if i < (len(args) - 1):
482+
# If the next argument starts with a "-" then assume that it does not belong to the --arg
483+
# argument. If there is no "-", assume it belongs to the --arg_opt argument, and skip both
484+
if not args[i + 1].startswith("-"):
485+
i = i + 1
486+
elif arg[len(arg_opt)] == "=":
487+
# The commandline argument is "--arg=value": Continue with next arg
488+
pass
489+
else:
490+
# The argument list contains an argument like "--arg_other_param": Keep that.
491+
retained_args.append(arg)
492+
elif arg == no_arg_opt:
493+
pass
494+
else:
495+
retained_args.append(arg)
496+
i = i + 1
497+
return retained_args

InnerEye/ML/SSL/lightning_containers/ssl_container.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ def setup(self) -> None:
121121
dataset_path=self.local_dataset,
122122
batch_size=self.ssl_training_batch_size)})
123123
self.data_module: InnerEyeDataModuleTypes = self.get_data_module()
124-
self.perform_validation_and_test_set_inference = False
125-
if self.number_of_cross_validation_splits > 1:
124+
self.inference_on_val_set = False
125+
self.inference_on_test_set = False
126+
if self.perform_cross_validation:
126127
raise NotImplementedError("Cross-validation logic is not implemented for this module.")
127128

128129
def _load_config(self) -> None:

InnerEye/ML/deep_learning_config.py

+69-13
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,16 @@
55
from __future__ import annotations
66

77
import logging
8-
from enum import Enum, unique
9-
from pathlib import Path
10-
from typing import Any, Dict, List, Optional
11-
128
import param
9+
from enum import Enum, unique
1310
from pandas import DataFrame
1411
from param import Parameterized
12+
from pathlib import Path
13+
from typing import Any, Dict, List, Optional
1514

1615
from InnerEye.Azure.azure_util import DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, RUN_CONTEXT, is_offline_run_context
1716
from InnerEye.Common import fixed_paths
18-
from InnerEye.Common.common_util import is_windows
17+
from InnerEye.Common.common_util import ModelProcessing, is_windows
1918
from InnerEye.Common.fixed_paths import DEFAULT_AML_UPLOAD_DIR, DEFAULT_LOGS_DIR_NAME
2019
from InnerEye.Common.generic_parsing import GenericConfig
2120
from InnerEye.Common.type_annotations import PathOrString, TupleFloat2
@@ -199,14 +198,24 @@ class WorkflowParams(param.Parameterized):
199198
cross_validation_split_index: int = param.Integer(DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, bounds=(-1, None),
200199
doc="The index of the cross validation fold this model is "
201200
"associated with when performing k-fold cross validation")
202-
perform_training_set_inference: bool = \
203-
param.Boolean(False,
204-
doc="If True, run full image inference on the training set at the end of training. If False and "
205-
"perform_validation_and_test_set_inference is True (default), only run inference on "
206-
"validation and test set. If both flags are False do not run inference.")
207-
perform_validation_and_test_set_inference: bool = \
208-
param.Boolean(True,
209-
doc="If True (default), run full image inference on validation and test set after training.")
201+
inference_on_train_set: Optional[bool] = \
202+
param.Boolean(None,
203+
doc="If set, enable/disable full image inference on training set after training.")
204+
inference_on_val_set: Optional[bool] = \
205+
param.Boolean(None,
206+
doc="If set, enable/disable full image inference on validation set after training.")
207+
inference_on_test_set: Optional[bool] = \
208+
param.Boolean(None,
209+
doc="If set, enable/disable full image inference on test set after training.")
210+
ensemble_inference_on_train_set: Optional[bool] = \
211+
param.Boolean(None,
212+
doc="If set, enable/disable full image inference on the training set after ensemble training.")
213+
ensemble_inference_on_val_set: Optional[bool] = \
214+
param.Boolean(None,
215+
doc="If set, enable/disable full image inference on validation set after ensemble training.")
216+
ensemble_inference_on_test_set: Optional[bool] = \
217+
param.Boolean(None,
218+
doc="If set, enable/disable full image inference on test set after ensemble training.")
210219
weights_url: str = param.String(doc="If provided, a url from which weights will be downloaded and used for model "
211220
"initialization.")
212221
local_weights_path: Optional[Path] = param.ClassSelector(class_=Path,
@@ -254,6 +263,53 @@ def validate(self) -> None:
254263
f"found number_of_cross_validation_splits = {self.number_of_cross_validation_splits} "
255264
f"and cross_validation_split_index={self.cross_validation_split_index}")
256265

266+
""" Defaults for when to run inference in the absence of any command line switches. """
267+
INFERENCE_DEFAULTS: Dict[ModelProcessing, Dict[ModelExecutionMode, bool]] = {
268+
ModelProcessing.DEFAULT: {
269+
ModelExecutionMode.TRAIN: False,
270+
ModelExecutionMode.TEST: True,
271+
ModelExecutionMode.VAL: True,
272+
},
273+
ModelProcessing.ENSEMBLE_CREATION: {
274+
ModelExecutionMode.TRAIN: False,
275+
ModelExecutionMode.TEST: True,
276+
ModelExecutionMode.VAL: False,
277+
}
278+
}
279+
280+
def inference_options(self) -> Dict[ModelProcessing, Dict[ModelExecutionMode, Optional[bool]]]:
281+
"""
282+
Return a mapping from ModelProcesing and ModelExecutionMode to command line switch.
283+
284+
:return: Command line switch for each combination of ModelProcessing and ModelExecutionMode.
285+
"""
286+
return {
287+
ModelProcessing.DEFAULT: {
288+
ModelExecutionMode.TRAIN: self.inference_on_train_set,
289+
ModelExecutionMode.TEST: self.inference_on_test_set,
290+
ModelExecutionMode.VAL: self.inference_on_val_set,
291+
},
292+
ModelProcessing.ENSEMBLE_CREATION: {
293+
ModelExecutionMode.TRAIN: self.ensemble_inference_on_train_set,
294+
ModelExecutionMode.TEST: self.ensemble_inference_on_test_set,
295+
ModelExecutionMode.VAL: self.ensemble_inference_on_val_set,
296+
}
297+
}
298+
299+
def inference_on_set(self, model_proc: ModelProcessing, data_split: ModelExecutionMode) -> bool:
300+
"""
301+
Returns True if inference is required for this model_proc and data_split.
302+
303+
:param model_proc: Whether we are testing an ensemble or single model.
304+
:param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
305+
:return: True if inference required.
306+
"""
307+
inference_option = self.inference_options()[model_proc][data_split]
308+
if inference_option is not None:
309+
return inference_option
310+
311+
return WorkflowParams.INFERENCE_DEFAULTS[model_proc][data_split]
312+
257313
@property
258314
def is_offline_run(self) -> bool:
259315
"""

InnerEye/ML/run_ml.py

+38-39
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any, Callable, Dict, List, Optional, Tuple
1212

1313
import pandas as pd
14+
from pytorch_lightning.core.datamodule import LightningDataModule
1415
import stopit
1516
import torch.multiprocessing
1617
from azureml._restclient.constants import RunStatus
@@ -120,19 +121,16 @@ def download_dataset(azure_dataset_id: str,
120121
return expected_dataset_path
121122

122123

123-
def log_metrics(val_metrics: Optional[InferenceMetricsForSegmentation],
124-
test_metrics: Optional[InferenceMetricsForSegmentation],
125-
train_metrics: Optional[InferenceMetricsForSegmentation],
124+
def log_metrics(metrics: Dict[ModelExecutionMode, InferenceMetrics],
126125
run_context: Run) -> None:
127126
"""
128127
Log metrics for each split to the provided run, or the current run context if None provided
129-
:param val_metrics: Inference results for the validation split
130-
:param test_metrics: Inference results for the test split
131-
:param train_metrics: Inference results for the train split
128+
:param metrics: Dictionary of inference results for each split.
132129
:param run_context: Run for which to log the metrics to, use the current run context if None provided
133130
"""
134-
for split in [x for x in [val_metrics, test_metrics, train_metrics] if x]:
135-
split.log_metrics(run_context)
131+
for split in metrics.values():
132+
if isinstance(split, InferenceMetricsForSegmentation):
133+
split.log_metrics(run_context)
136134

137135

138136
class MLRunner:
@@ -390,7 +388,7 @@ def run(self) -> None:
390388

391389
# If this is an cross validation run, and the present run is child run 0, then wait for the sibling
392390
# runs, build the ensemble model, and write a report for that.
393-
if self.container.number_of_cross_validation_splits > 0:
391+
if self.container.perform_cross_validation:
394392
should_wait_for_other_child_runs = (not self.is_offline_run) and \
395393
self.container.cross_validation_split_index == 0
396394
if should_wait_for_other_child_runs:
@@ -420,10 +418,24 @@ def is_normal_run_or_crossval_child_0(self) -> bool:
420418
"""
421419
Returns True if the present run is a non-crossvalidation run, or child run 0 of a crossvalidation run.
422420
"""
423-
if self.container.number_of_cross_validation_splits > 0:
421+
if self.container.perform_cross_validation:
424422
return self.container.cross_validation_split_index == 0
425423
return True
426424

425+
@staticmethod
426+
def lightning_data_module_dataloaders(data: LightningDataModule) -> Dict[ModelExecutionMode, Callable]:
427+
"""
428+
Given a lightning data module, return a dictionary of dataloader for each model execution mode.
429+
430+
:param data: Lightning data module.
431+
:return: Data loader for each model execution mode.
432+
"""
433+
return {
434+
ModelExecutionMode.TEST: data.test_dataloader,
435+
ModelExecutionMode.VAL: data.val_dataloader,
436+
ModelExecutionMode.TRAIN: data.train_dataloader
437+
}
438+
427439
def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> None:
428440
"""
429441
Run inference on the test set for all models that are specified via a LightningContainer.
@@ -439,11 +451,10 @@ def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> No
439451
# Read the data modules before changing the working directory, in case the code relies on relative paths
440452
data = self.container.get_inference_data_module()
441453
dataloaders: List[Tuple[DataLoader, ModelExecutionMode]] = []
442-
if self.container.perform_validation_and_test_set_inference:
443-
dataloaders.append((data.test_dataloader(), ModelExecutionMode.TEST)) # type: ignore
444-
dataloaders.append((data.val_dataloader(), ModelExecutionMode.VAL)) # type: ignore
445-
if self.container.perform_training_set_inference:
446-
dataloaders.append((data.train_dataloader(), ModelExecutionMode.TRAIN)) # type: ignore
454+
data_dataloaders = MLRunner.lightning_data_module_dataloaders(data)
455+
for data_split, dataloader in data_dataloaders.items():
456+
if self.container.inference_on_set(ModelProcessing.DEFAULT, data_split):
457+
dataloaders.append((dataloader(), data_split))
447458
checkpoint = load_checkpoint(checkpoint_paths[0], use_gpu=self.container.use_gpu)
448459
lightning_model.load_state_dict(checkpoint['state_dict'])
449460
lightning_model.eval()
@@ -491,8 +502,8 @@ def run_inference(self, checkpoint_handler: CheckpointHandler,
491502
"""
492503

493504
# run full image inference on existing or newly trained model on the training, and testing set
494-
test_metrics, val_metrics, _ = self.model_inference_train_and_test(checkpoint_handler=checkpoint_handler,
495-
model_proc=model_proc)
505+
self.model_inference_train_and_test(checkpoint_handler=checkpoint_handler,
506+
model_proc=model_proc)
496507

497508
self.try_compare_scores_against_baselines(model_proc)
498509

@@ -752,37 +763,25 @@ def copy_file(source: Path, destination_file: str) -> None:
752763
def model_inference_train_and_test(self,
753764
checkpoint_handler: CheckpointHandler,
754765
model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> \
755-
Tuple[Optional[InferenceMetrics], Optional[InferenceMetrics], Optional[InferenceMetrics]]:
756-
train_metrics = None
757-
val_metrics = None
758-
test_metrics = None
766+
Dict[ModelExecutionMode, InferenceMetrics]:
767+
metrics: Dict[ModelExecutionMode, InferenceMetrics] = {}
759768

760769
config = self.innereye_config
761770

762-
def run_model_test(data_split: ModelExecutionMode) -> Optional[InferenceMetrics]:
763-
return model_test(config, data_split=data_split, checkpoint_handler=checkpoint_handler, # type: ignore
764-
model_proc=model_proc)
765-
766-
if config.perform_validation_and_test_set_inference:
767-
# perform inference on test set
768-
test_metrics = run_model_test(ModelExecutionMode.TEST)
769-
# perform inference on validation set (not for ensemble as current val is in the training fold
770-
# for at least one of the models).
771-
if model_proc != ModelProcessing.ENSEMBLE_CREATION:
772-
val_metrics = run_model_test(ModelExecutionMode.VAL)
773-
774-
if config.perform_training_set_inference:
775-
# perform inference on training set if required
776-
train_metrics = run_model_test(ModelExecutionMode.TRAIN)
771+
for data_split in ModelExecutionMode:
772+
if self.container.inference_on_set(model_proc, data_split):
773+
opt_metrics = model_test(config, data_split=data_split, checkpoint_handler=checkpoint_handler,
774+
model_proc=model_proc)
775+
if opt_metrics is not None:
776+
metrics[data_split] = opt_metrics
777777

778778
# log the metrics to AzureML experiment if possible. When doing ensemble runs, log to the Hyperdrive parent run,
779779
# so that we get the metrics of child run 0 and the ensemble separated.
780780
if config.is_segmentation_model and not self.is_offline_run:
781781
run_for_logging = PARENT_RUN_CONTEXT if model_proc.ENSEMBLE_CREATION else RUN_CONTEXT
782-
log_metrics(val_metrics=val_metrics, test_metrics=test_metrics, # type: ignore
783-
train_metrics=train_metrics, run_context=run_for_logging) # type: ignore
782+
log_metrics(metrics=metrics, run_context=run_for_logging) # type: ignore
784783

785-
return test_metrics, val_metrics, train_metrics
784+
return metrics
786785

787786
@stopit.threading_timeoutable()
788787
def wait_for_runs_to_finish(self, delay: int = 60) -> None:

Tests/ML/configs/lightning_test_containers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class DummyContainerWithModel(LightningContainer):
190190

191191
def __init__(self) -> None:
192192
super().__init__()
193-
self.perform_training_set_inference = True
193+
self.inference_on_train_set = True
194194
self.num_epochs = 50
195195
self.l_rate = 1e-1
196196

0 commit comments

Comments
 (0)