diff --git a/mmcls/apis/train.py b/mmcls/apis/train.py index 40e9531d400..909b116d163 100644 --- a/mmcls/apis/train.py +++ b/mmcls/apis/train.py @@ -131,7 +131,6 @@ def train_model(model, model = wrap_distributed_model( model, cfg.device, - device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: @@ -173,6 +172,10 @@ def train_model(model, # fp16 setting fp16_cfg = cfg.get('fp16', None) + + if fp16_cfg is None and device == 'npu': + fp16_cfg = {'loss_scale': 'dynamic'} + if fp16_cfg is not None: if device == 'ipu': from mmcv.device.ipu import IPUFp16OptimizerHook diff --git a/mmcls/datasets/samplers/distributed_sampler.py b/mmcls/datasets/samplers/distributed_sampler.py index a38c5ac1d58..9e78c400693 100644 --- a/mmcls/datasets/samplers/distributed_sampler.py +++ b/mmcls/datasets/samplers/distributed_sampler.py @@ -4,6 +4,7 @@ from mmcls.core.utils import sync_random_seed from mmcls.datasets import SAMPLERS +from mmcls.utils import auto_select_device @SAMPLERS.register_module() @@ -30,7 +31,7 @@ def __init__(self, # in the same order based on the same seed. Then different ranks # could use different indices to select non-overlapped data from the # same data list. - self.seed = sync_random_seed(seed) + self.seed = sync_random_seed(seed, device=auto_select_device()) def __iter__(self): # deterministically shuffle based on epoch diff --git a/mmcls/utils/distribution.py b/mmcls/utils/distribution.py index e3da97ee9a6..d57bd2b53ba 100644 --- a/mmcls/utils/distribution.py +++ b/mmcls/utils/distribution.py @@ -16,7 +16,10 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs): Returns: model(nn.Module): the model to be parallelized. """ - if device == 'cuda': + if device == 'npu': + from mmcv.device.npu import NPUDataParallel + model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs) + elif device == 'cuda': from mmcv.parallel import MMDataParallel model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs) elif device == 'cpu': @@ -49,9 +52,16 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs): .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. DistributedDataParallel.html """ - if device == 'cuda': + if device == 'npu': + from mmcv.device.npu import NPUDistributedDataParallel + from torch.npu import current_device + model = NPUDistributedDataParallel( + model.npu(), *args, device_ids=[current_device()], **kwargs) + elif device == 'cuda': from mmcv.parallel import MMDistributedDataParallel - model = MMDistributedDataParallel(model.cuda(), *args, **kwargs) + from torch.cuda import current_device + model = MMDistributedDataParallel( + model.cuda(), *args, device_ids=[current_device()], **kwargs) else: raise RuntimeError(f'Unavailable device "{device}"')