Skip to content

Add Phi-4-mini-instruct #8856

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 8 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .ci/scripts/test_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ test_model() {
rm "./${MODEL_NAME}.pte"
return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears.
fi
if [[ "${MODEL_NAME}" == "phi4_mini" ]]; then
# Install requirements for export_llama
bash examples/models/llama/install_requirements.sh
# Test export_llama script: python3 -m examples.models.llama.export_llama.
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/phi-4-mini/config.json
run_portable_executor_runner
rm "./${MODEL_NAME}.pte"
fi

# Export a basic .pte and run the model.
"${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}" "${STRICT}"
Expand Down
1 change: 1 addition & 0 deletions examples/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"llava": ("llava", "LlavaModel"),
"efficient_sam": ("efficient_sam", "EfficientSAM"),
"qwen2_5": ("qwen2_5", "Qwen2_5Model"),
"phi4_mini": ("phi4_mini", "Phi4MiniModel"),
}

__all__ = [
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"llama3_2",
"static_llama",
"qwen2_5",
"phi4_mini",
]
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]

Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ModelArgs:
apply_embedding: bool = True # Use embedding inside the transformer
apply_output: bool = True # Use output layer (unembedding) inside the transformer
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
partial_rotary_factor: float = 1.0
rope_theta: Optional[float] = (
None # The official name to override self.rope_freq_base.
)
Expand Down
26 changes: 22 additions & 4 deletions examples/models/llama/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,21 @@ def forward(


# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77
def hf_precompute_freqs_cis(dim: int, end: int, theta: float):
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242.
# Current only support non-long rope.
def hf_precompute_freqs_cis(
dim: int, end: int, theta: float, partial_rotary_factor: float = 1.0
):
# Partial rotary embeddings.
dim = int(dim * partial_rotary_factor)

# Short factor scaling.
freqs = 1.0 / (
theta
** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim)
)
# TODO: support long factor scaling.

# pyre-ignore Undefined attribute [16]: `float` has no attribute `device`.
t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as(
freqs # pyre-ignore
Expand Down Expand Up @@ -180,8 +190,13 @@ def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)

rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]

q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
return q_embed, k_embed


Expand Down Expand Up @@ -217,7 +232,10 @@ def __init__(self, params: ModelArgs):

# Choose the appropriate RoPE implementation
if self.params.use_hf_rope:
self.precompute_freqs_cis = hf_precompute_freqs_cis
self.precompute_freqs_cis = partial(
hf_precompute_freqs_cis,
partial_rotary_factor=self.params.partial_rotary_factor,
)
self.apply_rotary_emb = hf_apply_rotary_emb
else:
self.precompute_freqs_cis = partial(
Expand Down
15 changes: 15 additions & 0 deletions examples/models/phi-4-mini/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"dim": 3072,
"ffn_dim_multiplier": 1,
"hidden_dim": 8192,
"n_heads": 24,
"n_kv_heads": 8,
"n_layers": 32,
"norm_eps": 1e-05,
"rope_theta": 10000.0,
"use_scaled_rope": false,
"vocab_size": 200064,
"use_hf_rope": true,
"partial_rotary_factor": 0.75,
"attention_qkv_bias": false
}
88 changes: 88 additions & 0 deletions examples/models/phi-4-mini/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import argparse
from typing import Dict

import torch

from torchtune.models.convert_weights import get_mapped_key

from torchtune.training import FullModelHFCheckpointer


# Standard _FROM_META weight mapping of Meta weights to TorchTune.
_PHI_4_FROM_META = {
"tok_embeddings.weight": "tok_embeddings.weight",
"norm.weight": "norm.scale",
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
}


def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _PHI_4_FROM_META.items()}

for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value

# Input and output embeddings are tied.
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]

return converted_state_dict


def main():
parser = argparse.ArgumentParser(
description="Convert Phi-4-mini weights to Meta format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()

checkpointer = FullModelHFCheckpointer(
checkpoint_dir=args.input_dir,
checkpoint_files=[
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
],
output_dir=".",
model_type="PHI3_MINI",
)

print("Loading checkpoint...")
sd = checkpointer.load_checkpoint()

print("Converting checkpoint...")
sd = phi_4_tune_to_meta(sd["model"])

torch.save(sd, args.output)
print(f"Checkpoint saved to {args.output}")


if __name__ == "__main__":
main()
Loading