Skip to content

Commit 873723b

Browse files
authored
Merge pull request #128 from yanboliang/mixtral_improvements
Mixtral MoE improvements: transposed w2 to have reduction dim be innermost dim
2 parents f68e81e + 7e50fcc commit 873723b

File tree

6 files changed

+27
-26
lines changed

6 files changed

+27
-26
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ Please check the rest of this page about benchmark of LLaMA family models.
2222
### Mixtral 8x7B
2323
We also supported [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) which is a high-quality sparse mixture of experts (MoE) model, the average token generation rates are:
2424

25-
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
25+
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
2626
|------------------|---------|-----------|--------|------------|
27-
|baseline(bfloat16)| OOM | 78.75 | 118.23 | 203.69 |
28-
| int8 | 56.04 | 99.91 | 149.53 | 218.48 |
27+
|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 |
28+
| int8 | 97.92 | 155.03 | 216.87 | 279.35 |
2929

3030
Note that the benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).
3131

mixtral-moe/README.md

+3-4
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
1212
## Benchmarks
1313
Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).
1414

15-
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
15+
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
1616
|------------------|---------|-----------|--------|------------|
17-
|baseline(bfloat16)| OOM | 78.75 | 118.23 | 203.69 |
18-
| int8 | 56.04 | 99.91 | 149.53 | 218.48 |
19-
17+
|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 |
18+
| int8 | 97.92 | 155.03 | 216.87 | 279.35 |
2019

2120

2221
## Generate Text

mixtral-moe/model.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,16 @@ class ConditionalFeedForward(nn.Module):
188188
def __init__(self, config):
189189
super().__init__()
190190
self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
191-
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
191+
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
192192
self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
193193

194194
def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
195-
w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D]
196-
w3_weights = self.w3[expert_indices].transpose(-1, -2) # [T, A, D, D]
195+
w1_weights = self.w1[expert_indices] # [T, A, D, D]
196+
w3_weights = self.w3[expert_indices] # [T, A, D, D]
197197
w2_weights = self.w2[expert_indices] # [T, A, D, D]
198-
x1 = F.silu(torch.einsum('ti,taio -> tao', x, w1_weights))
199-
x3 = torch.einsum('ti, taio -> tao', x, w3_weights)
200-
expert_outs = torch.einsum('tao, taoi -> tai', (x1 * x3), w2_weights)
198+
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
199+
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
200+
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
201201
return expert_outs
202202

203203

mixtral-moe/quantize.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ def create_quantized_state_dict(self):
7575
cur_state_dict[f"{fqn}.weight"] = int8_weight
7676
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
7777
elif isinstance(mod, ConditionalFeedForward):
78-
num_experts, intermediate_size, dim = mod.w1.shape
7978
for weight_idx in range(0, 3):
8079
weight_name = f"w{weight_idx + 1}"
8180
scales_name = f"scales{weight_idx + 1}"
8281
weight = getattr(mod, weight_name)
82+
num_experts, intermediate_size, dim = weight.shape
8383

8484
bit8_weight_list = []
8585
scales_list = []
@@ -125,20 +125,20 @@ def __init__(self, num_experts, intermediate_size, dim, target_dtype):
125125
self.target_dtype = target_dtype
126126

127127
self.register_buffer("w1", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
128-
self.register_buffer("w2", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
128+
self.register_buffer("w2", torch.empty(num_experts, dim, intermediate_size, dtype=target_dtype))
129129
self.register_buffer("w3", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
130130

131131
self.register_buffer("scales1", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
132-
self.register_buffer("scales2", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
132+
self.register_buffer("scales2", torch.empty(num_experts, dim, dtype=torch.bfloat16))
133133
self.register_buffer("scales3", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
134134

135135
def forward(self, x, expert_indices):
136-
w1_weights = (self.w1.to(x.dtype)[expert_indices] * self.scales1[expert_indices].to(x.dtype).unsqueeze(-1)).transpose(-1, -2) # [T, A, D, D]
137-
w3_weights = (self.w3.to(x.dtype)[expert_indices] * self.scales3[expert_indices].to(x.dtype).unsqueeze(-1)).transpose(-1, -2) # [T, A, D, D]
138-
w2_weights = (self.w2.to(x.dtype)[expert_indices] * self.scales2[expert_indices].to(x.dtype).unsqueeze(-1)) # [T, A, D, D]
139-
x1 = F.silu(torch.einsum('ti,taio -> tao', x, w1_weights))
140-
x3 = torch.einsum('ti, taio -> tao', x, w3_weights)
141-
expert_outs = torch.einsum('tao, taoi -> tai', (x1 * x3), w2_weights)
136+
w1_weights = self.w1.to(x.dtype)[expert_indices] # [T, A, D, D]
137+
w3_weights = self.w3.to(x.dtype)[expert_indices] # [T, A, D, D]
138+
w2_weights = self.w2.to(x.dtype)[expert_indices]
139+
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights) * self.scales1[expert_indices].to(x.dtype))
140+
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) * self.scales3[expert_indices].to(x.dtype)
141+
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) * self.scales2[expert_indices].to(x.dtype) # [T, A, D, D]
142142
return expert_outs
143143

144144

mixtral-moe/scripts/convert_hf_checkpoint.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ def convert_hf_checkpoint(
7676
del final_result[key]
7777
del final_result[key.replace("wq", "wk")]
7878
del final_result[key.replace("wq", "wv")]
79-
if "w1" in key or "w2" in key or "w3" in key:
79+
elif "w1" in key or "w3" in key:
8080
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous()
81-
if "gate" in key:
81+
elif "w2" in key:
82+
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous()
83+
elif "gate" in key:
8284
final_result[key] = final_result[key].contiguous()
8385

8486
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")

mixtral-moe/tp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ def shard_qkv(qkv, dim, weight_splits):
9999
def _apply_tp_moe_ffn(mlp: MOEFeedForward) -> None:
100100
mlp.cond_ffn.w1 = nn.Parameter(shard(mlp.cond_ffn.w1, 1), requires_grad=False)
101101
mlp.cond_ffn.w3 = nn.Parameter(shard(mlp.cond_ffn.w3, 1), requires_grad=False)
102-
mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 1), requires_grad=False)
102+
mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 2), requires_grad=False)
103103

104104
if hasattr(mlp.cond_ffn, "scales1"):
105105
mlp.cond_ffn.scales1 = nn.Parameter(shard(mlp.cond_ffn.scales1, 1), requires_grad=False)
106106
mlp.cond_ffn.scales3 = nn.Parameter(shard(mlp.cond_ffn.scales3, 1), requires_grad=False)
107-
mlp.cond_ffn.scales2 = nn.Parameter(shard(mlp.cond_ffn.scales2, 1), requires_grad=False)
107+
mlp.cond_ffn.scales2 = nn.Parameter(mlp.cond_ffn.scales2, requires_grad=False)
108108

109109
world_size = _get_world_size()
110110
mlp.cond_ffn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(

0 commit comments

Comments
 (0)