From e954d982e72e845d6c315e7a6d433856fb7167e4 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 14 Oct 2022 17:05:07 +0900 Subject: [PATCH] Show warning when importing _pyrannc fails --- docker/ubuntu/base/Dockerfile | 2 +- pyrannc/__init__.py | 40 +++++++++++++++++++++++++++++++++-- pyrannc/torch_version.py | 2 ++ setup.py | 5 +++++ 4 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 pyrannc/torch_version.py diff --git a/docker/ubuntu/base/Dockerfile b/docker/ubuntu/base/Dockerfile index a5b5dd8..31612ce 100644 --- a/docker/ubuntu/base/Dockerfile +++ b/docker/ubuntu/base/Dockerfile @@ -10,7 +10,7 @@ ENV BOOST_VERSION 1.78.0 ENV OPENUCX_VERSION 1.9.0 ENV OPENMPI_VERSION 4.0.7 ENV CONDA_VERSION 4.9.2 -ENV PYTORCH_VERSION 1.12.1 +ENV PYTORCH_VERSION 1.11.0 ENV NCCL_VERSION 2.14.3-1 SHELL ["/bin/bash", "-c"] diff --git a/pyrannc/__init__.py b/pyrannc/__init__.py index 1d8ac51..c79f513 100644 --- a/pyrannc/__init__.py +++ b/pyrannc/__init__.py @@ -1,6 +1,7 @@ import copy import inspect import logging +import sys from collections import OrderedDict import torch @@ -8,7 +9,21 @@ import torch.onnx.utils import torch.random -from . import _pyrannc, utils +from .torch_version import BUILD_TORCH_VER + +try: + from . import _pyrannc +except ImportError as e: + import re + + torch_ver = re.sub(r"\+.*", "", torch.__version__) + build_torch_ver = re.sub(r"\+.*", "", BUILD_TORCH_VER) + if torch_version != build_torch_ver: + print("RaNNC was compiled with PyTorch {}, but the current PyTorch version is {}.".format( + BUILD_TORCH_VER, torch.__version__), file=sys.stderr) + raise e + +from . import utils from .dist_param import store_dist_param, load_dist_param, set_dist_param, get_dist_param_range, set_dist_param_dtype, \ DistributeModelParams from .opt import patch_optimizer @@ -442,6 +457,12 @@ def zero_grad(self): super().zero_grad() def get_param(self, name, amp_master_param=False): + r""" + Gets a parameter tensor specified by ``name``. + + :param args: Name of a parameter. + :param amp_master_param: Gets Apex amp master parameter if ``True``. + """ if name not in self.name_to_pid or name not in self.name_to_param: raise RuntimeError("No parameter found: {}".format(name)) @@ -450,6 +471,12 @@ def get_param(self, name, amp_master_param=False): return self.name_to_param[name] def get_param_grad(self, name, amp_master_param=False): + r""" + Gets the gradient of a parameter tensor specified by ``name``. + + :param args: Name of a parameter. + :param amp_master_param: Gets Apex amp master gradient if ``True``. + """ if name not in self.name_to_pid or name not in self.name_to_param: raise RuntimeError("No parameter found: {}".format(name)) @@ -482,6 +509,12 @@ def undeploy(self): super().undeploy() def enable_dropout(self, enable): + r""" + Enables/disables dropout layers. + This method is useful for evaluation because model.eval() does not work for a RaNNCModule. + + :param enable: Set ``True`` to enable and ``False`` to disable dropout layers. + """ if self.ready: super().enable_dropout(enable) else: @@ -565,11 +598,14 @@ def _run_dp_dry(path): def recreate_all_communicators(): + r""" + Destroy and recreate all communicators. + """ _pyrannc.recreate_all_communicators() def show_deployment(path, batch_size): - """ + r""" Show a deployment (Subgraphs and micro-batch sizes in pipeline parallelism) saved in a file. This is used for debugging. diff --git a/pyrannc/torch_version.py b/pyrannc/torch_version.py new file mode 100644 index 0000000..62212ff --- /dev/null +++ b/pyrannc/torch_version.py @@ -0,0 +1,2 @@ +BUILD_TORCH_VER = "1.11.0" +BUILD_TORCH_CUDA_VER = "11.3" diff --git a/setup.py b/setup.py index b3edca0..1ad56fa 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,7 @@ import sys from distutils.version import LooseVersion +import torch from setuptools import setup, Extension, find_packages from setuptools.command.build_ext import build_ext @@ -70,6 +71,10 @@ def build_extension(self, ext): version_nums = os.environ["CUDA_VERSION"].split(".") VERSION += "+cu{}{}".format(version_nums[0], version_nums[1]) +with open('pyrannc/torch_version.py', mode='w') as f: + f.write('BUILD_TORCH_VER="{}"\n'.format(torch.__version__)) + f.write('BUILD_TORCH_CUDA_VER="{}"\n'.format(torch.version.cuda)) + setup( name='pyrannc', packages=find_packages(),