Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

add moe topk(k>2) gate support #5881

Merged
merged 8 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 83 additions & 5 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def einsum(rule, a, b):
return a.unsqueeze(2) * b.unsqueeze(1)
elif rule == 'se,se->s':
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
elif rule == 'se,sec->sec':
return a.unsqueeze(2) * b
elif rule == 'sec,sm->ecm':
s = a.shape[0]
e = a.shape[1]
Expand Down Expand Up @@ -191,8 +193,8 @@ def top1gating(logits: Tensor,
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)

gates = F.softmax(logits, dim=1)
capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))

# Create a mask for 1st's expert per token
Expand Down Expand Up @@ -369,6 +371,81 @@ def top2gating(logits: Tensor,
return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu')


def topkgating(
logits: Tensor,
k: int,
capacity_factor: float,
min_capacity: int,
drop_tokens: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
drop_policy: str = "probs",
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements TopKGating on logits."""

# everything is in fp32 in this function
# get topk gates
top_gate, top_idx = torch.topk(logits, k=k, dim=1)
# gating decisions
gates = F.softmax(logits, dim=1)
num_experts = int(gates.shape[1])

# get topk mask
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_idx, top_gate)

mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(1, top_idx, 1)

exp_counts = torch.sum(mask, dim=0).detach().to(logits.device)

# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask.float(), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts / k

if drop_tokens:
# Calculate configured capacity and remove locations outside capacity from mask
capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity))
# update mask and locations by capacity

if drop_policy == 'probs':
capacity_probs, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
mask = torch.logical_and(mask, capacity_mask)
locations = torch.cumsum(mask, dim=0) - 1

elif drop_policy == "position":
locations = torch.cumsum(mask, dim=0) - 1
mask *= torch.lt(locations, capacity)
else:
raise ValueError(f"Invalid drop_policy: {drop_policy}")

else:
# Do not drop tokens - set capacity according to current expert assignments
new_capacity = torch.max(exp_counts)
if ep_group is not None:
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity

# normalize gates
gates_masked = gates * mask
gates_s = torch.sum(gates_masked, dim=-1, keepdim=True)
denom_s = torch.clamp(gates_s, min=torch.finfo(gates_masked.dtype).eps)
gates_masked = gates_masked / denom_s

# dispatch_mask
locations_sc = _one_hot_to_float((locations * mask), capacity)

combine_weights = torch.einsum("se,sec->sec", gates_masked, locations_sc)

dispatch_mask = combine_weights.bool()

return l_aux, combine_weights, dispatch_mask, exp_counts


class TopKGate(Module):
"""Gate module which implements Top2Gating as described in Gshard_.
::
Expand Down Expand Up @@ -401,9 +478,6 @@ def __init__(self,
top2_2nd_expert_sampling: bool = True) -> None:
super().__init__()

# Only top-1 and top-2 are supported at the moment.
if k != 1 and k != 2:
raise ValueError('Only top-1 and top-2 gatings are supported.')
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
self.ep_group = ep_group
self.k = k
Expand Down Expand Up @@ -441,9 +515,13 @@ def forward(self,
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
self.drop_tokens, self.use_rts, self.ep_group, use_tutel)

else:
elif self.k == 2:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling)
else:
gate_output = topkgating(logits, self.k,
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, self.drop_tokens, self.ep_group)

if self.wall_clock_breakdown:
self.timers(TOPK_GATE_TIMER).stop()
Expand Down
49 changes: 48 additions & 1 deletion tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader
import deepspeed.comm as dist
from deepspeed import get_accelerator
from deepspeed.moe.sharded_moe import top1gating
from deepspeed.moe.sharded_moe import top1gating, topkgating
from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param
from deepspeed.utils.torch import required_torch_version

Expand Down Expand Up @@ -191,3 +191,50 @@ def test(self):
drop_tokens=False,
use_rts=True,
use_tutel=False)


class TestTopkGate(DistributedTest):

def test(self):

def check_equal(logits, cap, sparse_truth, res):
m, n = logits.shape
dispatch_mask_truth = torch.zeros(m, n, cap)
i, j, k = sparse_truth.t()
dispatch_mask_truth[i, j, k] = 1
assert (torch.equal(dispatch_mask_truth, res))

#s=4 e=4 topk=2 cap=2(s*topk/e)
logits = torch.tensor([[0.11, 0.2, 0.1, 0.3], [0.3, 0.4, 0.11, 0.1], [0.11, 0.1, 0.6, 0.5],
[0.1, 0.11, 0.7, 0.8]])
logits *= dist.get_rank() + 1
probs_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='probs')[2]
probs_sec_sparse = torch.tensor([[0, 1, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 0], [3, 2, 1], [3, 3, 1]])
check_equal(logits, 2, probs_sec_sparse, probs_dispatch_res)

position_sec_sparse = torch.tensor([[0, 1, 0], [0, 3, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 1],
[3, 2, 1]])
position_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='position')[2]
check_equal(logits, 2, position_sec_sparse, position_dispatch_res)

#s=4 e=6 topk=3 cap=2(s*topk/e)
logits2 = torch.tensor([[0.5858, 0.4801, 0.6269, 0.5397, 0.9722, 0.7034],
[0.5445, 0.6332, 0.4519, 0.6308, 0.0519, 0.6450],
[0.4874, 0.8110, 0.7467, 0.8474, 0.0277, 0.3068],
[0.8570, 0.6714, 0.5310, 0.3274, 0.4836, 0.9892]])
logits2 *= dist.get_rank() + 1

#top3 full mask #prob_mask #postion_mask
#0 0 1 0 1 1 #0 0 1 0 1 1 #0 0 1 0 1 1
#0 1 0 1 0 1 #0 0 0 1 0 0 #0 1 0 1 0 1
#0 1 1 1 0 0 #0 1 1 1 0 0 #0 1 1 1 0 0
#1 1 0 0 0 1 #1 1 0 0 0 1 #1 0 0 0 0 0
probs_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='probs')[2]
probs_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 3, 0], [2, 1, 0], [2, 2, 1], [2, 3, 1],
[3, 0, 0], [3, 1, 1], [3, 5, 1]])
check_equal(logits2, 2, probs_sec_sparse, probs_dispatch_res)

position_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 1, 0], [1, 3, 0], [1, 5, 1],
[2, 1, 1], [2, 2, 1], [2, 3, 1], [3, 0, 0]])
position_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='position')[2]
check_equal(logits2, 2, position_sec_sparse, position_dispatch_res)
Loading