Skip to content

Commit

Permalink
recompute core attn
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Aug 22, 2022
1 parent 0ea6f2f commit 1794910
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 53 deletions.
2 changes: 1 addition & 1 deletion examples/gpt/hybrid_parallel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Data:
| type_vocab_size | 词表类型 |
| initializer_range | 参数初始化的范围 |
| use_recompute | 是否使用recompute训练 |
| recompute_granularity | recompute训练的粒度,可选 `full` `only_attn`,full即recompute全部transformer,only_attn表明只recompute self attention部分 |
| recompute_granularity | recompute训练的粒度,可选 `full` `full_attn` `core_attn`,full即recompute全部transformer,full_attn表明只recompute所有self attention部分,core_attn表明只recompute `softmax(qkT)v` 部分 |
| fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 |


Expand Down
2 changes: 1 addition & 1 deletion examples/gpt/single/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Data:
| type_vocab_size | 词表类型 |
| initializer_range | 参数初始化的范围 |
| use_recompute | 是否使用recompute训练 |
| recompute_granularity | recompute训练的粒度,可选 `full` `only_attn`,full即recompute全部transformer,only_attn表明只recompute self attention部分 |
| recompute_granularity | recompute训练的粒度,可选 `full` `full_attn` `core_attn`,full即recompute全部transformer,full_attn表明只recompute所有self attention部分,core_attn表明只recompute `softmax(qkT)v` 部分 |
| fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 |

### 优化器
Expand Down
4 changes: 2 additions & 2 deletions examples/gpt/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def process_model_configs(yaml_dict):
configs['ffn_hidden_size'] = 4 * configs['hidden_size']

if configs['use_recompute']:
assert configs['recompute_granularity'] in ["full", "only_attn"], \
assert configs['recompute_granularity'] in ["full", "full_attn", "core_attn"], \
"recompute_granularity can be only chosen from " \
"'full' or 'only_attn', but received '{}'".format(configs['recompute_granularity'])
"'full', 'full_attn' or 'core_attn', but received '{}'".format(configs['recompute_granularity'])

if configs['fused_linear'] and not is_fused_matmul_bias_supported():
configs['fused_linear'] = False
Expand Down
85 changes: 50 additions & 35 deletions fleetx/models/gpt_model/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ def __init__(self,
need_weights=False,
weight_attr=None,
bias_attr=None,
fuse=False,
fused_linear=False):
fuse=True,
fused_linear=False,
use_recompute=False,
recompute_granularity="full"):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
Expand All @@ -64,6 +66,8 @@ def __init__(self,
self.dropout = dropout
self.need_weights = need_weights
self.fuse = fuse
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity

self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
Expand All @@ -89,7 +93,6 @@ def _fuse_prepare_qkv(self, query):
mix_layer = self.qkv_proj(query)
mix_layer = paddle.reshape_(mix_layer,
[0, 0, self.num_heads, 3 * self.head_dim])
mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3])
q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1)
return q, k, v

Expand Down Expand Up @@ -164,6 +167,35 @@ def gen_cache(self, key, value=None, type=Cache):
# incremental_state with initial value, mainly for usage like UniLM
return self.Cache(key, value)

def core_attn(self, q, k, v, transpose=False):
if transpose:
q = paddle.transpose(q, [0, 2, 1, 3])
k = paddle.transpose(k, [0, 2, 1, 3])
v = paddle.transpose(v, [0, 2, 1, 3])
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim ** -0.5)

# TODO(liuyuang): support softmax_mask_fuse_upper_triangle for generation task
weights = F.softmax(product)

# weights = incubate.softmax_mask_fuse_upper_triangle(product)

if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")

out = tensor.matmul(weights, v)

# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

return out, weights

def forward(self,
query,
key,
Expand All @@ -187,30 +219,11 @@ def forward(self,
else:
q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
cache)
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)

# if attn_mask is not None:
# product = product + attn_mask

# TODO(liuyuang): support softmax_mask_fuse_upper_triangle for generation task
weights = F.softmax(product)

# weights = incubate.softmax_mask_fuse_upper_triangle(product)

if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")

out = tensor.matmul(weights, v)

# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
if self.use_recompute and self.recompute_granularity == "core_attn":
out, weights = recompute(self.core_attn, q, k, v, self.fuse)
else:
out, weights = self.core_attn(q, k, v, self.fuse)

# project to output
out = self.out_proj(out)
Expand Down Expand Up @@ -323,7 +336,8 @@ def __init__(self,
weight_attr=None,
bias_attr=None,
fused_linear=False,
recompute_attn=False):
use_recompute=False,
recompute_granularity="full"):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
Expand All @@ -332,7 +346,8 @@ def __init__(self,
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
self.recompute_attn = recompute_attn
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity

weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
Expand All @@ -345,7 +360,9 @@ def __init__(self,
dropout=attn_dropout,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0],
fused_linear=fused_linear)
fused_linear=fused_linear,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity)
self.linear1 = Linear(
d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2])
self.linear2 = Linear(
Expand All @@ -364,9 +381,8 @@ def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
tgt = self.norm1(tgt)

if use_cache is False:
if self.recompute_attn:
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask,
use_cache, cache)
if self.use_recompute and self.recompute_granularity == "full_attn":
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, use_cache, cache)
else:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else:
Expand Down Expand Up @@ -454,8 +470,6 @@ def __init__(self,

super(GPTModel, self).__init__()

recompute_attn = use_recompute and recompute_granularity == "only_attn"

self.initializer_range = initializer_range
self.hidden_size = hidden_size
self.vocab_size = vocab_size
Expand All @@ -480,7 +494,8 @@ def __init__(self,
mean=0.0, std=self.initializer_range)),
bias_attr=None,
fused_linear=fused_linear,
recompute_attn=recompute_attn))
use_recompute=use_recompute,
recompute_granularity=recompute_granularity))

self.decoder = TransformerDecoder(
decoder_layers,
Expand Down
34 changes: 20 additions & 14 deletions fleetx/models/gpt_model/modeling_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def __init__(self,
bias_attr=None,
fuse=True,
num_partitions=1,
fused_linear=False):
fused_linear=False,
use_recompute=False,
recompute_granularity="full"):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
Expand All @@ -88,6 +90,8 @@ def __init__(self,
self.dropout = dropout
self.need_weights = need_weights
self.fuse = fuse
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity

self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
Expand Down Expand Up @@ -380,7 +384,9 @@ def __init__(self,
bias_attr=None,
num_partitions=1,
fused_linear=False,
recompute_attn=False):
recompute_attn=False,
use_recompute=False,
recompute_granularity="full"):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
Expand All @@ -389,7 +395,8 @@ def __init__(self,
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
self.recompute_attn = recompute_attn
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity

weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
Expand All @@ -401,7 +408,9 @@ def __init__(self,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0],
num_partitions=num_partitions,
fused_linear=fused_linear)
fused_linear=fused_linear,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity)

self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
d_model,
Expand Down Expand Up @@ -437,9 +446,8 @@ def forward(self,
tgt = self.norm1(tgt)

if use_cache is False:
if self.recompute_attn:
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask,
use_cache, cache)
if self.use_recompute and self.recompute_granularity == "full_attn":
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, use_cache, cache)
else:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else:
Expand Down Expand Up @@ -535,8 +543,6 @@ def __init__(self,

super(GPTModel, self).__init__()

recompute_attn = use_recompute and recompute_granularity == "only_attn"

self.initializer_range = initializer_range
self.hidden_size = hidden_size
self.vocab_size = vocab_size
Expand All @@ -562,7 +568,8 @@ def __init__(self,
bias_attr=None,
num_partitions=num_partitions,
fused_linear=fused_linear,
recompute_attn=recompute_attn))
use_recompute=use_recompute,
recompute_granularity=recompute_granularity))

self.decoder = TransformerDecoder(
decoder_layers,
Expand Down Expand Up @@ -765,8 +772,6 @@ def __init__(self,
fused_linear=False,
recompute_granularity="full"):

recompute_attn = use_recompute and recompute_granularity == "only_attn"

# forward desc
self.descs = []

Expand Down Expand Up @@ -799,7 +804,8 @@ def __init__(self,
bias_attr=None,
num_partitions=num_partitions,
fused_linear=fused_linear,
recompute_attn=recompute_attn))
use_recompute=use_recompute,
recompute_granularity=recompute_granularity))

self.descs.append(
LayerDesc(
Expand All @@ -822,7 +828,7 @@ def _logits_helper(embedding, output):
initializer_range=0.02))

recompute_interval = 0
if recompute and not recompute_attn:
if recompute and recompute_granularity == "full":
recompute_interval = 1

super().__init__(
Expand Down

0 comments on commit 1794910

Please # to comment.