Skip to content

Commit

Permalink
Show warning when importing _pyrannc fails
Browse files Browse the repository at this point in the history
  • Loading branch information
Masahiro Tanaka committed Oct 14, 2022
1 parent 700edbb commit e954d98
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docker/ubuntu/base/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ENV BOOST_VERSION 1.78.0
ENV OPENUCX_VERSION 1.9.0
ENV OPENMPI_VERSION 4.0.7
ENV CONDA_VERSION 4.9.2
ENV PYTORCH_VERSION 1.12.1
ENV PYTORCH_VERSION 1.11.0
ENV NCCL_VERSION 2.14.3-1

SHELL ["/bin/bash", "-c"]
Expand Down
40 changes: 38 additions & 2 deletions pyrannc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
import copy
import inspect
import logging
import sys
from collections import OrderedDict

import torch
import torch.cuda
import torch.onnx.utils
import torch.random

from . import _pyrannc, utils
from .torch_version import BUILD_TORCH_VER

try:
from . import _pyrannc
except ImportError as e:
import re

torch_ver = re.sub(r"\+.*", "", torch.__version__)
build_torch_ver = re.sub(r"\+.*", "", BUILD_TORCH_VER)
if torch_version != build_torch_ver:
print("RaNNC was compiled with PyTorch {}, but the current PyTorch version is {}.".format(
BUILD_TORCH_VER, torch.__version__), file=sys.stderr)
raise e

from . import utils
from .dist_param import store_dist_param, load_dist_param, set_dist_param, get_dist_param_range, set_dist_param_dtype, \
DistributeModelParams
from .opt import patch_optimizer
Expand Down Expand Up @@ -442,6 +457,12 @@ def zero_grad(self):
super().zero_grad()

def get_param(self, name, amp_master_param=False):
r"""
Gets a parameter tensor specified by ``name``.
:param args: Name of a parameter.
:param amp_master_param: Gets Apex amp master parameter if ``True``.
"""
if name not in self.name_to_pid or name not in self.name_to_param:
raise RuntimeError("No parameter found: {}".format(name))

Expand All @@ -450,6 +471,12 @@ def get_param(self, name, amp_master_param=False):
return self.name_to_param[name]

def get_param_grad(self, name, amp_master_param=False):
r"""
Gets the gradient of a parameter tensor specified by ``name``.
:param args: Name of a parameter.
:param amp_master_param: Gets Apex amp master gradient if ``True``.
"""
if name not in self.name_to_pid or name not in self.name_to_param:
raise RuntimeError("No parameter found: {}".format(name))

Expand Down Expand Up @@ -482,6 +509,12 @@ def undeploy(self):
super().undeploy()

def enable_dropout(self, enable):
r"""
Enables/disables dropout layers.
This method is useful for evaluation because model.eval() does not work for a RaNNCModule.
:param enable: Set ``True`` to enable and ``False`` to disable dropout layers.
"""
if self.ready:
super().enable_dropout(enable)
else:
Expand Down Expand Up @@ -565,11 +598,14 @@ def _run_dp_dry(path):


def recreate_all_communicators():
r"""
Destroy and recreate all communicators.
"""
_pyrannc.recreate_all_communicators()


def show_deployment(path, batch_size):
"""
r"""
Show a deployment (Subgraphs and micro-batch sizes in pipeline parallelism) saved in a file.
This is used for debugging.
Expand Down
2 changes: 2 additions & 0 deletions pyrannc/torch_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
BUILD_TORCH_VER = "1.11.0"
BUILD_TORCH_CUDA_VER = "11.3"
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
from distutils.version import LooseVersion

import torch
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext

Expand Down Expand Up @@ -70,6 +71,10 @@ def build_extension(self, ext):
version_nums = os.environ["CUDA_VERSION"].split(".")
VERSION += "+cu{}{}".format(version_nums[0], version_nums[1])

with open('pyrannc/torch_version.py', mode='w') as f:
f.write('BUILD_TORCH_VER="{}"\n'.format(torch.__version__))
f.write('BUILD_TORCH_CUDA_VER="{}"\n'.format(torch.version.cuda))

setup(
name='pyrannc',
packages=find_packages(),
Expand Down

0 comments on commit e954d98

Please # to comment.