Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Enable multi-node training #385

Merged
merged 19 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .amlignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ run_outputs
# Test output from model registration
TestsOutsidePackage/azureml-models
tensorboard_runs
InnerEyeTestVariables.txt
InnerEyePrivateSettings.yml
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,15 @@ created.
## Upcoming

### Added
- ([#385](https://github.com/microsoft/InnerEye-DeepLearning/pull/385)) Add the ability to train a model on multiple
nodes in AzureML. Example: Add `--num_nodes=2` to the commandline arguments to train on 2 nodes.

### Changed
- ([#385](https://github.com/microsoft/InnerEye-DeepLearning/pull/385)) Starting an AzureML run now uses the
`ScriptRunConfig` object, rather than the deprecated `Estimator` object.
- ([#385](https://github.com/microsoft/InnerEye-DeepLearning/pull/385)) When registering a model, the name of the
Python execution environment is added as a tag. This tag is read when running inference, and the execution environment
is re-used.

### Fixed

Expand Down
15 changes: 7 additions & 8 deletions InnerEye/Azure/azure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
import getpass
import logging
import sys
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

import param
from azureml.core import Run, Workspace
from azureml.core import Run, ScriptRunConfig, Workspace
from azureml.core.authentication import InteractiveLoginAuthentication, ServicePrincipalAuthentication
from azureml.train.estimator import MMLBaseEstimator
from azureml.train.hyperdrive import HyperDriveConfig
from git import Repo

Expand Down Expand Up @@ -108,6 +107,8 @@ class AzureConfig(GenericConfig):
max_run_duration: str = param.String(doc="The maximum runtime that is allowed for this job when running in "
"AzureML. This is a floating point number with a string suffix s, m, h, d "
"for seconds, minutes, hours, day. Examples: '3.5h', '2d'")
num_nodes: int = param.Integer(default=1, doc="The number of virtual machines that will be allocated for this"
"job in AzureML.")
_workspace: Workspace = param.ClassSelector(class_=Workspace,
doc="The cached workspace object that has been created in the first"
"call to get_workspace")
Expand Down Expand Up @@ -245,8 +246,8 @@ class SourceConfig:
root_folder: Path
entry_script: Path
conda_dependencies_files: List[Path]
script_params: Optional[Dict[str, str]] = None
hyperdrive_config_func: Optional[Callable[[MMLBaseEstimator], HyperDriveConfig]] = None
script_params: List[str] = field(default_factory=list)
hyperdrive_config_func: Optional[Callable[[ScriptRunConfig], HyperDriveConfig]] = None
upload_timeout_seconds: int = 36000
environment_variables: Optional[Dict[str, str]] = None

Expand All @@ -271,9 +272,7 @@ def set_script_params_except_submit_flag(self) -> None:
else:
retained_args.append(arg)
i = i + 1
# The AzureML documentation says that positional arguments should be passed in using an
# empty string as the value.
self.script_params = {arg: "" for arg in retained_args}
self.script_params = retained_args


@dataclass
Expand Down
205 changes: 137 additions & 68 deletions InnerEye/Azure/azure_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@
# ------------------------------------------------------------------------------------------
import argparse
import getpass
import hashlib
import logging
import os
import signal
import sys
from argparse import ArgumentError, ArgumentParser, Namespace
from datetime import date
from pathlib import Path
from typing import Any, Dict, List, Optional

from azureml.core import Dataset, Experiment, Run
from azureml.core import Dataset, Environment, Experiment, Run, ScriptRunConfig
from azureml.core.conda_dependencies import CondaDependencies
from azureml.core.datastore import Datastore
from azureml.core.runconfig import MpiConfiguration, RunConfiguration
from azureml.core.workspace import WORKSPACE_DEFAULT_BLOB_STORE_NAME
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
from azureml.data import FileDataset
from azureml.train.dnn import PyTorch

from InnerEye.Azure import azure_util
Expand All @@ -34,6 +37,9 @@
INPUT_DATA_KEY = "input_data"

RUN_RECOVERY_FILE = "most_recent_run.txt"
# The version to use when creating an AzureML Python environment. We create all environments with a unique hashed
# name, hence version will always be fixed
ENVIRONMENT_VERSION = "1"


def submit_to_azureml(azure_config: AzureConfig,
Expand Down Expand Up @@ -135,10 +141,10 @@ def create_and_submit_experiment(
workspace = azure_config.get_workspace()
experiment_name = create_experiment_name(azure_config)
exp = Experiment(workspace=workspace, name=azure_util.to_azure_friendly_string(experiment_name))
pt_env = create_pytorch_environment(azure_config, source_config, azure_dataset_id)
script_run_config = create_run_config(azure_config, source_config, azure_dataset_id)

# submit a training/testing run associated with the experiment
run: Run = exp.submit(pt_env)
run: Run = exp.submit(script_run_config)

# set metadata for the run
set_run_tags(run, azure_config, model_config_overrides)
Expand Down Expand Up @@ -172,7 +178,7 @@ def create_and_submit_experiment(


def get_or_create_dataset(azure_config: AzureConfig,
azure_dataset_id: str) -> Dataset:
azure_dataset_id: str) -> FileDataset:
"""
Looks in the AzureML datastore for a dataset of the given name. If there is no such dataset, a dataset is created
and registered, assuming that the files are in a folder that has the same name as the dataset. For example, if
Expand Down Expand Up @@ -213,31 +219,6 @@ def get_or_create_dataset(azure_config: AzureConfig,
return azureml_dataset


def create_pytorch_environment(azure_config: AzureConfig,
source_config: SourceConfig,
azure_dataset_id: str) -> PyTorch:
"""
Creates an Estimator environment required for model execution
:param workspace: The AzureML workspace
:param azure_config: azure related configurations to use for model scaleout behaviour
:param source_config: configurations for model execution, such as name and execution mode
:param azure_dataset_id: The name of the dataset in blob storage to be used for this run.
:return: The configured PyTorch environment to be used for experimentation
"""
azureml_dataset = get_or_create_dataset(azure_config, azure_dataset_id=azure_dataset_id)
if azureml_dataset:
if azure_config.use_dataset_mount:
logging.info("Inside AzureML, the dataset will be provided as a mounted folder.")
estimator_inputs = [azureml_dataset.as_named_input(INPUT_DATA_KEY).as_mount()]
else:
logging.info("Inside AzureML, the dataset will be downloaded before training starts.")
estimator_inputs = [azureml_dataset.as_named_input(INPUT_DATA_KEY).as_download()]
else:
raise ValueError("No AzureML dataset was found.")

return create_estimator_from_configs(azure_config, source_config, estimator_inputs)


def pytorch_version_from_conda_dependencies(conda_dependencies: CondaDependencies) -> Optional[str]:
"""
Given a CondaDependencies object, look for a spec of the form "pytorch=...", and return
Expand All @@ -254,63 +235,121 @@ def pytorch_version_from_conda_dependencies(conda_dependencies: CondaDependencie
return None


def create_estimator_from_configs(azure_config: AzureConfig,
source_config: SourceConfig,
estimator_inputs: List[DatasetConsumptionConfig]) -> PyTorch:
def get_or_create_python_environment(azure_config: AzureConfig,
source_config: SourceConfig,
environment_name: str = "",
register_environment: bool = True) -> Environment:
"""
Create an return a PyTorch estimator from the provided configuration information.
:param azure_config: Azure configuration, used to store various values for the job to be submitted
:param source_config: source configutation, for other needed values
:param estimator_inputs: value for the "inputs" field of the estimator.
:return:
Creates a description for the Python execution environment in AzureML, based on the Conda environment
definition files that are specified in `source_config`. If such environment with this Conda environment already
exists, it is retrieved, otherwise created afresh.
:param azure_config: azure related configurations to use for model scale-out behaviour
:param source_config: configurations for model execution, such as name and execution mode
:param environment_name: If specified, try to retrieve the existing Python environment with this name. If that
is not found, create one from the Conda files provided. This parameter is meant to be used when running
inference for an existing model.
:param register_environment: If True, the Python environment will be registered in the AzureML workspace. If
False, it will only be created, but not registered. Use this for unit testing.
"""
# AzureML seems to sometimes expect the entry script path in Linux format, hence convert to posix path
entry_script_relative_path = source_config.entry_script.relative_to(source_config.root_folder).as_posix()
logging.info(f"Entry script {entry_script_relative_path} ({source_config.entry_script} relative to "
f"source directory {source_config.root_folder})")
environment_variables = {
"AZUREML_OUTPUT_UPLOAD_TIMEOUT_SEC": str(source_config.upload_timeout_seconds),
"MKL_SERVICE_FORCE_INTEL": "1",
**(source_config.environment_variables or {})
}
# Merge the project-specific dependencies with the packages that InnerEye itself needs. This should not be
# necessary if the innereye package is installed. It is necessary when working with an outer project and
# InnerEye as a git submodule and submitting jobs from the local machine.
# In case of version conflicts, the package version in the outer project is given priority.
conda_dependencies = merge_conda_dependencies(source_config.conda_dependencies_files) # type: ignore
conda_dependencies, merged_yaml = merge_conda_dependencies(source_config.conda_dependencies_files) # type: ignore
if azure_config.pip_extra_index_url:
# When an extra-index-url is supplied, swap the order in which packages are searched for.
# This is necessary if we need to consume packages from extra-index that clash with names of packages on
# pypi
conda_dependencies.set_pip_option(f"--index-url {azure_config.pip_extra_index_url}")
conda_dependencies.set_pip_option("--extra-index-url https://pypi.org/simple")
# create Estimator environment
framework_version = pytorch_version_from_conda_dependencies(conda_dependencies)
assert framework_version is not None, "The AzureML SDK is behind PyTorch, it does not yet know the version we use."
logging.info(f"PyTorch framework version: {framework_version}")
env_variables = {
"AZUREML_OUTPUT_UPLOAD_TIMEOUT_SEC": str(source_config.upload_timeout_seconds),
"MKL_SERVICE_FORCE_INTEL": "1",
**(source_config.environment_variables or {})
}
base_image = "mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04"
# Create a name for the environment that will likely uniquely identify it. AzureML does hashing on top of that,
# and will re-use existing environments even if they don't have the same name.
# Hashing should include everything that can reasonably change. Rely on hashlib here, because the built-in
# hash function gives different results for the same string in different python instances.
hash_string = "\n".join([merged_yaml, azure_config.docker_shm_size, base_image, str(env_variables)])
sha1 = hashlib.sha1(hash_string.encode("utf8"))
overall_hash = sha1.hexdigest()[:32]
unique_env_name = f"InnerEye-{overall_hash}"
try:
env_name_to_find = environment_name or unique_env_name
env = Environment.get(azure_config.get_workspace(), name=env_name_to_find, version=ENVIRONMENT_VERSION)
logging.info(f"Using existing Python environment '{env.name}'.")
return env
except Exception:
logging.info(f"Python environment '{unique_env_name}' does not yet exist, creating and registering it.")
env = Environment(name=unique_env_name)
env.docker.enabled = True
env.docker.shm_size = azure_config.docker_shm_size
env.python.conda_dependencies = conda_dependencies
env.docker.base_image = base_image
env.environment_variables = env_variables
if register_environment:
env.register(azure_config.get_workspace())
return env


def create_run_config(azure_config: AzureConfig,
source_config: SourceConfig,
azure_dataset_id: str = "",
environment_name: str = "") -> ScriptRunConfig:
"""
Creates a configuration to run the InnerEye training script in AzureML.
:param azure_config: azure related configurations to use for model scale-out behaviour
:param source_config: configurations for model execution, such as name and execution mode
:param azure_dataset_id: The name of the dataset in blob storage to be used for this run. This can be an empty
string to not use any datasets.
:param environment_name: If specified, try to retrieve the existing Python environment with this name. If that
is not found, create one from the Conda files provided in `source_config`. This parameter is meant to be used
when running inference for an existing model.
:return: The configured script run.
"""
if azure_dataset_id:
azureml_dataset = get_or_create_dataset(azure_config, azure_dataset_id=azure_dataset_id)
if not azureml_dataset:
raise ValueError(f"AzureML dataset {azure_dataset_id} could not be found or created.")
named_input = azureml_dataset.as_named_input(INPUT_DATA_KEY)
dataset_consumption = named_input.as_mount() if azure_config.use_dataset_mount else named_input.as_download()
else:
dataset_consumption = None
# AzureML seems to sometimes expect the entry script path in Linux format, hence convert to posix path
entry_script_relative_path = source_config.entry_script.relative_to(source_config.root_folder).as_posix()
logging.info(f"Entry script {entry_script_relative_path} ({source_config.entry_script} relative to "
f"source directory {source_config.root_folder})")
max_run_duration = None
if azure_config.max_run_duration:
max_run_duration = run_duration_string_to_seconds(azure_config.max_run_duration)
workspace = azure_config.get_workspace()
estimator = PyTorch(
run_config = RunConfiguration(
script=entry_script_relative_path,
arguments=source_config.script_params,
)
run_config.environment = get_or_create_python_environment(azure_config, source_config,
environment_name=environment_name)
run_config.target = azure_config.cluster
run_config.max_run_duration_seconds = max_run_duration
if azure_config.num_nodes > 1:
distributed_job_config = MpiConfiguration(node_count=azure_config.num_nodes)
run_config.mpi = distributed_job_config
run_config.framework = "Python"
run_config.communicator = "IntelMpi"
run_config.node_count = distributed_job_config.node_count
if dataset_consumption:
run_config.data = {dataset_consumption.name: dataset_consumption}
# Use blob storage for storing the source, rather than the FileShares section of the storage account.
run_config.source_directory_data_store = workspace.datastores.get(WORKSPACE_DEFAULT_BLOB_STORE_NAME).name
script_run_config = ScriptRunConfig(
source_directory=str(source_config.root_folder),
entry_script=entry_script_relative_path,
script_params=source_config.script_params,
compute_target=azure_config.cluster,
# Use blob storage for storing the source, rather than the FileShares section of the storage account.
source_directory_data_store=workspace.datastores.get(WORKSPACE_DEFAULT_BLOB_STORE_NAME),
inputs=estimator_inputs,
environment_variables=environment_variables,
shm_size=azure_config.docker_shm_size,
use_docker=True,
use_gpu=True,
framework_version=framework_version,
max_run_duration_seconds=max_run_duration
run_config=run_config,
)
estimator.run_config.environment.python.conda_dependencies = conda_dependencies
if azure_config.hyperdrive:
estimator = source_config.hyperdrive_config_func(estimator) # type: ignore
return estimator
script_run_config = source_config.hyperdrive_config_func(script_run_config) # type: ignore
return script_run_config


def create_runner_parser(model_config_class: type = None) -> argparse.ArgumentParser:
Expand Down Expand Up @@ -452,3 +491,33 @@ def run_duration_string_to_seconds(s: str) -> Optional[int]:
else:
raise ArgumentError("s", f"Invalid suffix: Must be one of 's', 'm', 'h', 'd', but got: {s}")
return int(float(s[:-1]) * multiplier)


def set_environment_variables_for_multi_node() -> None:
"""
Sets the environment variables that PyTorch Lightning needs for multi-node training.
"""
az_master_node = "AZ_BATCHAI_MPI_MASTER_NODE"
master_addr = "MASTER_ADDR"
master_ip = "MASTER_IP"
master_port = "MASTER_PORT"
world_rank = "OMPI_COMM_WORLD_RANK"
node_rank = "NODE_RANK"

if az_master_node in os.environ:
# For AML BATCHAI
os.environ[master_addr] = os.environ[az_master_node]
elif master_ip in os.environ:
# AKS
os.environ[master_addr] = os.environ[master_ip]
else:
logging.info("No settings for the MPI central node found. Assuming that this is a single node training job.")
return

if master_port not in os.environ:
os.environ[master_port] = "6105"

if world_rank in os.environ:
os.environ[node_rank] = os.environ[world_rank] # node rank is the world_rank from mpi run
for var in [master_addr, master_port, node_rank]:
print(f"Distributed training: {var} = {os.environ[var]}")
Loading