Skip to content

Commit

Permalink
[NPU] Add NPU support for unit test (#4569)
Browse files Browse the repository at this point in the history
Unit tests would fail or skip when device=npu, and we definitely want to
test all these wonderful features by official unit tests.
Here comes the commit to add NPU support for unit test. P.S. see what we
have already done #4567.


**What I do in this commit**
1. Just add npu logic branch 
feat: Add npu support for skip_on_arch in tests/unit/util.py
feat: Add npu support for skip_on_cuda in tests/unit/util.py
feat: Add npu support for tests/unit/common.py

2. Set_device of accelerator before deepspeed.init_distributed in
tests/unit/common.py
It would be friendlier and easier for other device like npu, if we can
set_device of accelerator before init_distributed. Plus, setting device
param before init sounds more reasonable.

3. Solve the problem of calling get_accelerator().random().fork_rng with
non-cuda device
Function `train_cifar()` in `tests/unit/alexnet_model.py` calls
`get_accelerator().random().fork_rng` without passing `device_type`
explicitly. Unfortunately, `torch.random.fork_rng()` has default value
setting `device_type=cuda` and non-cuda devices would fail to run. So my
solution is explicitly passing
`device_type=get_accelerator().device_name()`, and either cuda or
non-cuda devices would perform correctly.

---------

Co-authored-by: ryan <ruanzhixiang1@huawei.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
4 people authored Nov 13, 2023
1 parent 0a6095f commit 4b7cae7
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 24 deletions.
2 changes: 1 addition & 1 deletion accelerator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# DeepSpeed Team

from .abstract_accelerator import DeepSpeedAccelerator
from .real_accelerator import get_accelerator, set_accelerator
from .real_accelerator import get_accelerator, set_accelerator, is_current_accelerator_supported
27 changes: 15 additions & 12 deletions accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
except ImportError as e:
dsa2 = None

SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps']

ds_accelerator = None


Expand All @@ -34,14 +36,18 @@ def _validate_accelerator(accel_obj):
# accelerator.abstractor_accelerator
# or deepspeed.accelerator.abstract_accelerator, consider accel_obj
# is a conforming object
if not ((dsa1 != None and isinstance(accel_obj, dsa1)) or (dsa2 != None and isinstance(accel_obj, dsa2))):
if not ((dsa1 is not None and isinstance(accel_obj, dsa1)) or (dsa2 is not None and isinstance(accel_obj, dsa2))):
raise AssertionError(f"{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator")

# TODO: turn off is_available test since this breaks tests
# assert accel_obj.is_available(), \
# f'{accel_obj.__class__.__name__} accelerator fails is_available() test'


def is_current_accelerator_supported():
return get_accelerator() in SUPPORTED_ACCELERATOR_LIST


def get_accelerator():
global ds_accelerator
if ds_accelerator is not None:
Expand All @@ -50,7 +56,6 @@ def get_accelerator():
accelerator_name = None
ds_set_method = None
# 1. Detect whether there is override of DeepSpeed accelerators from environment variable.
DS_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps']
if "DS_ACCELERATOR" in os.environ.keys():
accelerator_name = os.environ["DS_ACCELERATOR"]
if accelerator_name == "xpu":
Expand Down Expand Up @@ -79,15 +84,13 @@ def get_accelerator():
torch.mps.current_allocated_memory()
except (RuntimeError, ImportError) as e:
raise ValueError(f"MPS_Accelerator requires torch.mps, which is not installed on this system.")
elif accelerator_name == "cuda":
pass
else:
raise ValueError(
f'DS_ACCELERATOR must be one of {DS_ACCELERATOR_LIST}. Value "{accelerator_name}" is not supported')
elif is_current_accelerator_supported():
raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. '
f'Value "{accelerator_name}" is not supported')
ds_set_method = "override"

# 2. If no override, detect which accelerator to use automatically
if accelerator_name == None:
if accelerator_name is None:
# We need a way to choose among different accelerator types.
# Currently we detect which accelerator extension is installed
# in the environment and use it if the installing answer is True.
Expand All @@ -105,21 +108,21 @@ def get_accelerator():
accelerator_name = "xpu"
except ImportError as e:
pass
if accelerator_name == None:
if accelerator_name is None:
try:
import intel_extension_for_pytorch # noqa: F401,F811 # type: ignore

accelerator_name = "cpu"
except ImportError as e:
pass
if accelerator_name == None:
if accelerator_name is None:
try:
import torch_npu # noqa: F401,F811 # type: ignore

accelerator_name = "npu"
except ImportError as e:
pass
if accelerator_name == None:
if accelerator_name is None:
try:
import torch.mps

Expand All @@ -128,7 +131,7 @@ def get_accelerator():
accelerator_name = "mps"
except (RuntimeError, ImportError) as e:
pass
if accelerator_name == None:
if accelerator_name is None:
accelerator_name = "cuda"

ds_set_method = "auto detect"
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/alexnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def cifar_trainset(fp16=False):


def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123):
with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()]):
with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()],
device_type=get_accelerator().device_name()):
ds_utils.set_random_seed(seed)

# disable dropout
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def set_accelerator_visible():
match = re.search('Device Type.*GPU', line)
if match:
num_accelerators += 1
elif get_accelerator().device_name() == 'npu':
npu_smi = subprocess.check_output(['npu-smi', 'info', '-l'])
num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip())
else:
assert get_accelerator().device_name() == 'cpu'
cpu_sockets = int(
Expand Down Expand Up @@ -204,13 +207,13 @@ def _dist_run(self, local_rank, num_procs, master_port):
if get_accelerator().is_available():
set_accelerator_visible()

if get_accelerator().is_available():
get_accelerator().set_device(local_rank)

if self.init_distributed:
deepspeed.init_distributed(dist_backend=self.backend)
dist.barrier()

if get_accelerator().is_available():
get_accelerator().set_device(local_rank)

try:
self.run(**self._fixture_kwargs)
except BaseException as e:
Expand Down
20 changes: 13 additions & 7 deletions tests/unit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,29 @@

import pytest
import torch
import deepspeed
from deepspeed.accelerator import get_accelerator, is_current_accelerator_supported
from deepspeed.git_version_info import torch_info
from packaging import version as pkg_version


def skip_on_arch(min_arch=7):
if deepspeed.accelerator.get_accelerator().device_name() == 'cuda':
if get_accelerator().device_name() == 'cuda':
if torch.cuda.get_device_capability()[0] < min_arch: #ignore-cuda
pytest.skip(f"needs higher compute capability than {min_arch}")
else:
assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu'
assert is_current_accelerator_supported()
return


def skip_on_cuda(valid_cuda):
split_version = lambda x: map(int, x.split('.')[:2])
if deepspeed.accelerator.get_accelerator().device_name() == 'cuda':
if get_accelerator().device_name() == 'cuda':
CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version'])
CUDA_VERSION = (CUDA_MAJOR * 10) + CUDA_MINOR
if valid_cuda.count(CUDA_VERSION) == 0:
pytest.skip(f"requires cuda versions {valid_cuda}")
else:
assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu'
assert is_current_accelerator_supported()
return


Expand All @@ -43,8 +43,14 @@ def bf16_required_version_check(accelerator_check=True):
else:
accelerator_pass = True

if (TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)) and (CUDA_MAJOR >= 11) and (
NCCL_MAJOR > 2 or (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)) and accelerator_pass:
torch_version_available = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
cuda_version_available = CUDA_MAJOR >= 11
nccl_version_available = NCCL_MAJOR > 2 or (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)
npu_available = get_accelerator().device_name() == 'npu'

if torch_version_available and cuda_version_available and nccl_version_available and accelerator_pass:
return True
elif npu_available:
return True
else:
return False
Expand Down

0 comments on commit 4b7cae7

Please # to comment.