Skip to content

Add qk norm optionally before attention calculation #8820

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ runtime.python_library(
"rope.py",
"attention.py",
"model_args.py",
"norm.py",
],
_is_external_target = True,
base_module = "executorch.examples.models.llama",
Expand Down
13 changes: 13 additions & 0 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn as nn
import torch.nn.functional as F
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.rope import Rope


Expand Down Expand Up @@ -176,6 +177,14 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.max_context_len = args.max_context_len
self.dim = args.dim
self.attention_qkv_bias = args.attention_qkv_bias
self.use_qk_norm = args.use_qk_norm

if self.use_qk_norm:
q_norm_dim = self.head_dim
k_norm_dim = self.head_dim
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)

self.wq = nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
)
Expand Down Expand Up @@ -241,6 +250,10 @@ def forward(
k = k.transpose(1, 2)
v = v.transpose(1, 2)

if self.use_qk_norm:
q = self.q_norm_fn(q)
k = self.k_norm_fn(k)

if self.use_kv_cache:
assert input_pos is not None
k, v = self.kv_cache.update(input_pos, k, v)
Expand Down
50 changes: 1 addition & 49 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,59 +18,11 @@
)

from executorch.examples.models.llama.model_args import ModelArgs

from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.rope import Rope

from torch import nn


class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.

"""
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The normalized tensor.

"""
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
"""
Forward pass through the RMSNorm layer.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The output tensor after applying RMSNorm.

"""
output = self._norm(x.float()).type_as(x)
return output * self.weight


class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ModelArgs:
output_prune_map: Optional[Dict[int, int]] = None
apply_embedding: bool = True # Use embedding inside the transformer
apply_output: bool = True # Use output layer (unembedding) inside the transformer
use_qk_norm: bool = False # apply normalization to q and k in the attention
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
partial_rotary_factor: float = 1.0
rope_theta: Optional[float] = (
Expand Down
51 changes: 51 additions & 0 deletions examples/models/llama/norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import torch
from torch import nn


class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.

"""
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The normalized tensor.

"""
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
"""
Forward pass through the RMSNorm layer.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The output tensor after applying RMSNorm.

"""
output = self._norm(x.float()).type_as(x)
return output * self.weight
3 changes: 2 additions & 1 deletion examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
self.head_dim = config.head_dim
self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5)
self.attention_qkv_bias = config.attention_qkv_bias
self.use_qk_norm = config.use_qk_norm

assert not self.use_qk_norm, "QK norm not supported in static attention yet"
self.wqs = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
Expand Down Expand Up @@ -258,7 +260,6 @@ def forward(
new_vs = [self.wvs[i](x) for i in range(self.n_kv_heads)]
new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs]
new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks]

all_ks = []
all_vs = []
for i in range(self.n_kv_heads):
Expand Down
Loading