From 5b2587a44b1edf3068eb84bf36fb79e89a4ce34c Mon Sep 17 00:00:00 2001 From: Madhumitha Sridhara Date: Thu, 6 Mar 2025 12:01:17 -0800 Subject: [PATCH] Add qk norm optionally before attention calculation (#8820) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/8820 Some of the new llama checkpoints developed by genai use an additional qk_norm in the attention calculation. To run these models with executorch and have parity with server models, we require an optional qk norm in the ET attention. Refactoring RMSNorm into a separate file so that there is no circular dependency between attention and llama_transformer Reviewed By: iseeyuan Differential Revision: D70355802 --- examples/models/llama/TARGETS | 1 + examples/models/llama/attention.py | 13 ++++++ examples/models/llama/llama_transformer.py | 50 +-------------------- examples/models/llama/model_args.py | 1 + examples/models/llama/norm.py | 51 ++++++++++++++++++++++ examples/models/llama/static_attention.py | 3 +- 6 files changed, 69 insertions(+), 50 deletions(-) create mode 100644 examples/models/llama/norm.py diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 46875b3412..48c48532f7 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -16,6 +16,7 @@ runtime.python_library( "rope.py", "attention.py", "model_args.py", + "norm.py", ], _is_external_target = True, base_module = "executorch.examples.models.llama", diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 66eeb10989..54f738ba73 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.norm import RMSNorm from executorch.examples.models.llama.rope import Rope @@ -176,6 +177,14 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.max_context_len = args.max_context_len self.dim = args.dim self.attention_qkv_bias = args.attention_qkv_bias + self.use_qk_norm = args.use_qk_norm + + if self.use_qk_norm: + q_norm_dim = self.head_dim + k_norm_dim = self.head_dim + self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) + self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) + self.wq = nn.Linear( self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias ) @@ -241,6 +250,10 @@ def forward( k = k.transpose(1, 2) v = v.transpose(1, 2) + if self.use_qk_norm: + q = self.q_norm_fn(q) + k = self.k_norm_fn(k) + if self.use_kv_cache: assert input_pos is not None k, v = self.kv_cache.update(input_pos, k, v) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 3536936e47..5c8db7f208 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -18,59 +18,11 @@ ) from executorch.examples.models.llama.model_args import ModelArgs - +from executorch.examples.models.llama.norm import RMSNorm from executorch.examples.models.llama.rope import Rope - from torch import nn -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - """ - Initialize the RMSNorm normalization layer. - - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - - """ - super().__init__() - self.dim = dim - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - """ - Apply the RMSNorm normalization to the input tensor. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The normalized tensor. - - """ - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - """ - Forward pass through the RMSNorm layer. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The output tensor after applying RMSNorm. - - """ - output = self._norm(x.float()).type_as(x) - return output * self.weight - - class FeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 714976e34f..75ee926f51 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -37,6 +37,7 @@ class ModelArgs: output_prune_map: Optional[Dict[int, int]] = None apply_embedding: bool = True # Use embedding inside the transformer apply_output: bool = True # Use output layer (unembedding) inside the transformer + use_qk_norm: bool = False # apply normalization to q and k in the attention use_hf_rope: bool = False # Use HuggingFace's RoPE implementation partial_rotary_factor: float = 1.0 rope_theta: Optional[float] = ( diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py new file mode 100644 index 0000000000..5a63ad8f59 --- /dev/null +++ b/examples/models/llama/norm.py @@ -0,0 +1,51 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import torch +from torch import nn + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 9bb5cee5b2..f35efa3815 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -209,7 +209,9 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope): self.head_dim = config.head_dim self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5) self.attention_qkv_bias = config.attention_qkv_bias + self.use_qk_norm = config.use_qk_norm + assert not self.use_qk_norm, "QK norm not supported in static attention yet" self.wqs = nn.ModuleList( [ nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) @@ -258,7 +260,6 @@ def forward( new_vs = [self.wvs[i](x) for i in range(self.n_kv_heads)] new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs] new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks] - all_ks = [] all_vs = [] for i in range(self.n_kv_heads):