Skip to content

Commit

Permalink
[Feature] Support mmcls with NPU backend. (#1072)
Browse files Browse the repository at this point in the history
* init npu

* Avoid to import latest MMCV code to be compatible with old verisons.

Co-authored-by: mzr1996 <mzr1996@163.com>
  • Loading branch information
wangjiangben-hw and mzr1996 authored Oct 24, 2022
1 parent 38040d5 commit 17ed870
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
5 changes: 4 additions & 1 deletion mmcls/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def train_model(model,
model = wrap_distributed_model(
model,
cfg.device,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
Expand Down Expand Up @@ -173,6 +172,10 @@ def train_model(model,

# fp16 setting
fp16_cfg = cfg.get('fp16', None)

if fp16_cfg is None and device == 'npu':
fp16_cfg = {'loss_scale': 'dynamic'}

if fp16_cfg is not None:
if device == 'ipu':
from mmcv.device.ipu import IPUFp16OptimizerHook
Expand Down
3 changes: 2 additions & 1 deletion mmcls/datasets/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from mmcls.core.utils import sync_random_seed
from mmcls.datasets import SAMPLERS
from mmcls.utils import auto_select_device


@SAMPLERS.register_module()
Expand All @@ -30,7 +31,7 @@ def __init__(self,
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
self.seed = sync_random_seed(seed, device=auto_select_device())

def __iter__(self):
# deterministically shuffle based on epoch
Expand Down
16 changes: 13 additions & 3 deletions mmcls/utils/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
Returns:
model(nn.Module): the model to be parallelized.
"""
if device == 'cuda':
if device == 'npu':
from mmcv.device.npu import NPUDataParallel
model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs)
elif device == 'cuda':
from mmcv.parallel import MMDataParallel
model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
elif device == 'cpu':
Expand Down Expand Up @@ -49,9 +52,16 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs):
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
if device == 'cuda':
if device == 'npu':
from mmcv.device.npu import NPUDistributedDataParallel
from torch.npu import current_device
model = NPUDistributedDataParallel(
model.npu(), *args, device_ids=[current_device()], **kwargs)
elif device == 'cuda':
from mmcv.parallel import MMDistributedDataParallel
model = MMDistributedDataParallel(model.cuda(), *args, **kwargs)
from torch.cuda import current_device
model = MMDistributedDataParallel(
model.cuda(), *args, device_ids=[current_device()], **kwargs)
else:
raise RuntimeError(f'Unavailable device "{device}"')

Expand Down

0 comments on commit 17ed870

Please # to comment.