From bf5d7c3fa37b6fa530642887a4d8f8548fd2db2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=A9tan=20Lepage?= <33058747+GaetanLepage@users.noreply.github.com> Date: Tue, 10 Dec 2024 18:19:30 +0100 Subject: [PATCH] Only import torch.distributed if it is available (#35133) --- src/transformers/pytorch_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 5bdf8a355ddf..fab1b9118d18 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -38,8 +38,10 @@ is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13") is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") +# Cache this result has it's a C FFI call which can be pretty time-consuming +_torch_distributed_available = torch.distributed.is_available() -if is_torch_greater_or_equal("2.5"): +if is_torch_greater_or_equal("2.5") and _torch_distributed_available: from torch.distributed.tensor import Replicate from torch.distributed.tensor.parallel import ( ColwiseParallel,