11
11
from typing import Any , Callable , Dict , List , Optional , Tuple
12
12
13
13
import pandas as pd
14
+ from pytorch_lightning .core .datamodule import LightningDataModule
14
15
import stopit
15
16
import torch .multiprocessing
16
17
from azureml ._restclient .constants import RunStatus
@@ -120,19 +121,16 @@ def download_dataset(azure_dataset_id: str,
120
121
return expected_dataset_path
121
122
122
123
123
- def log_metrics (val_metrics : Optional [InferenceMetricsForSegmentation ],
124
- test_metrics : Optional [InferenceMetricsForSegmentation ],
125
- train_metrics : Optional [InferenceMetricsForSegmentation ],
124
+ def log_metrics (metrics : Dict [ModelExecutionMode , InferenceMetrics ],
126
125
run_context : Run ) -> None :
127
126
"""
128
127
Log metrics for each split to the provided run, or the current run context if None provided
129
- :param val_metrics: Inference results for the validation split
130
- :param test_metrics: Inference results for the test split
131
- :param train_metrics: Inference results for the train split
128
+ :param metrics: Dictionary of inference results for each split.
132
129
:param run_context: Run for which to log the metrics to, use the current run context if None provided
133
130
"""
134
- for split in [x for x in [val_metrics , test_metrics , train_metrics ] if x ]:
135
- split .log_metrics (run_context )
131
+ for split in metrics .values ():
132
+ if isinstance (split , InferenceMetricsForSegmentation ):
133
+ split .log_metrics (run_context )
136
134
137
135
138
136
class MLRunner :
@@ -390,7 +388,7 @@ def run(self) -> None:
390
388
391
389
# If this is an cross validation run, and the present run is child run 0, then wait for the sibling
392
390
# runs, build the ensemble model, and write a report for that.
393
- if self .container .number_of_cross_validation_splits > 0 :
391
+ if self .container .perform_cross_validation :
394
392
should_wait_for_other_child_runs = (not self .is_offline_run ) and \
395
393
self .container .cross_validation_split_index == 0
396
394
if should_wait_for_other_child_runs :
@@ -420,10 +418,24 @@ def is_normal_run_or_crossval_child_0(self) -> bool:
420
418
"""
421
419
Returns True if the present run is a non-crossvalidation run, or child run 0 of a crossvalidation run.
422
420
"""
423
- if self .container .number_of_cross_validation_splits > 0 :
421
+ if self .container .perform_cross_validation :
424
422
return self .container .cross_validation_split_index == 0
425
423
return True
426
424
425
+ @staticmethod
426
+ def lightning_data_module_dataloaders (data : LightningDataModule ) -> Dict [ModelExecutionMode , Callable ]:
427
+ """
428
+ Given a lightning data module, return a dictionary of dataloader for each model execution mode.
429
+
430
+ :param data: Lightning data module.
431
+ :return: Data loader for each model execution mode.
432
+ """
433
+ return {
434
+ ModelExecutionMode .TEST : data .test_dataloader ,
435
+ ModelExecutionMode .VAL : data .val_dataloader ,
436
+ ModelExecutionMode .TRAIN : data .train_dataloader
437
+ }
438
+
427
439
def run_inference_for_lightning_models (self , checkpoint_paths : List [Path ]) -> None :
428
440
"""
429
441
Run inference on the test set for all models that are specified via a LightningContainer.
@@ -439,11 +451,10 @@ def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> No
439
451
# Read the data modules before changing the working directory, in case the code relies on relative paths
440
452
data = self .container .get_inference_data_module ()
441
453
dataloaders : List [Tuple [DataLoader , ModelExecutionMode ]] = []
442
- if self .container .perform_validation_and_test_set_inference :
443
- dataloaders .append ((data .test_dataloader (), ModelExecutionMode .TEST )) # type: ignore
444
- dataloaders .append ((data .val_dataloader (), ModelExecutionMode .VAL )) # type: ignore
445
- if self .container .perform_training_set_inference :
446
- dataloaders .append ((data .train_dataloader (), ModelExecutionMode .TRAIN )) # type: ignore
454
+ data_dataloaders = MLRunner .lightning_data_module_dataloaders (data )
455
+ for data_split , dataloader in data_dataloaders .items ():
456
+ if self .container .inference_on_set (ModelProcessing .DEFAULT , data_split ):
457
+ dataloaders .append ((dataloader (), data_split ))
447
458
checkpoint = load_checkpoint (checkpoint_paths [0 ], use_gpu = self .container .use_gpu )
448
459
lightning_model .load_state_dict (checkpoint ['state_dict' ])
449
460
lightning_model .eval ()
@@ -491,8 +502,8 @@ def run_inference(self, checkpoint_handler: CheckpointHandler,
491
502
"""
492
503
493
504
# run full image inference on existing or newly trained model on the training, and testing set
494
- test_metrics , val_metrics , _ = self .model_inference_train_and_test (checkpoint_handler = checkpoint_handler ,
495
- model_proc = model_proc )
505
+ self .model_inference_train_and_test (checkpoint_handler = checkpoint_handler ,
506
+ model_proc = model_proc )
496
507
497
508
self .try_compare_scores_against_baselines (model_proc )
498
509
@@ -752,37 +763,25 @@ def copy_file(source: Path, destination_file: str) -> None:
752
763
def model_inference_train_and_test (self ,
753
764
checkpoint_handler : CheckpointHandler ,
754
765
model_proc : ModelProcessing = ModelProcessing .DEFAULT ) -> \
755
- Tuple [Optional [InferenceMetrics ], Optional [InferenceMetrics ], Optional [InferenceMetrics ]]:
756
- train_metrics = None
757
- val_metrics = None
758
- test_metrics = None
766
+ Dict [ModelExecutionMode , InferenceMetrics ]:
767
+ metrics : Dict [ModelExecutionMode , InferenceMetrics ] = {}
759
768
760
769
config = self .innereye_config
761
770
762
- def run_model_test (data_split : ModelExecutionMode ) -> Optional [InferenceMetrics ]:
763
- return model_test (config , data_split = data_split , checkpoint_handler = checkpoint_handler , # type: ignore
764
- model_proc = model_proc )
765
-
766
- if config .perform_validation_and_test_set_inference :
767
- # perform inference on test set
768
- test_metrics = run_model_test (ModelExecutionMode .TEST )
769
- # perform inference on validation set (not for ensemble as current val is in the training fold
770
- # for at least one of the models).
771
- if model_proc != ModelProcessing .ENSEMBLE_CREATION :
772
- val_metrics = run_model_test (ModelExecutionMode .VAL )
773
-
774
- if config .perform_training_set_inference :
775
- # perform inference on training set if required
776
- train_metrics = run_model_test (ModelExecutionMode .TRAIN )
771
+ for data_split in ModelExecutionMode :
772
+ if self .container .inference_on_set (model_proc , data_split ):
773
+ opt_metrics = model_test (config , data_split = data_split , checkpoint_handler = checkpoint_handler ,
774
+ model_proc = model_proc )
775
+ if opt_metrics is not None :
776
+ metrics [data_split ] = opt_metrics
777
777
778
778
# log the metrics to AzureML experiment if possible. When doing ensemble runs, log to the Hyperdrive parent run,
779
779
# so that we get the metrics of child run 0 and the ensemble separated.
780
780
if config .is_segmentation_model and not self .is_offline_run :
781
781
run_for_logging = PARENT_RUN_CONTEXT if model_proc .ENSEMBLE_CREATION else RUN_CONTEXT
782
- log_metrics (val_metrics = val_metrics , test_metrics = test_metrics , # type: ignore
783
- train_metrics = train_metrics , run_context = run_for_logging ) # type: ignore
782
+ log_metrics (metrics = metrics , run_context = run_for_logging ) # type: ignore
784
783
785
- return test_metrics , val_metrics , train_metrics
784
+ return metrics
786
785
787
786
@stopit .threading_timeoutable ()
788
787
def wait_for_runs_to_finish (self , delay : int = 60 ) -> None :
0 commit comments