Skip to content

Commit

Permalink
fit hybrid
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Aug 22, 2022
1 parent c897602 commit bee465e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 30 deletions.
10 changes: 3 additions & 7 deletions fleetx/models/gpt_model/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,7 @@ def gen_cache(self, key, value=None, type=Cache):
# incremental_state with initial value, mainly for usage like UniLM
return self.Cache(key, value)

def core_attn(self, q, k, v, transpose=False):
if transpose:
q = paddle.transpose(q, [0, 2, 1, 3])
k = paddle.transpose(k, [0, 2, 1, 3])
v = paddle.transpose(v, [0, 2, 1, 3])
def core_attn(self, q, k, v):
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim ** -0.5)
Expand Down Expand Up @@ -222,9 +218,9 @@ def forward(self,
cache)

if self.use_recompute and self.recompute_granularity == "core_attn":
out, weights = recompute(self.core_attn, q, k, v, self.fuse)
out, weights = recompute(self.core_attn, q, k, v)
else:
out, weights = self.core_attn(q, k, v, self.fuse)
out, weights = self.core_attn(q, k, v)

# project to output
out = self.out_proj(out)
Expand Down
55 changes: 32 additions & 23 deletions fleetx/models/gpt_model/modeling_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,34 @@ def gen_cache(self, key, value=None, type=Cache):
# incremental_state with initial value, mainly for usage like UniLM
return self.Cache(key, value)

def core_attn(self, q, k, v):
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim ** -0.5)

# if attn_mask is not None:
# product = product + attn_mask

# weights = F.softmax(product)

weights = incubate.softmax_mask_fuse_upper_triangle(product)

if self.dropout:
with get_rng_state_tracker().rng_state('local_seed'):
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")

out = tensor.matmul(weights, v)

# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

return out, weights

def forward(self,
query,
key,
Expand All @@ -246,30 +274,11 @@ def forward(self,
else:
q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
cache)
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)

# if attn_mask is not None:
# product = product + attn_mask

# weights = F.softmax(product)

weights = incubate.softmax_mask_fuse_upper_triangle(product)

if self.dropout:
with get_rng_state_tracker().rng_state('local_seed'):
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")

out = tensor.matmul(weights, v)

# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
if self.use_recompute and self.recompute_granularity == "core_attn":
out, weights = recompute(self.core_attn, q, k, v)
else:
out, weights = self.core_attn(q, k, v)

# project to output
out = self.out_proj(out)
Expand Down

0 comments on commit bee465e

Please # to comment.