From feb505c2525fa2f9a1fd7dc80ced11ba64cb0d19 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 8 Aug 2024 15:32:14 +0800 Subject: [PATCH 1/4] Mingzhi/topk prob (#28) * [MoE] enable topk > 2 gate * print_version * refine code * deepspeed/moe/sharded_moe.py * func verify * refine code * refine code * refine code * refine code * refine code * remove duplicate topk * update * refine code * fix format * update * fix == * update * add ut * rm tt * update * add top3 ut * revert note * remove - --------- Co-authored-by: Kurt Chen Co-authored-by: Jin, Youzhi --- deepspeed/moe/sharded_moe.py | 98 ++++++++++++++++++++++++++++++++++-- tests/unit/moe/test_moe.py | 53 ++++++++++++++++++- 2 files changed, 145 insertions(+), 6 deletions(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 416f01b82e3d..d7b118f411cd 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -124,6 +124,8 @@ def einsum(rule, a, b): return a.unsqueeze(2) * b.unsqueeze(1) elif rule == 'se,se->s': return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) + elif rule == 'se,sec->sec': + return a.unsqueeze(2) * b elif rule == 'sec,sm->ecm': s = a.shape[0] e = a.shape[1] @@ -191,8 +193,8 @@ def top1gating(logits: Tensor, if noisy_gate_policy == 'RSample': logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) # everything is in fp32 in this function - gates = F.softmax(logits, dim=1) + gates = F.softmax(logits, dim=1) capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity)) # Create a mask for 1st's expert per token @@ -298,6 +300,7 @@ def top2gating(logits: Tensor, # Create a mask for 1st's expert per token indices1_s = torch.argmax(gates, dim=1) num_experts = int(gates.shape[1]) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) if top2_2nd_expert_sampling: @@ -369,6 +372,87 @@ def top2gating(logits: Tensor, return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu') +def topkgating( + logits: Tensor, + k: int, + capacity_factor: float, + min_capacity: int, + drop_tokens: bool = True, + ep_group: Union[torch.distributed.ProcessGroup, None] = None, + drop_policy: str = "probs", +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Implements TopKGating on logits.""" + + # everything is in fp32 in this function + # get topk gates + top_gate, top_idx = torch.topk(logits, k=k, dim=1) + gates = F.softmax(logits, dim=1) + num_experts = int(gates.shape[1]) + + # get topk mask + topk_masked_gates = torch.zeros_like(logits).scatter(1, top_idx, top_gate) + + mask = torch.zeros_like(gates, dtype=torch.int64).scatter_(1, top_idx, 1) + + # Compute tokens per expert + exp_counts = torch.sum(mask, dim=0).detach() + + # gating decisions + exp_counts = torch.sum(mask, dim=0).detach().to(logits.device) + + # Compute l_aux + me = torch.mean(gates, dim=0) + # HPU Enable Begin + ce = torch.mean(mask.float(), dim=0, dtype=torch.float) + # HPU Enable End + l_aux = torch.mean(me * ce) * num_experts * num_experts / k + + if drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity)) + # update mask and locations by capacity + # mask *= torch.lt(locations, capacity) + if drop_policy == 'probs': + capacity_probs, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False) + capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1) + mask = torch.logical_and(mask, capacity_mask) + locations = torch.cumsum(mask, dim=0) - 1 + + elif drop_policy == "position": + locations = torch.cumsum(mask, dim=0) - 1 + mask *= torch.lt(locations, capacity) + else: + raise ValueError(f"Invalid drop_policy: {drop_policy}") + + else: + # Do not drop tokens - set capacity according to current expert assignments + new_capacity = torch.max(exp_counts) + if ep_group is not None: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) + if groups._get_expert_model_parallel_world_size() == 1: + # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. + # This is since we are going to activate drop_tokens() to drop duplicate tokens. + tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu) + new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) + capacity = new_capacity + + + # normalize gates + gates_masked = gates * mask + gates_s = torch.sum(gates_masked, dim=-1, keepdim=True) + denom_s = torch.clamp(gates_s, min=torch.finfo(gates_masked.dtype).eps) + gates_masked = gates_masked / denom_s + + # dispatch_mask + locations_sc = _one_hot_to_float((locations * mask), capacity) + + combine_weights = torch.einsum("se,sec->sec", gates_masked, locations_sc) + + dispatch_mask = combine_weights.bool() + + return l_aux, combine_weights, dispatch_mask, exp_counts + + class TopKGate(Module): """Gate module which implements Top2Gating as described in Gshard_. :: @@ -402,8 +486,10 @@ def __init__(self, super().__init__() # Only top-1 and top-2 are supported at the moment. - if k != 1 and k != 2: - raise ValueError('Only top-1 and top-2 gatings are supported.') + #if k != 1 and k != 2: + # raise ValueError('Only top-1 and top-2 gatings are supported.') + # HPU Enable Begin + # self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) self.ep_group = ep_group self.k = k @@ -441,9 +527,13 @@ def forward(self, self.min_capacity, used_token, self.noisy_gate_policy if self.training else None, self.drop_tokens, self.use_rts, self.ep_group, use_tutel) - else: + elif self.k == 2: gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling) + else: + gate_output = topkgating(logits, self.k, + self.capacity_factor if self.training else self.eval_capacity_factor, + self.min_capacity, self.drop_tokens, self.ep_group) if self.wall_clock_breakdown: self.timers(TOPK_GATE_TIMER).stop() diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index fdff9430a4e6..e92106aee239 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -7,11 +7,13 @@ import deepspeed import pytest import gc + +import deepspeed.comm as dist from unit.common import DistributedTest from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader import deepspeed.comm as dist from deepspeed import get_accelerator -from deepspeed.moe.sharded_moe import top1gating +from deepspeed.moe.sharded_moe import top1gating,topkgating from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param from deepspeed.utils.torch import required_torch_version @@ -172,7 +174,6 @@ def test(self, ep_size, use_residual): model.backward(loss) model.step() - class TestTopk(DistributedTest): world_size = 2 @@ -191,3 +192,51 @@ def test(self): drop_tokens=False, use_rts=True, use_tutel=False) + +class TestTopkGate(DistributedTest): + def test(self): + def check_equal(logits,cap,sparse_truth,res): + m,n=logits.shape + dispatch_mask_truth=torch.zeros(m,n,cap) + i,j,k=sparse_truth.t() + dispatch_mask_truth[i,j,k]=1 + assert(torch.equal(dispatch_mask_truth, res)) + #s=4 e=4 topk=2 cap=2(s*topk/e) + logits=torch.tensor([ + [0.11,0.2,0.1,0.3], + [0.3,0.4,0.11,0.1], + [0.11,0.1,0.6,0.5], + [0.1,0.11,0.7,0.8]]) + logits*=dist.get_rank()+1 + probs_dispatch_res=topkgating(logits,2,1,min_capacity=1,drop_policy='probs')[2] + probs_sec_sparse = torch.tensor([[0, 1, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 0], [3, 2, 1], [3, 3, 1]]) + check_equal(logits,2,probs_sec_sparse,probs_dispatch_res) + + + position_sec_sparse =torch.tensor([[0,1,0],[0,3,0],[1,0,0],[1,1,1],[2,2,0],[2,3,1],[3,2,1]]) + position_dispatch_res=topkgating(logits,2,1,min_capacity=1,drop_policy='position')[2] + check_equal(logits,2,position_sec_sparse,position_dispatch_res) + + + #s=4 e=6 topk=3 cap=2(s*topk/e) + logits2=torch.tensor([[0.5858, 0.4801, 0.6269, 0.5397, 0.9722, 0.7034], + [0.5445, 0.6332, 0.4519, 0.6308, 0.0519, 0.6450], + [0.4874, 0.8110, 0.7467, 0.8474, 0.0277, 0.3068], + [0.8570, 0.6714, 0.5310, 0.3274, 0.4836, 0.9892]]) + logits2*=dist.get_rank()+1 + + #top3 full mask #prob_mask #postion_mask + #0 0 1 0 1 1 #0 0 1 0 1 1 #0 0 1 0 1 1 + #0 1 0 1 0 1 #0 0 0 1 0 0 #0 1 0 1 0 1 + #0 1 1 1 0 0 #0 1 1 1 0 0 #0 1 1 1 0 0 + #1 1 0 0 0 1 #1 1 0 0 0 1 #1 0 0 0 0 0 + probs_dispatch_res=topkgating(logits2,3,1,min_capacity=1,drop_policy='probs')[2] + probs_sec_sparse = torch.tensor([[0,2,0],[0,4,0],[0,5,0],[1,3,0],[2,1,0],[2,2,1],[2,3,1],[3,0,0],[3,1,1],[3,5,1]]) + check_equal(logits2,2, probs_sec_sparse,probs_dispatch_res) + + position_sec_sparse =torch.tensor([[0,2,0],[0,4,0],[0,5,0],[1,1,0],[1,3,0],[1,5,1],[2,1,1],[2,2,1],[2,3,1],[3,0,0]]) + position_dispatch_res=topkgating(logits2,3,1,min_capacity=1,drop_policy='position')[2] + check_equal(logits2,2,position_sec_sparse,position_dispatch_res) + + + From 153af8e2c7d84391a4b67754b67c2bedb19b5196 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 8 Aug 2024 09:29:59 +0000 Subject: [PATCH 2/4] refine --- deepspeed/moe/sharded_moe.py | 21 +++------- tests/unit/moe/test_moe.py | 76 ++++++++++++++++++------------------ 2 files changed, 43 insertions(+), 54 deletions(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index d7b118f411cd..b32a6ce6e9b4 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -386,32 +386,27 @@ def topkgating( # everything is in fp32 in this function # get topk gates top_gate, top_idx = torch.topk(logits, k=k, dim=1) + # gating decisions gates = F.softmax(logits, dim=1) num_experts = int(gates.shape[1]) # get topk mask topk_masked_gates = torch.zeros_like(logits).scatter(1, top_idx, top_gate) - mask = torch.zeros_like(gates, dtype=torch.int64).scatter_(1, top_idx, 1) - - # Compute tokens per expert - exp_counts = torch.sum(mask, dim=0).detach() + mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(1, top_idx, 1) - # gating decisions exp_counts = torch.sum(mask, dim=0).detach().to(logits.device) # Compute l_aux me = torch.mean(gates, dim=0) - # HPU Enable Begin - ce = torch.mean(mask.float(), dim=0, dtype=torch.float) - # HPU Enable End + ce = torch.mean(mask.float(), dim=0) l_aux = torch.mean(me * ce) * num_experts * num_experts / k if drop_tokens: # Calculate configured capacity and remove locations outside capacity from mask capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity)) # update mask and locations by capacity - # mask *= torch.lt(locations, capacity) + if drop_policy == 'probs': capacity_probs, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False) capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1) @@ -421,7 +416,7 @@ def topkgating( elif drop_policy == "position": locations = torch.cumsum(mask, dim=0) - 1 mask *= torch.lt(locations, capacity) - else: + else: raise ValueError(f"Invalid drop_policy: {drop_policy}") else: @@ -436,7 +431,6 @@ def topkgating( new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) capacity = new_capacity - # normalize gates gates_masked = gates * mask gates_s = torch.sum(gates_masked, dim=-1, keepdim=True) @@ -485,11 +479,6 @@ def __init__(self, top2_2nd_expert_sampling: bool = True) -> None: super().__init__() - # Only top-1 and top-2 are supported at the moment. - #if k != 1 and k != 2: - # raise ValueError('Only top-1 and top-2 gatings are supported.') - # HPU Enable Begin - # self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) self.ep_group = ep_group self.k = k diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index e92106aee239..28caaa428e6c 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -13,7 +13,7 @@ from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader import deepspeed.comm as dist from deepspeed import get_accelerator -from deepspeed.moe.sharded_moe import top1gating,topkgating +from deepspeed.moe.sharded_moe import top1gating, topkgating from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param from deepspeed.utils.torch import required_torch_version @@ -174,6 +174,7 @@ def test(self, ep_size, use_residual): model.backward(loss) model.step() + class TestTopk(DistributedTest): world_size = 2 @@ -192,51 +193,50 @@ def test(self): drop_tokens=False, use_rts=True, use_tutel=False) - + + class TestTopkGate(DistributedTest): + def test(self): - def check_equal(logits,cap,sparse_truth,res): - m,n=logits.shape - dispatch_mask_truth=torch.zeros(m,n,cap) - i,j,k=sparse_truth.t() - dispatch_mask_truth[i,j,k]=1 - assert(torch.equal(dispatch_mask_truth, res)) + + def check_equal(logits, cap, sparse_truth, res): + m, n = logits.shape + dispatch_mask_truth = torch.zeros(m, n, cap) + i, j, k = sparse_truth.t() + dispatch_mask_truth[i, j, k] = 1 + assert (torch.equal(dispatch_mask_truth, res)) + #s=4 e=4 topk=2 cap=2(s*topk/e) - logits=torch.tensor([ - [0.11,0.2,0.1,0.3], - [0.3,0.4,0.11,0.1], - [0.11,0.1,0.6,0.5], - [0.1,0.11,0.7,0.8]]) - logits*=dist.get_rank()+1 - probs_dispatch_res=topkgating(logits,2,1,min_capacity=1,drop_policy='probs')[2] + logits = torch.tensor([[0.11, 0.2, 0.1, 0.3], [0.3, 0.4, 0.11, 0.1], [0.11, 0.1, 0.6, 0.5], + [0.1, 0.11, 0.7, 0.8]]) + logits *= dist.get_rank() + 1 + probs_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='probs')[2] probs_sec_sparse = torch.tensor([[0, 1, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 0], [3, 2, 1], [3, 3, 1]]) - check_equal(logits,2,probs_sec_sparse,probs_dispatch_res) - - - position_sec_sparse =torch.tensor([[0,1,0],[0,3,0],[1,0,0],[1,1,1],[2,2,0],[2,3,1],[3,2,1]]) - position_dispatch_res=topkgating(logits,2,1,min_capacity=1,drop_policy='position')[2] - check_equal(logits,2,position_sec_sparse,position_dispatch_res) - + check_equal(logits, 2, probs_sec_sparse, probs_dispatch_res) + + position_sec_sparse = torch.tensor([[0, 1, 0], [0, 3, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 1], + [3, 2, 1]]) + position_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='position')[2] + check_equal(logits, 2, position_sec_sparse, position_dispatch_res) #s=4 e=6 topk=3 cap=2(s*topk/e) - logits2=torch.tensor([[0.5858, 0.4801, 0.6269, 0.5397, 0.9722, 0.7034], - [0.5445, 0.6332, 0.4519, 0.6308, 0.0519, 0.6450], - [0.4874, 0.8110, 0.7467, 0.8474, 0.0277, 0.3068], - [0.8570, 0.6714, 0.5310, 0.3274, 0.4836, 0.9892]]) - logits2*=dist.get_rank()+1 + logits2 = torch.tensor([[0.5858, 0.4801, 0.6269, 0.5397, 0.9722, 0.7034], + [0.5445, 0.6332, 0.4519, 0.6308, 0.0519, 0.6450], + [0.4874, 0.8110, 0.7467, 0.8474, 0.0277, 0.3068], + [0.8570, 0.6714, 0.5310, 0.3274, 0.4836, 0.9892]]) + logits2 *= dist.get_rank() + 1 #top3 full mask #prob_mask #postion_mask - #0 0 1 0 1 1 #0 0 1 0 1 1 #0 0 1 0 1 1 + #0 0 1 0 1 1 #0 0 1 0 1 1 #0 0 1 0 1 1 #0 1 0 1 0 1 #0 0 0 1 0 0 #0 1 0 1 0 1 #0 1 1 1 0 0 #0 1 1 1 0 0 #0 1 1 1 0 0 #1 1 0 0 0 1 #1 1 0 0 0 1 #1 0 0 0 0 0 - probs_dispatch_res=topkgating(logits2,3,1,min_capacity=1,drop_policy='probs')[2] - probs_sec_sparse = torch.tensor([[0,2,0],[0,4,0],[0,5,0],[1,3,0],[2,1,0],[2,2,1],[2,3,1],[3,0,0],[3,1,1],[3,5,1]]) - check_equal(logits2,2, probs_sec_sparse,probs_dispatch_res) - - position_sec_sparse =torch.tensor([[0,2,0],[0,4,0],[0,5,0],[1,1,0],[1,3,0],[1,5,1],[2,1,1],[2,2,1],[2,3,1],[3,0,0]]) - position_dispatch_res=topkgating(logits2,3,1,min_capacity=1,drop_policy='position')[2] - check_equal(logits2,2,position_sec_sparse,position_dispatch_res) - - - + probs_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='probs')[2] + probs_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 3, 0], [2, 1, 0], [2, 2, 1], [2, 3, 1], + [3, 0, 0], [3, 1, 1], [3, 5, 1]]) + check_equal(logits2, 2, probs_sec_sparse, probs_dispatch_res) + + position_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 1, 0], [1, 3, 0], [1, 5, 1], + [2, 1, 1], [2, 2, 1], [2, 3, 1], [3, 0, 0]]) + position_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='position')[2] + check_equal(logits2, 2, position_sec_sparse, position_dispatch_res) From 80a0c5f2c9f3495b8a832ab75631a060e5aa263f Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 8 Aug 2024 17:40:20 +0800 Subject: [PATCH 3/4] remove empty line --- deepspeed/moe/sharded_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index b32a6ce6e9b4..c09a11e213db 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -300,7 +300,6 @@ def top2gating(logits: Tensor, # Create a mask for 1st's expert per token indices1_s = torch.argmax(gates, dim=1) num_experts = int(gates.shape[1]) - mask1 = F.one_hot(indices1_s, num_classes=num_experts) if top2_2nd_expert_sampling: From 3bea84b881715d6663f5b1f782081f2e36fd2889 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 12 Aug 2024 15:01:33 +0800 Subject: [PATCH 4/4] fix format --- tests/unit/moe/test_moe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index 28caaa428e6c..f65d5e2a03bc 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -7,8 +7,6 @@ import deepspeed import pytest import gc - -import deepspeed.comm as dist from unit.common import DistributedTest from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader import deepspeed.comm as dist