generated from fkodom/python-repo-template
-
-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathattention.py
296 lines (263 loc) · 11.7 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from einops import einsum, rearrange
from torch import Tensor, nn
def scaled_dot_product_gqa(
query: Tensor,
key: Tensor,
value: Tensor,
dropout: float = 0.0,
scale: Optional[float] = None,
mask: Optional[Tensor] = None,
is_causal: Optional[bool] = None,
need_weights: bool = False,
average_attn_weights: bool = False,
force_grouped: bool = False,
):
"""Scaled dot product attention with support for grouped queries.
Einstein notation:
- b: batch size
- n / s: sequence length
- h: number of heads
- g: number of groups
- d: dimension of query/key/value
Args:
query: Query tensor of shape (b, n, h, d)
key: Key tensor of shape (b, s, h, d)
value: Value tensor of shape (b, s, h, d)
dropout: Dropout probability (default: 0.0)
scale: Scale factor for query (default: d_query ** 0.5)
mask: Mask tensor of shape (b, n, s) or (b, s). If 'ndim == 2', the mask is
applied to all 'n' rows of the attention matrix. (default: None)
force_grouped: If True, apply grouped-query attention even if the number of
heads is equal for query, key, and value. (default: False)
Returns:
2-tuple of:
- Attention output with shape (b, n, h, d)
- (Optional) Attention weights with shape (b, h, n, s). Only returned if
'need_weights' is True.
"""
if (mask is not None) and (is_causal is not None):
raise ValueError(
"Only one of 'mask' and 'is_causal' should be provided, but got both."
)
elif not query.ndim == key.ndim == value.ndim == 4:
raise ValueError(
f"Expected query, key, and value to be 4-dimensional, but got shapes "
f"{query.shape}, {key.shape}, and {value.shape}."
)
# Move sequence length dimension to axis 2.
# This makes the attention operations below *much* faster.
query = rearrange(query, "b n h d -> b h n d")
key = rearrange(key, "b s h d -> b h s d")
value = rearrange(value, "b s h d -> b h s d")
bq, hq, nq, dq = query.shape
bk, hk, nk, dk = key.shape
bv, hv, nv, dv = value.shape
if not (bq == bk == bv and dq == dk == dv):
raise ValueError(
"Expected query, key, and value to have the same batch size (dim=0) and "
f"embedding dimension (dim=3), but got query: {query.shape}, "
f"key: {key.shape}, and value: {value.shape}."
)
elif (hk != hv) or (nk != nv):
raise ValueError(
"Expected key and value to have the same size in dimensions 1 and 2, but "
f"got key: {key.shape} and value: {value.shape}."
)
elif hq % hk != 0:
raise ValueError(
"Expected query heads to be a multiple of key/value heads, but got "
f"query: {query.shape} and key/value: {key.shape}."
)
if scale is None:
scale = query.size(-1) ** 0.5
query = query / scale
num_head_groups = hq // hk
query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
similarity = einsum(query, key, "b g h n d, b h s d -> b g h n s")
if is_causal:
# Mask out the upper triangular portion of the attention matrix. This prevents
# the model from attending to tokens in the future.
mask = torch.ones((bq, nq, nk), device=query.device, dtype=torch.bool).tril_()
if mask is not None:
# Expand mask to match the shape of the attention matrix.
# If mask is 2D, assume that it is applied to the key/value sequence dimension.
# Else if mask is 3D, assume that it is applied to the query/key/value sequence
# dimension for all attention heads.
#
# Users could also provide a 4D mask, which is applied to the query/key/value
# sequence dimension for each attention head (though I don't have a particular
# use case in mind for that).
if mask.ndim == 2:
mask = rearrange(mask, "b s -> b () () () s")
elif mask.ndim == 3:
mask = rearrange(mask, "b n s -> b () () n s")
# Mask similarity values by setting them to negative infinity. This guarantees
# that they will not contribute to the softmax computation below.
similarity.masked_fill_(~mask, torch.finfo(similarity.dtype).min)
attention = F.softmax(similarity, dim=-1)
if dropout > 0.0:
attention = F.dropout(attention, p=dropout)
# Apply attention matrix to the value Tensor.
out = einsum(attention, value, "b g h n s, b h s d -> b g h n d")
# Move head dimension back to axis 2
out = rearrange(out, "b g h n d -> b n (h g) d")
attn_weights: Optional[Tensor] = None
if need_weights:
# Move the sequence dimensions back to positions 1, 2. Move the head dimension
# to position 3. This more closely matches the return shape of the attention
# output: (b, n, h, d).
attn_weights = rearrange(attention, "b g h n s -> b n s (h g)")
if average_attn_weights:
attn_weights = attn_weights.mean(dim=1)
return out, attn_weights
class MultiheadGQA(nn.Module):
"""Multi-head grouped query attention (GQA) layer.
Reference:
"GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
https://arxiv.org/pdf/2305.13245v1.pdf
GQA is a variant of multihead attention (MHA) that uses fewer write heads
(key / value) than query heads. GQA can be viewed as a generalization of
multi-query attention (MQA), which uses a single write head. GQA and MQA give
significant speedups over standard MHA in decoder layers, with minimal loss in
accuracy. In the paper, GQA is shown to be more accurate than MQA, while still
having a significant speedup over MHA.
NOTE: The original authors only benchmark GQA by adapting the T5 (XL or XXL) model
from MHA to GQA. As a result, they do not mention parameter initialization or
layer normalization strategies. I follow the best practices laid out in the
MAGNETO paper, which improves Transformer performance through better parameter
initialization and layer norm placement. See:
https://arxiv.org/pdf/2210.06423.pdf, Fig. 2
"""
def __init__(
self,
embed_dim: int,
query_heads: int,
kv_heads: int,
dropout: float = 0.0,
bias: bool = True,
layer_norm: bool = True,
layer_norm_eps: float = 1e-5,
gamma_init: float = 1.0,
device: Optional[Union[torch.device, str]] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.query_heads = query_heads
self.kv_heads = kv_heads
self.dropout = dropout
self.layer_norm = layer_norm
self.gamma_init = gamma_init
if self.query_heads % self.kv_heads != 0:
raise ValueError(
f"query_heads ({query_heads}) must be divisible by "
f"kv_heads ({kv_heads})"
)
elif (embed_dim % self.query_heads != 0) or (embed_dim % self.kv_heads != 0):
raise ValueError(
f"embed_dim ({embed_dim}) must be divisible by "
f"query_heads ({query_heads}) and kv_heads ({kv_heads})"
)
head_dim = embed_dim // query_heads
if not head_dim % 8 == 0:
raise ValueError(
f"head_dim (embed_dim / num_heads = {head_dim}) must be divisible by 8"
)
if not head_dim <= 128:
raise ValueError(
f"head_dim (embed_dim / num_heads = {head_dim}) must be <= 128"
)
# Query projection layer is the same as in vanilla MHA.
self.q_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
# Key/value projection layers have a smaller output dimension, so that
# the we have fewer key/value attention heads after reshaping.
kv_embed_dim = embed_dim // query_heads * kv_heads
self.k_proj = nn.Linear(
embed_dim, kv_embed_dim, bias=bias, device=device, dtype=dtype
)
self.v_proj = nn.Linear(
embed_dim, kv_embed_dim, bias=bias, device=device, dtype=dtype
)
self.norm: Optional[nn.LayerNorm] = None
if layer_norm:
self.norm = nn.LayerNorm(
embed_dim, eps=layer_norm_eps, device=device, dtype=dtype
)
# Grouped attention output will have the same embedding dimension as the
# key/value Tensors. So the output projection layer needs to accept the
# same dimension (kv_embed_dim).
self.out_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
self._reset_parameters()
def _reset_parameters(self):
nn.init.xavier_normal_(self.q_proj.weight)
if self.q_proj.bias is not None:
nn.init.constant_(self.q_proj.bias, 0)
nn.init.xavier_normal_(self.k_proj.weight)
if self.k_proj.bias is not None:
nn.init.constant_(self.k_proj.bias, 0)
# NOTE: We follow the initialization strategy from MAGNETO. See:
# https://arxiv.org/pdf/2210.06423.pdf, Fig. 2
# Gain (self.gamma_init) should be provided as a keyword argument when
# initializing the larger Transformer model, since it requires knowledge
# of the number of encoder/decoder layers in the model.
nn.init.xavier_normal_(self.v_proj.weight, gain=self.gamma_init)
if self.v_proj.bias is not None:
nn.init.constant_(self.v_proj.bias, 0)
nn.init.xavier_normal_(self.out_proj.weight, gain=self.gamma_init)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
need_weights: bool = False,
# TODO
# attn_mask: Optional[Tensor] = None,
is_causal: bool = False,
average_attn_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
# Notation:
# b - batch size
# n - sequence length
# h - number of heads
# d - embedding dimension
#
# Input shape: (b, n, d)
q: Tensor = self.q_proj(query)
k: Tensor = self.k_proj(key)
v: Tensor = self.v_proj(value)
# Unfold 'd' dimension into 'h' separate attention heads.
q = rearrange(q, "b n (h d) -> b n h d", h=self.query_heads)
k = rearrange(k, "b n (h d) -> b n h d", h=self.kv_heads)
v = rearrange(v, "b n (h d) -> b n h d", h=self.kv_heads)
# Apply attention, then fold 'h' attention heads back into 'd'.
x, attn = scaled_dot_product_gqa(
query=q,
key=k,
value=v,
# TODO
# mask=attn_mask,
is_causal=is_causal,
need_weights=need_weights,
average_attn_weights=average_attn_weights,
force_grouped=False,
)
x = rearrange(x, "b n h d -> b n (h d)")
# NOTE: This is different from 'nn.MultiheadAttention'! We follow the MAGNETO
# architecture (https://arxiv.org/pdf/2210.06423.pdf), which applies an extra
# layer norm before the linear output projection. The cross-attention layer in
# the MAGNETO decoder does not include this layer norm, so users have the
# option to disable it (layer_norm=False).
if self.layer_norm:
assert self.norm is not None
x = self.norm(x)
# Linear projection on attention outputs.
x = self.out_proj(x)
return x, attn