From d522b04f439f320290504ceec981ba550ec51b27 Mon Sep 17 00:00:00 2001 From: Yan Ma Date: Mon, 17 Feb 2025 20:59:18 +0800 Subject: [PATCH] [Bugfix] fix xpu communicator (#13368) Signed-off-by: yan ma --- .../device_communicators/xpu_communicator.py | 54 +++++++++++++++++++ vllm/platforms/xpu.py | 4 ++ 2 files changed, 58 insertions(+) create mode 100644 vllm/distributed/device_communicators/xpu_communicator.py diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py new file mode 100644 index 0000000000000..256e7965e0a72 --- /dev/null +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from .base_device_communicator import DeviceCommunicatorBase + + +class XpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + + def all_reduce(self, input_) -> torch.Tensor: + dist.all_reduce(input_, group=self.device_group) + return input_ + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # For xpu path, gather doesn't work properly together with ray + # cluster so we use all_gather instead for now. + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((self.world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + if self.rank_in_group == dst: + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + else: + output_tensor = None + return output_tensor diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 81bc85f9415e8..04af319566af5 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -135,3 +135,7 @@ def device_support_bf16(cls) -> bool: logger.warning("Unknown device name %s, always use float16", device_name) return False + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa