diff --git a/op_builder/builder.py b/op_builder/builder.py index 79692ce05878..3613791c938d 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -35,10 +35,19 @@ TORCH_MINOR = int(torch.__version__.split('.')[1]) +class MissingCUDAException(Exception): + pass + + +class CUDAMismatchException(Exception): + pass + + def installed_cuda_version(name=""): import torch.utils.cpp_extension cuda_home = torch.utils.cpp_extension.CUDA_HOME - assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)" + if cuda_home is None: + raise MissingCUDAException("CUDA_HOME does not exist, unable to compile CUDA op(s)") # Ensure there is not a cuda version mismatch between torch and nvcc compiler output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True) output_split = output.split() @@ -89,9 +98,10 @@ def assert_no_cuda_mismatch(name=""): "Detected `DS_SKIP_CUDA_CHECK=1`: Allowing this combination of CUDA, but it may result in unexpected behavior." ) return True - raise Exception(f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " - f"version torch was compiled with {torch.version.cuda}, unable to compile " - "cuda/cpp extensions without a matching cuda version.") + raise CUDAMismatchException( + f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda}, unable to compile " + "cuda/cpp extensions without a matching cuda version.") return True @@ -339,7 +349,7 @@ def is_cuda_enable(self): try: assert_no_cuda_mismatch(self.name) return '-D__ENABLE_CUDA__' - except BaseException: + except MissingCUDAException: print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " "only cpu ops can be compiled!") return '-D__DISABLE_CUDA__' @@ -601,7 +611,7 @@ def builder(self): if not self.is_rocm_pytorch(): assert_no_cuda_mismatch(self.name) self.build_for_cpu = False - except BaseException: + except MissingCUDAException: self.build_for_cpu = True if self.build_for_cpu: