Skip to content

Commit 352416e

Browse files
authored
Add qk norm optionally before attention calculation
Differential Revision: D70355802 Pull Request resolved: #8820
1 parent bdafb22 commit 352416e

File tree

6 files changed

+69
-50
lines changed

6 files changed

+69
-50
lines changed

examples/models/llama/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ runtime.python_library(
1616
"rope.py",
1717
"attention.py",
1818
"model_args.py",
19+
"norm.py",
1920
],
2021
_is_external_target = True,
2122
base_module = "executorch.examples.models.llama",

examples/models/llama/attention.py

+13
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.nn as nn
66
import torch.nn.functional as F
77
from executorch.examples.models.llama.model_args import ModelArgs
8+
from executorch.examples.models.llama.norm import RMSNorm
89
from executorch.examples.models.llama.rope import Rope
910

1011

@@ -176,6 +177,14 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
176177
self.max_context_len = args.max_context_len
177178
self.dim = args.dim
178179
self.attention_qkv_bias = args.attention_qkv_bias
180+
self.use_qk_norm = args.use_qk_norm
181+
182+
if self.use_qk_norm:
183+
q_norm_dim = self.head_dim
184+
k_norm_dim = self.head_dim
185+
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
186+
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
187+
179188
self.wq = nn.Linear(
180189
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
181190
)
@@ -241,6 +250,10 @@ def forward(
241250
k = k.transpose(1, 2)
242251
v = v.transpose(1, 2)
243252

253+
if self.use_qk_norm:
254+
q = self.q_norm_fn(q)
255+
k = self.k_norm_fn(k)
256+
244257
if self.use_kv_cache:
245258
assert input_pos is not None
246259
k, v = self.kv_cache.update(input_pos, k, v)

examples/models/llama/llama_transformer.py

+1-49
Original file line numberDiff line numberDiff line change
@@ -18,59 +18,11 @@
1818
)
1919

2020
from executorch.examples.models.llama.model_args import ModelArgs
21-
21+
from executorch.examples.models.llama.norm import RMSNorm
2222
from executorch.examples.models.llama.rope import Rope
23-
2423
from torch import nn
2524

2625

27-
class RMSNorm(torch.nn.Module):
28-
def __init__(self, dim: int, eps: float = 1e-6):
29-
"""
30-
Initialize the RMSNorm normalization layer.
31-
32-
Args:
33-
dim (int): The dimension of the input tensor.
34-
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
35-
36-
Attributes:
37-
eps (float): A small value added to the denominator for numerical stability.
38-
weight (nn.Parameter): Learnable scaling parameter.
39-
40-
"""
41-
super().__init__()
42-
self.dim = dim
43-
self.eps = eps
44-
self.weight = nn.Parameter(torch.ones(dim))
45-
46-
def _norm(self, x):
47-
"""
48-
Apply the RMSNorm normalization to the input tensor.
49-
50-
Args:
51-
x (torch.Tensor): The input tensor.
52-
53-
Returns:
54-
torch.Tensor: The normalized tensor.
55-
56-
"""
57-
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
58-
59-
def forward(self, x):
60-
"""
61-
Forward pass through the RMSNorm layer.
62-
63-
Args:
64-
x (torch.Tensor): The input tensor.
65-
66-
Returns:
67-
torch.Tensor: The output tensor after applying RMSNorm.
68-
69-
"""
70-
output = self._norm(x.float()).type_as(x)
71-
return output * self.weight
72-
73-
7426
class FeedForward(nn.Module):
7527
def __init__(self, args: ModelArgs):
7628
super().__init__()

examples/models/llama/model_args.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class ModelArgs:
3737
output_prune_map: Optional[Dict[int, int]] = None
3838
apply_embedding: bool = True # Use embedding inside the transformer
3939
apply_output: bool = True # Use output layer (unembedding) inside the transformer
40+
use_qk_norm: bool = False # apply normalization to q and k in the attention
4041
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
4142
partial_rotary_factor: float = 1.0
4243
rope_theta: Optional[float] = (

examples/models/llama/norm.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
import torch
4+
from torch import nn
5+
6+
7+
class RMSNorm(torch.nn.Module):
8+
def __init__(self, dim: int, eps: float = 1e-6):
9+
"""
10+
Initialize the RMSNorm normalization layer.
11+
12+
Args:
13+
dim (int): The dimension of the input tensor.
14+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
15+
16+
Attributes:
17+
eps (float): A small value added to the denominator for numerical stability.
18+
weight (nn.Parameter): Learnable scaling parameter.
19+
20+
"""
21+
super().__init__()
22+
self.dim = dim
23+
self.eps = eps
24+
self.weight = nn.Parameter(torch.ones(dim))
25+
26+
def _norm(self, x):
27+
"""
28+
Apply the RMSNorm normalization to the input tensor.
29+
30+
Args:
31+
x (torch.Tensor): The input tensor.
32+
33+
Returns:
34+
torch.Tensor: The normalized tensor.
35+
36+
"""
37+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
38+
39+
def forward(self, x):
40+
"""
41+
Forward pass through the RMSNorm layer.
42+
43+
Args:
44+
x (torch.Tensor): The input tensor.
45+
46+
Returns:
47+
torch.Tensor: The output tensor after applying RMSNorm.
48+
49+
"""
50+
output = self._norm(x.float()).type_as(x)
51+
return output * self.weight

examples/models/llama/static_attention.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
209209
self.head_dim = config.head_dim
210210
self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5)
211211
self.attention_qkv_bias = config.attention_qkv_bias
212+
self.use_qk_norm = config.use_qk_norm
212213

214+
assert not self.use_qk_norm, "QK norm not supported in static attention yet"
213215
self.wqs = nn.ModuleList(
214216
[
215217
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
@@ -258,7 +260,6 @@ def forward(
258260
new_vs = [self.wvs[i](x) for i in range(self.n_kv_heads)]
259261
new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs]
260262
new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks]
261-
262263
all_ks = []
263264
all_vs = []
264265
for i in range(self.n_kv_heads):

0 commit comments

Comments
 (0)