-
Notifications
You must be signed in to change notification settings - Fork 143
Run inference using checkpoints from registered models #509
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checked through and all looks good. The test_model_inference_on_single_run
test may prove useful for the PR I'm working on now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, looking good. Some parameters could find better homes/classes.
Can you please add documentation around how this new flag should be used to run inference? Also, ensure that the documentation is cleared of any references to functionality that no longer exists.
run_recovery_id: str = param.String(doc="A run recovery id string in the form 'experiment name:run id' " | ||
"to use for inference, recovering a model training run or to register " | ||
"a model.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated to your PR, but why are these living in AzureConfig?
InnerEye/ML/run_ml.py
Outdated
self.container.extra_downloaded_run_id = run_recovery_object | ||
else: | ||
self.container.extra_downloaded_run_id = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The field name reads a bit strange - it seems to indicate that this is a run ID, but it's a RunRecovery object (which in turn is nothing but a list of Paths).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed, see my comment below.
InnerEye/ML/run_ml.py
Outdated
run_to_recover, | ||
EXTRA_RUN_SUBFOLDER, | ||
only_return_path=not is_global_rank_zero()) | ||
self.container.extra_downloaded_run_id = run_recovery_object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This field "extra_downloaded_run_id" could do with some documentation. Also, we define that twice in DeepLearningConfig and in the container. Maybe it is better places in WorkflowParams?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've renamed extra_downloaded_run_id
to pretraining_run_checkpoints
and initialized it once in WorkflowParams. I've also moved pretraining_run_recovery_id
from AzureConfig to WorkflowParams.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving pretraining_run_checkpoints
to WorkflowParams causes issues with initialization, I'm reverting this change for now.
@@ -49,7 +49,7 @@ steps: | |||
# hence don't set PYTHONPATH | |||
- bash: | | |||
source activate InnerEye | |||
pytest ./Tests/ -m "not (gpu or azureml or after_training_single_run or after_training_ensemble_run or inference or after_training_2node or after_training_glaucoma_cv_run)" --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-config=.coveragerc --cov-report=xml -n 2 --dist=loadscope --verbose | |||
pytest ./Tests/ -m "not (gpu or azureml or after_training_single_run or after_training_ensemble_run or inference or after_training_2node or after_training_glaucoma_cv_run or after_training_hello_container)" --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-config=.coveragerc --cov-report=xml -n 2 --dist=loadscope --verbose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See this SO question: https://stackoverflow.com/a/55921954 - effectively we are running everything that does not have a mark, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit complicated - there does not seem to be an easy way to check for custom markers only and ignore markers such as "skipif" and "parametrize".
@@ -246,8 +256,13 @@ class WorkflowParams(param.Parameterized): | |||
"be relative to the repository root directory.") | |||
|
|||
def validate(self) -> None: | |||
if self.weights_url and self.local_weights_path: | |||
raise ValueError("Cannot specify both local_weights_path and weights_url.") | |||
if sum([bool(param) for param in [self.weights_url, self.local_weights_path, self.model_id]]) > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not worth an extra push, but I think any([bool(param)...])
would have been clearer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code throws an error if 2 or more options are set, but needs to allow the case were zero or one option is set.
Adds the ability to run inference on registered models using the parameter
model_id
.