Skip to content

Commit 176f380

Browse files
j0rd1smitcarmoccajustusschockpre-commit-ci[bot]Borda
authored andcommittedAug 27, 2022
Disable non blocking to device with MPS (#14368)
* disable non-blocking for mps due to race condition bug * fixed typo * fixed: unknown mps device for non arm systems * Removed unrobust test case * moved _MPS_DEVICES such that we used in apply_func * Resolve circular dependencies * Comment rewording * changed torchElasticEnvironment to a global import * simplified if statement to blocking device type * Added change to CHANGELOG * Update src/pytorch_lightning/utilities/apply_func.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed mypy not detecting casting of device * Moved check into if statement to mainain original behavior Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
1 parent 30d51e7 commit 176f380

File tree

7 files changed

+16
-9
lines changed

7 files changed

+16
-9
lines changed
 

Diff for: ‎src/pytorch_lightning/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Fixed `LightningDataModule` hparams parsing ([#12806](https://github.com/PyTorchLightning/pytorch-lightning/pull/12806))
1616
- Reset epoch progress with batch size scaler ([#13846](https://github.com/Lightning-AI/lightning/pull/13846)
1717
- Fixed restoring the trainer after using `lr_find()` so that the correct LR schedule is used for the actual training ([#14113](https://github.com/Lightning-AI/lightning/pull/14113))
18+
- Fixed incorrect values after transferring data to a MPS device ([#13285](https://github.com/Lightning-AI/lightning/issues/13285))
1819

1920

2021
## [1.7.3] - 2022-08-25

Diff for: ‎src/pytorch_lightning/accelerators/cpu.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717

1818
from pytorch_lightning.accelerators.accelerator import Accelerator
19-
from pytorch_lightning.utilities import device_parser
19+
from pytorch_lightning.utilities.device_parser import parse_cpu_cores
2020
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2121
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE
2222
from pytorch_lightning.utilities.types import _DEVICE
@@ -42,13 +42,13 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
4242
@staticmethod
4343
def parse_devices(devices: Union[int, str, List[int]]) -> int:
4444
"""Accelerator device parsing logic."""
45-
devices = device_parser.parse_cpu_cores(devices)
45+
devices = parse_cpu_cores(devices)
4646
return devices
4747

4848
@staticmethod
4949
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
5050
"""Gets parallel devices for the Accelerator."""
51-
devices = device_parser.parse_cpu_cores(devices)
51+
devices = parse_cpu_cores(devices)
5252
return [torch.device("cpu")] * devices
5353

5454
@staticmethod

Diff for: ‎src/pytorch_lightning/accelerators/hpu.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
import torch
1818

1919
from pytorch_lightning.accelerators.accelerator import Accelerator
20-
from pytorch_lightning.utilities import _HPU_AVAILABLE, device_parser
20+
from pytorch_lightning.utilities.device_parser import parse_hpus
2121
from pytorch_lightning.utilities.exceptions import MisconfigurationException
22+
from pytorch_lightning.utilities.imports import _HPU_AVAILABLE
2223
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
2324

2425
if _HPU_AVAILABLE:
@@ -61,7 +62,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
6162
@staticmethod
6263
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[int]:
6364
"""Accelerator device parsing logic."""
64-
return device_parser.parse_hpus(devices)
65+
return parse_hpus(devices)
6566

6667
@staticmethod
6768
def get_parallel_devices(devices: int) -> List[torch.device]:

Diff for: ‎src/pytorch_lightning/accelerators/ipu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717

1818
from pytorch_lightning.accelerators.accelerator import Accelerator
19-
from pytorch_lightning.utilities import _IPU_AVAILABLE
19+
from pytorch_lightning.utilities.imports import _IPU_AVAILABLE
2020

2121

2222
class IPUAccelerator(Accelerator):

Diff for: ‎src/pytorch_lightning/plugins/environments/xla_environment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616

1717
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
18-
from pytorch_lightning.utilities import _TPU_AVAILABLE
18+
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
1919

2020
if _TPU_AVAILABLE:
2121
import torch_xla.core.xla_env_vars as xenv

Diff for: ‎src/pytorch_lightning/utilities/apply_func.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
Batch = type(None)
3939

4040

41-
_CPU_DEVICES = ("cpu", torch.device("cpu"))
41+
_BLOCKING_DEVICE_TYPES = ("cpu", "mps")
4242

4343

4444
def to_dtype_tensor(
@@ -322,6 +322,9 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
322322
- :class:`torch.device`
323323
"""
324324

325+
if isinstance(device, str):
326+
device = torch.device(device)
327+
325328
def batch_to(data: Any) -> Any:
326329
# try to move torchtext data first
327330
if _TORCHTEXT_LEGACY and isinstance(data, Batch):
@@ -342,7 +345,8 @@ def batch_to(data: Any) -> Any:
342345

343346
kwargs = {}
344347
# Don't issue non-blocking transfers to CPU
345-
if isinstance(data, Tensor) and device not in _CPU_DEVICES:
348+
# Same with MPS due to a race condition bug: https://github.com/pytorch/pytorch/issues/83015
349+
if isinstance(data, Tensor) and isinstance(device, torch.device) and device.type not in _BLOCKING_DEVICE_TYPES:
346350
kwargs["non_blocking"] = True
347351
data_output = data.to(device, **kwargs)
348352
if data_output is not None:

Diff for: ‎src/pytorch_lightning/utilities/device_parser.py

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def parse_gpu_ids(
110110
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps)
111111
if not gpus:
112112
raise MisconfigurationException("GPUs requested but none are available.")
113+
113114
if (
114115
TorchElasticEnvironment.detect()
115116
and len(gpus) != 1

0 commit comments

Comments
 (0)