Skip to content

Commit

Permalink
[XPU] llama swiglu uses Paddle's native swiglu
Browse files Browse the repository at this point in the history
  • Loading branch information
dynamicheart committed Nov 12, 2024
1 parent 10a62c7 commit 2b35515
Showing 1 changed file with 0 additions and 14 deletions.
14 changes: 0 additions & 14 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,20 +647,6 @@ def __init__(self, config):

def forward(self, x):
if self.fuse_attention_ffn:
# FIXME(yangjianbang): use paddle's native swiglu
if get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

out = self.gate_up_fused_proj(x)
out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True)
out = self.down_proj(out)
return out
except ImportError:
gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1)
out = self.down_proj(F.silu(gate_out) * up_out)
return out

x = swiglu(self.gate_up_fused_proj(x))
else:
x = swiglu(self.gate_proj(x), self.up_proj(x))
Expand Down

0 comments on commit 2b35515

Please # to comment.