Skip to content

Commit 382da55

Browse files
phymbertThiloteE
authored andcommitted
model: Add support for PhiMoE arch (ggml-org#11003)
* model: support phimoe * python linter * doc: minor Co-authored-by: ThiloteE <73715071+ThiloteE@users.noreply.github.com> * doc: minor Co-authored-by: ThiloteE <73715071+ThiloteE@users.noreply.github.com> * doc: add phimoe as supported model ggml-ci --------- Co-authored-by: ThiloteE <73715071+ThiloteE@users.noreply.github.com>
1 parent d8311e9 commit 382da55

10 files changed

+208
-31
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
6969
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
7070
- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)
7171
- [x] [Phi models](https://huggingface.co/models?search=microsoft/phi)
72+
- [x] [PhiMoE](https://github.com/ggerganov/llama.cpp/pull/11003)
7273
- [x] [GPT-2](https://huggingface.co/gpt2)
7374
- [x] [Orion 14B](https://github.com/ggerganov/llama.cpp/pull/5118)
7475
- [x] [InternLM2](https://huggingface.co/models?search=internlm2)

convert_hf_to_gguf.py

+57
Original file line numberDiff line numberDiff line change
@@ -2562,6 +2562,63 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
25622562
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
25632563

25642564

2565+
@Model.register("PhiMoEForCausalLM")
2566+
class PhiMoeModel(Phi3MiniModel):
2567+
model_arch = gguf.MODEL_ARCH.PHIMOE
2568+
2569+
_experts: list[dict[str, Tensor]] | None = None
2570+
2571+
def set_gguf_parameters(self):
2572+
super().set_gguf_parameters()
2573+
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
2574+
self.gguf_writer.add_expert_count(self.hparams["num_local_experts"])
2575+
2576+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2577+
# process the experts separately
2578+
if name.find("block_sparse_moe.experts") != -1:
2579+
n_experts = self.hparams["num_local_experts"]
2580+
assert bid is not None
2581+
2582+
if self._experts is None:
2583+
self._experts = [{} for _ in range(self.block_count)]
2584+
2585+
self._experts[bid][name] = data_torch
2586+
2587+
if len(self._experts[bid]) >= n_experts * 3:
2588+
tensors: list[tuple[str, Tensor]] = []
2589+
2590+
# merge the experts into a single 3d tensor
2591+
for w_name in ["w1", "w2", "w3"]:
2592+
datas: list[Tensor] = []
2593+
2594+
for xid in range(n_experts):
2595+
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
2596+
datas.append(self._experts[bid][ename])
2597+
del self._experts[bid][ename]
2598+
2599+
data_torch = torch.stack(datas, dim=0)
2600+
2601+
merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
2602+
2603+
new_name = self.map_tensor_name(merged_name)
2604+
2605+
tensors.append((new_name, data_torch))
2606+
return tensors
2607+
else:
2608+
return []
2609+
2610+
return [(self.map_tensor_name(name), data_torch)]
2611+
2612+
def prepare_tensors(self):
2613+
super().prepare_tensors()
2614+
2615+
if self._experts is not None:
2616+
# flatten `list[dict[str, Tensor]]` into `list[str]`
2617+
experts = [k for d in self._experts for k in d.keys()]
2618+
if len(experts) > 0:
2619+
raise ValueError(f"Unprocessed experts: {experts}")
2620+
2621+
25652622
@Model.register("PlamoForCausalLM")
25662623
class PlamoModel(Model):
25672624
model_arch = gguf.MODEL_ARCH.PLAMO

docs/development/HOWTO-add-model.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ The required steps to implement for an HF model are:
2828
```python
2929
@Model.register("MyModelForCausalLM")
3030
class MyModel(Model):
31-
model_arch = gguf.MODEL_ARCH.GROK
31+
model_arch = gguf.MODEL_ARCH.MYMODEL
3232
```
3333

3434
2. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py)
@@ -79,14 +79,14 @@ Depending on the model configuration, tokenizer, code and tensors layout, you wi
7979
- `Model#set_vocab`
8080
- `Model#write_tensors`
8181

82-
NOTE: Tensor names must end with `.weight` suffix, that is the convention and several tools like `quantize` expect this to proceed the weights.
82+
NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights.
8383

8484
### 2. Define the model architecture in `llama.cpp`
8585

8686
The model params and tensors layout must be defined in `llama.cpp`:
8787
1. Define a new `llm_arch`
8888
2. Define the tensors layout in `LLM_TENSOR_NAMES`
89-
3. Add any non standard metadata in `llm_load_hparams`
89+
3. Add any non-standard metadata in `llm_load_hparams`
9090
4. Create the tensors for inference in `llm_load_tensors`
9191
5. If the model has a RoPE operation, add the rope type in `llama_rope_type`
9292

@@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc
9696

9797
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
9898

99-
Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
99+
Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`.
100100

101-
When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR.
101+
Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.
102102

103103
Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/).
104104

gguf-py/gguf/constants.py

+20
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ class MODEL_ARCH(IntEnum):
244244
QWEN2VL = auto()
245245
PHI2 = auto()
246246
PHI3 = auto()
247+
PHIMOE = auto()
247248
PLAMO = auto()
248249
CODESHELL = auto()
249250
ORION = auto()
@@ -428,6 +429,7 @@ class MODEL_TENSOR(IntEnum):
428429
MODEL_ARCH.QWEN2VL: "qwen2vl",
429430
MODEL_ARCH.PHI2: "phi2",
430431
MODEL_ARCH.PHI3: "phi3",
432+
MODEL_ARCH.PHIMOE: "phimoe",
431433
MODEL_ARCH.PLAMO: "plamo",
432434
MODEL_ARCH.CODESHELL: "codeshell",
433435
MODEL_ARCH.ORION: "orion",
@@ -940,6 +942,24 @@ class MODEL_TENSOR(IntEnum):
940942
MODEL_TENSOR.FFN_DOWN,
941943
MODEL_TENSOR.FFN_UP,
942944
],
945+
MODEL_ARCH.PHIMOE: [
946+
MODEL_TENSOR.TOKEN_EMBD,
947+
MODEL_TENSOR.OUTPUT_NORM,
948+
MODEL_TENSOR.OUTPUT,
949+
MODEL_TENSOR.ROPE_FACTORS_LONG,
950+
MODEL_TENSOR.ROPE_FACTORS_SHORT,
951+
MODEL_TENSOR.ATTN_NORM,
952+
MODEL_TENSOR.ATTN_QKV,
953+
MODEL_TENSOR.ATTN_Q,
954+
MODEL_TENSOR.ATTN_K,
955+
MODEL_TENSOR.ATTN_V,
956+
MODEL_TENSOR.ATTN_OUT,
957+
MODEL_TENSOR.FFN_NORM,
958+
MODEL_TENSOR.FFN_GATE_INP,
959+
MODEL_TENSOR.FFN_GATE_EXP,
960+
MODEL_TENSOR.FFN_DOWN_EXP,
961+
MODEL_TENSOR.FFN_UP_EXP,
962+
],
943963
MODEL_ARCH.CODESHELL: [
944964
MODEL_TENSOR.TOKEN_EMBD,
945965
MODEL_TENSOR.POS_EMBD,

gguf-py/gguf/tensor_mapping.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class TensorNameMap:
5555
# Output
5656
MODEL_TENSOR.OUTPUT: (
5757
"embed_out", # gptneox
58-
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2
58+
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe
5959
"output", # llama-pth bloom internlm2
6060
"word_embeddings_for_head", # persimmon
6161
"lm_head.linear", # phi2
@@ -68,7 +68,7 @@ class TensorNameMap:
6868
MODEL_TENSOR.OUTPUT_NORM: (
6969
"gpt_neox.final_layer_norm", # gptneox
7070
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
71-
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2
71+
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe
7272
"norm", # llama-pth
7373
"transformer.norm_f", # mpt dbrx
7474
"ln_f", # refact bloom qwen gpt2
@@ -108,7 +108,7 @@ class TensorNameMap:
108108
"transformer.h.{bid}.input_layernorm", # falcon7b
109109
"h.{bid}.input_layernorm", # bloom
110110
"transformer.h.{bid}.ln_mlp", # falcon40b
111-
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
111+
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
112112
"layers.{bid}.attention_norm", # llama-pth
113113
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
114114
"model.layers.{bid}.ln1", # yi
@@ -152,7 +152,7 @@ class TensorNameMap:
152152

153153
# Attention query
154154
MODEL_TENSOR.ATTN_Q: (
155-
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2
155+
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe
156156
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
157157
"layers.{bid}.attention.wq", # llama-pth
158158
"encoder.layer.{bid}.attention.self.query", # bert
@@ -165,7 +165,7 @@ class TensorNameMap:
165165

166166
# Attention key
167167
MODEL_TENSOR.ATTN_K: (
168-
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2
168+
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe
169169
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
170170
"layers.{bid}.attention.wk", # llama-pth
171171
"encoder.layer.{bid}.attention.self.key", # bert
@@ -179,7 +179,7 @@ class TensorNameMap:
179179

180180
# Attention value
181181
MODEL_TENSOR.ATTN_V: (
182-
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2
182+
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
183183
"layers.{bid}.attention.wv", # llama-pth
184184
"encoder.layer.{bid}.attention.self.value", # bert
185185
"transformer.h.{bid}.attn.v_proj", # gpt-j
@@ -197,7 +197,7 @@ class TensorNameMap:
197197
"transformer.blocks.{bid}.attn.out_proj", # mpt
198198
"transformer.h.{bid}.self_attention.dense", # falcon
199199
"h.{bid}.self_attention.dense", # bloom
200-
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2
200+
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
201201
"model.layers.{bid}.self_attn.linear_attn", # deci
202202
"layers.{bid}.attention.wo", # llama-pth
203203
"encoder.layer.{bid}.attention.output.dense", # bert
@@ -242,7 +242,7 @@ class TensorNameMap:
242242
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
243243
"h.{bid}.post_attention_layernorm", # bloom
244244
"transformer.blocks.{bid}.norm_2", # mpt
245-
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
245+
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe
246246
"layers.{bid}.ffn_norm", # llama-pth
247247
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
248248
"model.layers.{bid}.ln2", # yi
@@ -265,7 +265,7 @@ class TensorNameMap:
265265

266266
MODEL_TENSOR.FFN_GATE_INP: (
267267
"layers.{bid}.feed_forward.gate", # mixtral
268-
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
268+
"model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe
269269
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
270270
"transformer.decoder_layer.{bid}.router", # Grok
271271
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
@@ -310,10 +310,11 @@ class TensorNameMap:
310310
),
311311

312312
MODEL_TENSOR.FFN_UP_EXP: (
313-
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
314-
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
315-
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
316-
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
313+
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
314+
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
315+
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
316+
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
317+
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
317318
),
318319

319320
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -342,10 +343,11 @@ class TensorNameMap:
342343
),
343344

344345
MODEL_TENSOR.FFN_GATE_EXP: (
345-
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
346-
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
347-
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
348-
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
346+
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
347+
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
348+
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
349+
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
350+
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
349351
),
350352

351353
MODEL_TENSOR.FFN_GATE_SHEXP: (
@@ -387,6 +389,7 @@ class TensorNameMap:
387389
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
388390
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
389391
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
392+
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
390393
),
391394

392395
MODEL_TENSOR.FFN_DOWN_SHEXP: (

src/llama-arch.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
2727
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
2828
{ LLM_ARCH_PHI2, "phi2" },
2929
{ LLM_ARCH_PHI3, "phi3" },
30+
{ LLM_ARCH_PHIMOE, "phimoe" },
3031
{ LLM_ARCH_PLAMO, "plamo" },
3132
{ LLM_ARCH_CODESHELL, "codeshell" },
3233
{ LLM_ARCH_ORION, "orion" },
@@ -584,6 +585,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
584585
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
585586
},
586587
},
588+
{
589+
LLM_ARCH_PHIMOE,
590+
{
591+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
592+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
593+
{ LLM_TENSOR_OUTPUT, "output" },
594+
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
595+
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
596+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
597+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
598+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
599+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
600+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
601+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
602+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
603+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
604+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
605+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
606+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
607+
},
608+
},
587609
{
588610
LLM_ARCH_PLAMO,
589611
{

src/llama-arch.h

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ enum llm_arch {
3131
LLM_ARCH_QWEN2VL,
3232
LLM_ARCH_PHI2,
3333
LLM_ARCH_PHI3,
34+
LLM_ARCH_PHIMOE,
3435
LLM_ARCH_PLAMO,
3536
LLM_ARCH_CODESHELL,
3637
LLM_ARCH_ORION,

src/llama-model.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ const char * llm_type_name(llm_type type) {
7676
case MODEL_8x7B: return "8x7B";
7777
case MODEL_8x22B: return "8x22B";
7878
case MODEL_16x12B: return "16x12B";
79+
case MODEL_16x3_8B: return "16x3.8B";
7980
case MODEL_10B_128x3_66B: return "10B+128x3.66B";
8081
case MODEL_57B_A14B: return "57B.A14B";
8182
case MODEL_27B: return "27B";
@@ -661,6 +662,15 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
661662
throw std::runtime_error("invalid value for sliding_window");
662663
}
663664
} break;
665+
case LLM_ARCH_PHIMOE:
666+
{
667+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
668+
669+
switch (hparams.n_layer) {
670+
case 32: model.type = e_model::MODEL_16x3_8B; break;
671+
default: model.type = e_model::MODEL_UNKNOWN;
672+
}
673+
} break;
664674
case LLM_ARCH_PLAMO:
665675
{
666676
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -2094,6 +2104,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
20942104
case LLM_ARCH_OLMOE:
20952105
case LLM_ARCH_PHI2:
20962106
case LLM_ARCH_PHI3:
2107+
case LLM_ARCH_PHIMOE:
20972108
case LLM_ARCH_GEMMA:
20982109
case LLM_ARCH_GEMMA2:
20992110
case LLM_ARCH_STARCODER2:

src/llama-model.h

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ enum llm_type {
7373
MODEL_8x7B,
7474
MODEL_8x22B,
7575
MODEL_16x12B,
76+
MODEL_16x3_8B,
7677
MODEL_10B_128x3_66B,
7778
MODEL_57B_A14B,
7879
MODEL_27B,

0 commit comments

Comments
 (0)