From eb33fb15ddcbb9e54f4b7a2f484551c44feb8785 Mon Sep 17 00:00:00 2001 From: iLeGend Date: Mon, 4 Mar 2024 21:53:59 +0800 Subject: [PATCH 1/2] Replace the ProcessGroup from torch.distributed to deepspeed.comm --- deepspeed/moe/sharded_moe.py | 3 +-- deepspeed/runtime/comm/coalesced_collectives.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index d92211b9d220..6abc0a4333a3 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -97,8 +97,7 @@ class _AllToAll(torch.autograd.Function): @staticmethod def forward( ctx: Any, - # TODO: replace with DS process group - group: torch.distributed.ProcessGroup, + group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore ctx.group = group input = input.contiguous() diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py index d63d7e985e07..543795126fab 100644 --- a/deepspeed/runtime/comm/coalesced_collectives.py +++ b/deepspeed/runtime/comm/coalesced_collectives.py @@ -12,8 +12,7 @@ import torch from torch import Tensor from deepspeed import comm as dist -# NOTE: Use torch.distributed's ProcessGroup class until we have our own. -from torch.distributed import ProcessGroup, all_to_all_single +from deepspeed.comm import ProcessGroup, all_to_all_single from deepspeed.accelerator import get_accelerator from deepspeed.utils import instrument_w_nvtx from deepspeed.ops import op_builder From 55dbaf54848cefe10b8d6de639ac5d03c9d4bfca Mon Sep 17 00:00:00 2001 From: iLeGend Date: Mon, 4 Mar 2024 14:18:23 +0000 Subject: [PATCH 2/2] fix format --- deepspeed/moe/sharded_moe.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 6abc0a4333a3..e6a5292d7e4f 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -95,10 +95,7 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: class _AllToAll(torch.autograd.Function): @staticmethod - def forward( - ctx: Any, - group: dist.ProcessGroup, - input: Tensor) -> Tensor: # type: ignore + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore ctx.group = group input = input.contiguous() output = torch.empty_like(input)