diff --git a/paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.cc b/paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.cc new file mode 100644 index 0000000000000..bbe6e22e752b8 --- /dev/null +++ b/paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.cc @@ -0,0 +1,478 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.h" + +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/utils/general_functions.h" + +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace { + +class FlashAttnPatternQscale : public paddle::drr::DrrPatternBase { + private: + bool softmax_with_cast_; + + public: + explicit FlashAttnPatternQscale(bool softmax_with_cast) + : softmax_with_cast_(softmax_with_cast) {} + + std::string name() const override { return "FlashAttnPatternQscale"; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); + // check the transpose + // q[b, s, head, head_dim] -> transpose -> q[b, head, s, head_dim] -> scale + const auto &transpose_q = src.Op("pd_op.transpose"); + src.Tensor("q_transpose_out") = transpose_q(src.Tensor("q")); + // scale before matmul + const auto &scale_q = src.Op("pd_op.scale"); + const auto &full_scale = + src.Op("pd_op.full", {{"value", src.Attr("scale_q_value")}}); + src.Tensor("q_scale_out") = + scale_q(src.Tensor("q_transpose_out"), full_scale()); + // k[b, s, head, head_dim] -> transpose -> k[b, head, s, head_dim] + // k[b, head, s, head_dim] -> transpose -> k[b, head, head_dim, s] + const auto &transpose_k = src.Op("pd_op.transpose"); + src.Tensor("k_transpose_out") = transpose_k(src.Tensor("k")); + const auto &transpose_k2 = src.Op("pd_op.transpose"); + src.Tensor("k_transpose2_out") = + transpose_k2(src.Tensor("k_transpose_out")); + // v[b, s, head, head_dim] -> transpose -> v[b, head, s, head_dim] + const auto &transpose_v = src.Op("pd_op.transpose"); + src.Tensor("v_transpose_out") = transpose_v(src.Tensor("v")); + // qk + const auto &qk_matmul = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_qk_transpose_x")}, + {"transpose_y", src.Attr("matmul_qk_transpose_y")}}); + src.Tensor("qk_out") = + qk_matmul(src.Tensor("q_scale_out"), src.Tensor("k_transpose2_out")); + + // mask + const auto &mask_add = src.Op("pd_op.add"); + src.Tensor("mask_add_out") = + mask_add(src.Tensor("qk_out"), src.Tensor("mask")); + + if (softmax_with_cast_) { + // cast + softmax + cast + const auto &softmax_cast1 = src.Op("pd_op.cast"); + src.Tensor("softmax_cast1_out") = + softmax_cast1(src.Tensor("mask_add_out")); + const auto &softmax = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}}); + src.Tensor("softmax_cast2_in") = softmax(src.Tensor("softmax_cast1_out")); + const auto &softmax_cast2 = src.Op("pd_op.cast"); + src.Tensor("softmax_out") = softmax_cast2(src.Tensor("softmax_cast2_in")); + } else { + // softmax + const auto &softmax = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}}); + src.Tensor("softmax_out") = softmax(src.Tensor("mask_add_out")); + } + + // o + const auto &context_matmul = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("context_matmul_transpose_x")}, + {"transpose_y", src.Attr("context_matmul_transpose_y")}}); + src.Tensor("context_matmul_out") = context_matmul( + src.Tensor("softmax_out"), src.Tensor("v_transpose_out")); + const auto &o_transpose = src.Op("pd_op.transpose"); + src.Tensor("out") = o_transpose(src.Tensor("context_matmul_out")); + + // Constraints + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + auto q_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("q")); + if (!q_dtype.isa() && + !q_dtype.isa()) { + return false; + } + // softmax + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + // matmul transpose + bool matmul_qk_transpose_x = + match_ctx.Attr("matmul_qk_transpose_x"); + bool matmul_qk_transpose_y = + match_ctx.Attr("matmul_qk_transpose_y"); + if (matmul_qk_transpose_x || matmul_qk_transpose_y) return false; + + bool matmul_o_transpose_x = + match_ctx.Attr("context_matmul_transpose_x"); + bool matmul_o_transpose_y = + match_ctx.Attr("context_matmul_transpose_y"); + if (matmul_o_transpose_x || matmul_o_transpose_y) return false; + // tensor shape + auto q_transpose_out = + pir::GetShapeFromValue(match_ctx.Tensor("q_transpose_out")); + auto k_transpose_out = + pir::GetShapeFromValue(match_ctx.Tensor("k_transpose_out")); + auto v_transpose_out = + pir::GetShapeFromValue(match_ctx.Tensor("v_transpose_out")); + if (q_transpose_out.size() != 4 || k_transpose_out.size() != 4 || + v_transpose_out.size() != 4 || + !(q_transpose_out.at(0) == k_transpose_out.at(0) && + k_transpose_out.at(0) == v_transpose_out.at(0)) || + !(q_transpose_out.at(1) == k_transpose_out.at(1) && + k_transpose_out.at(1) == v_transpose_out.at(1)) || + !(q_transpose_out.at(3) == k_transpose_out.at(3) && + k_transpose_out.at(3) == v_transpose_out.at(3))) { + return false; + } + // add shape + auto mask_add = pir::GetShapeFromValue(match_ctx.Tensor("mask")); + if (mask_add.size() != 4) { + return false; + } + + return true; + }); + + // + // Result Pattern. + // + paddle::drr::ResultPattern res = src.ResultPattern(); + const auto &flash_attn = res.Op("pd_op.flash_attn", + {{{"dropout", res.Float32Attr(0.0)}, + {"causal", res.BoolAttr(false)}, + {"return_softmax", res.BoolAttr(false)}, + {"is_test", res.BoolAttr(true)}, + {"rng_name", res.StrAttr("")}}}); + flash_attn({&res.Tensor("q"), + &res.Tensor("k"), + &res.Tensor("v"), + &res.InputNoneTensor(), + &res.Tensor("mask")}, + {&res.Tensor("out"), + &res.Tensor("softmax"), + &res.Tensor("softmax_lse"), + &res.Tensor("seed_offset")}); + } +}; + +// 1. scale after matmul +// 2. cast before and after softmax +class FlashAttnPatternOutscale : public paddle::drr::DrrPatternBase { + private: + bool softmax_with_cast_; + + public: + explicit FlashAttnPatternOutscale(bool softmax_with_cast) + : softmax_with_cast_(softmax_with_cast) {} + + public: + std::string name() const override { return "FlashAttnPatternOutscale"; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); + // check the transpose, + // q[b, s, head, head_dim] -> transpose -> q[b, head, s, head_dim] -> scale + const auto &transpose_q = src.Op("pd_op.transpose"); + src.Tensor("q_transpose_out") = transpose_q(src.Tensor("q")); + // k[b, s, head, head_dim] -> transpose -> k[b, head, s, head_dim] + // k[b, head, s, head_dim] -> transpose -> k[b, head, head_dim, s] + const auto &transpose_k = src.Op("pd_op.transpose"); + src.Tensor("k_transpose_out") = transpose_k(src.Tensor("k")); + const auto &transpose_k2 = src.Op("pd_op.transpose"); + src.Tensor("k_transpose2_out") = + transpose_k2(src.Tensor("k_transpose_out")); + // v[b, s, head, head_dim] -> transpose -> v[b, head, s, head_dim] + const auto &transpose_v = src.Op("pd_op.transpose"); + src.Tensor("v_transpose_out") = transpose_v(src.Tensor("v")); + // qk + const auto &qk_matmul = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_qk_transpose_x")}, + {"transpose_y", src.Attr("matmul_qk_transpose_y")}}); + src.Tensor("qk_out") = qk_matmul(src.Tensor("q_transpose_out"), + src.Tensor("k_transpose2_out")); + const auto &scale_out = src.Op("pd_op.scale"); + const auto &full_scale = + src.Op("pd_op.full", {{"value", src.Attr("scale_out_value")}}); + src.Tensor("qk_scale_out") = scale_out(src.Tensor("qk_out"), full_scale()); + + // mask + const auto &mask_add = src.Op("pd_op.add"); + src.Tensor("mask_add_out") = + mask_add(src.Tensor("qk_scale_out"), src.Tensor("mask")); + + if (softmax_with_cast_) { + // cast + softmax + cast + const auto &softmax_cast1 = src.Op("pd_op.cast"); + src.Tensor("softmax_cast1_out") = + softmax_cast1(src.Tensor("mask_add_out")); + const auto &softmax = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}}); + src.Tensor("softmax_cast2_in") = softmax(src.Tensor("softmax_cast1_out")); + const auto &softmax_cast2 = src.Op("pd_op.cast"); + src.Tensor("softmax_out") = softmax_cast2(src.Tensor("softmax_cast2_in")); + } else { + // softmax + const auto &softmax = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}}); + src.Tensor("softmax_out") = softmax(src.Tensor("mask_add_out")); + } + + // o + const auto &context_matmul = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("context_matmul_transpose_x")}, + {"transpose_y", src.Attr("context_matmul_transpose_y")}}); + src.Tensor("context_matmul_out") = context_matmul( + src.Tensor("softmax_out"), src.Tensor("v_transpose_out")); + const auto &o_transpose = src.Op("pd_op.transpose"); + src.Tensor("out") = o_transpose(src.Tensor("context_matmul_out")); + + // Constraints + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + auto q_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("q")); + if (!q_dtype.isa() && + !q_dtype.isa()) { + return false; + } + // softmax + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + // matmul transpose + bool matmul_qk_transpose_x = + match_ctx.Attr("matmul_qk_transpose_x"); + bool matmul_qk_transpose_y = + match_ctx.Attr("matmul_qk_transpose_y"); + if (matmul_qk_transpose_x || matmul_qk_transpose_y) return false; + + bool matmul_o_transpose_x = + match_ctx.Attr("context_matmul_transpose_x"); + bool matmul_o_transpose_y = + match_ctx.Attr("context_matmul_transpose_y"); + if (matmul_o_transpose_x || matmul_o_transpose_y) return false; + // tensor shape + auto q_transpose_out = + pir::GetShapeFromValue(match_ctx.Tensor("q_transpose_out")); + auto k_transpose_out = + pir::GetShapeFromValue(match_ctx.Tensor("k_transpose_out")); + auto v_transpose_out = + pir::GetShapeFromValue(match_ctx.Tensor("v_transpose_out")); + if (q_transpose_out.size() != 4 || k_transpose_out.size() != 4 || + v_transpose_out.size() != 4 || + !(q_transpose_out.at(0) == k_transpose_out.at(0) && + k_transpose_out.at(0) == v_transpose_out.at(0)) || + !(q_transpose_out.at(1) == k_transpose_out.at(1) && + k_transpose_out.at(1) == v_transpose_out.at(1)) || + !(q_transpose_out.at(3) == k_transpose_out.at(3) && + k_transpose_out.at(3) == v_transpose_out.at(3))) { + return false; + } + // add shape + auto mask_add = pir::GetShapeFromValue(match_ctx.Tensor("mask")); + if (mask_add.size() != 4) { + return false; + } + + return true; + }); + + // + // Result Pattern. + // + paddle::drr::ResultPattern res = src.ResultPattern(); + const auto &flash_attn = res.Op("pd_op.flash_attn", + {{{"dropout", res.Float32Attr(0.0)}, + {"causal", res.BoolAttr(false)}, + {"return_softmax", res.BoolAttr(false)}, + {"is_test", res.BoolAttr(true)}, + {"rng_name", res.StrAttr("")}}}); + flash_attn({&res.Tensor("q"), + &res.Tensor("k"), + &res.Tensor("v"), + &res.InputNoneTensor(), + &res.Tensor("mask")}, + {&res.Tensor("out"), + &res.Tensor("softmax"), + &res.Tensor("softmax_lse"), + &res.Tensor("seed_offset")}); + } +}; + +// slice qkv +class TransposeSliceFlashAttnPattern : public paddle::drr::DrrPatternBase { + public: + std::string name() const override { return "TransposeSliceFlashAttnPattern"; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); + // transpose + const auto &transpose_qkv = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm")}}); + src.Tensor("qkv_transpose") = transpose_qkv(src.Tensor("qkv")); + // slice q -> [b, head, s, head_dim] + const auto &slice_q = + src.Op(paddle::dialect::SliceOp::name(), + {{"axes", src.Attr("axes_q")}, + {"infer_flags", src.Attr("infer_flags_q")}, + {"decrease_axis", src.Attr("decrease_axis_q")}}); + const auto &full_int_array_q1 = src.Op("pd_op.full_int_array"); + const auto &full_int_array_q2 = src.Op("pd_op.full_int_array"); + src.Tensor("q") = slice_q( + src.Tensor("qkv_transpose"), full_int_array_q1(), full_int_array_q2()); + // slice k -> [b, head, s, head_dim] + const auto &slice_k = + src.Op(paddle::dialect::SliceOp::name(), + {{"axes", src.Attr("axes_k")}, + {"infer_flags", src.Attr("infer_flags_k")}, + {"decrease_axis", src.Attr("decrease_axis_k")}}); + const auto &full_int_array_k1 = src.Op("pd_op.full_int_array"); + const auto &full_int_array_k2 = src.Op("pd_op.full_int_array"); + src.Tensor("k") = slice_k( + src.Tensor("qkv_transpose"), full_int_array_k1(), full_int_array_k2()); + // slice v -> [b, head, s, head_dim] + const auto &slice_v = + src.Op(paddle::dialect::SliceOp::name(), + {{"axes", src.Attr("axes_v")}, + {"infer_flags", src.Attr("infer_flags_v")}, + {"decrease_axis", src.Attr("decrease_axis_v")}}); + const auto &full_int_array_v1 = src.Op("pd_op.full_int_array"); + const auto &full_int_array_v2 = src.Op("pd_op.full_int_array"); + src.Tensor("v") = slice_v( + src.Tensor("qkv_transpose"), full_int_array_v1(), full_int_array_v2()); + + // k[b, head, s, head_dim] -> transpose -> k[b, head, head_dim, s] + const auto &transpose_k = src.Op("pd_op.transpose"); + src.Tensor("k_transpose_out") = transpose_k(src.Tensor("k")); + // qk + const auto &qk_matmul = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_qk_transpose_x")}, + {"transpose_y", src.Attr("matmul_qk_transpose_y")}}); + src.Tensor("qk_out") = + qk_matmul(src.Tensor("q"), src.Tensor("k_transpose_out")); + // scale + const auto &scale_out = src.Op("pd_op.scale"); + const auto &full_scale = + src.Op("pd_op.full", {{"value", src.Attr("scale_out_value")}}); + src.Tensor("qk_scale_out") = scale_out(src.Tensor("qk_out"), full_scale()); + + // mask + const auto &mask_add = src.Op("pd_op.add"); + src.Tensor("mask_add_out") = + mask_add(src.Tensor("qk_scale_out"), src.Tensor("mask")); + + // softmax + const auto &softmax = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}}); + src.Tensor("softmax_out") = softmax(src.Tensor("mask_add_out")); + // o + const auto &context_matmul = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("context_matmul_transpose_x")}, + {"transpose_y", src.Attr("context_matmul_transpose_y")}}); + src.Tensor("context_matmul_out") = + context_matmul(src.Tensor("softmax_out"), src.Tensor("v")); + // [b, head, s, head_dim] -> [b, s, head, head_dim] + const auto &o_transpose = src.Op("pd_op.transpose"); + src.Tensor("out") = o_transpose(src.Tensor("context_matmul_out")); + + // Constraints + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + auto q_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("q")); + if (!q_dtype.isa() && + !q_dtype.isa()) { + return false; + } + // softmax + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + // matmul transpose + bool matmul_qk_transpose_x = + match_ctx.Attr("matmul_qk_transpose_x"); + bool matmul_qk_transpose_y = + match_ctx.Attr("matmul_qk_transpose_y"); + if (matmul_qk_transpose_x || matmul_qk_transpose_y) return false; + + bool matmul_o_transpose_x = + match_ctx.Attr("context_matmul_transpose_x"); + bool matmul_o_transpose_y = + match_ctx.Attr("context_matmul_transpose_y"); + if (matmul_o_transpose_x || matmul_o_transpose_y) return false; + // tensor shape + auto q = pir::GetShapeFromValue(match_ctx.Tensor("q")); + auto k = pir::GetShapeFromValue(match_ctx.Tensor("k")); + auto v = pir::GetShapeFromValue(match_ctx.Tensor("v")); + if (q.size() != 4 || k.size() != 4 || v.size() != 4 || + !(q.at(0) == k.at(0) && k.at(0) == v.at(0)) || + !(q.at(1) == k.at(1) && k.at(1) == v.at(1)) || + !(q.at(3) == k.at(3) && k.at(3) == v.at(3))) { + return false; + } + // add shape + auto mask_add = pir::GetShapeFromValue(match_ctx.Tensor("mask")); + if (mask_add.size() != 4) { + return false; + } + + return true; + }); + + // + // Result Pattern. + // + paddle::drr::ResultPattern res = src.ResultPattern(); + const auto &flash_attn = res.Op("pd_op.flash_attn", + {{{"dropout", res.Float32Attr(0.0)}, + {"causal", res.BoolAttr(false)}, + {"return_softmax", res.BoolAttr(false)}, + {"is_test", res.BoolAttr(true)}, + {"rng_name", res.StrAttr("")}}}); + flash_attn({&res.Tensor("q"), + &res.Tensor("k"), + &res.Tensor("v"), + &res.InputNoneTensor(), + &res.Tensor("mask")}, + {&res.Tensor("out"), + &res.Tensor("softmax"), + &res.Tensor("softmax_lse"), + &res.Tensor("seed_offset")}); + } +}; + +class FusedFlashAttnPass : public pir::PatternRewritePass { + public: + FusedFlashAttnPass() : pir::PatternRewritePass("fused_flash_attn_pass", 2) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(paddle::drr::Create(context, true)); + ps.Add(paddle::drr::Create(context, false)); + ps.Add(paddle::drr::Create(context, true)); + ps.Add(paddle::drr::Create(context, false)); + ps.Add(paddle::drr::Create(context)); + return ps; + } +}; + +} // namespace + +namespace pir { +std::unique_ptr CreateFusedFlashAttnPass() { + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(fused_flash_attn_pass, FusedFlashAttnPass); diff --git a/paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.h b/paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.h new file mode 100644 index 0000000000000..14183174760bc --- /dev/null +++ b/paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateFusedFlashAttnPass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/passes.h b/paddle/fluid/pir/transforms/passes.h index d22577adf874f..c73d695e46626 100644 --- a/paddle/fluid/pir/transforms/passes.h +++ b/paddle/fluid/pir/transforms/passes.h @@ -38,6 +38,7 @@ USE_PIR_PASS(conv2d_add_act_fuse_pass); USE_PIR_PASS(embedding_eltwise_layernorm_fuse_pass); USE_PIR_PASS(add_norm_fuse_pass); USE_PIR_PASS(fused_dot_product_attention_pass); +USE_PIR_PASS(fused_flash_attn_pass); #ifdef PADDLE_WITH_DNNL USE_PIR_PASS(batch_norm_act_fuse_pass); diff --git a/test/ir/pir/fused_pass/test_fused_flash_attn_pass.py b/test/ir/pir/fused_pass/test_fused_flash_attn_pass.py new file mode 100644 index 0000000000000..0ded90f26ad1f --- /dev/null +++ b/test/ir/pir/fused_pass/test_fused_flash_attn_pass.py @@ -0,0 +1,655 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import unittest + +import numpy as np +from pass_test import PassTest + +import paddle +from paddle.base import core + +np.random.seed(42) +paddle.enable_static() + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +is_sm_supported = ( + core.is_compiled_with_cuda() + and paddle.device.cuda.get_device_capability()[0] >= 8 + and paddle.device.cuda.get_device_capability()[1] >= 0 +) + + +def is_flashattn_supported(): + if ( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11040 + or not is_sm_supported + ): + return False + return True + + +@unittest.skipIf( + not is_flashattn_supported(), + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must >= 8.x", +) +class TestFlashAttnPatternQscaleCast(PassTest): + r""" + Q K V + | | | + transpose transpose transpose + | | | + scale transpose | + | | | + -- matmul-- | + | | + mask --- add | + | | + cast | + | | + softmax | + | | + cast | + | | + ------matmul------ + | + out + + Q K V None mask + | | | | | + ------flash_attn------ + | + out + """ + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + for bs in [1]: + for seq_len in [128]: + for head_dim in [64]: + for num_heads in [8]: + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard( + main_prog, start_prog + ): + mask_shape = (bs, num_heads, seq_len, seq_len) + Q = paddle.static.data( + name='Q', + shape=[bs, seq_len, num_heads, head_dim], + dtype='float16', + ) + K = paddle.static.data( + name='K', + shape=[bs, seq_len, num_heads, head_dim], + dtype='float16', + ) + V = paddle.static.data( + name='V', + shape=[bs, seq_len, num_heads, head_dim], + dtype='float16', + ) + mask = paddle.static.data( + name='mask', + shape=mask_shape, + dtype='float16', + ) + qt = paddle.transpose(Q, [0, 2, 1, 3]) + q_scale = paddle.scale( + qt, scale=0.125, bias=0.0 + ) + kt = paddle.transpose(K, [0, 2, 1, 3]) + kt = paddle.transpose(kt, [0, 1, 3, 2]) + vt = paddle.transpose(V, [0, 2, 1, 3]) + score = paddle.matmul(q_scale, kt) + score = paddle.add(score, mask) + cast_out = paddle.cast(score, 'float16') + softmax_out = paddle.nn.functional.softmax( + cast_out + ) + softmax_out = paddle.cast( + softmax_out, 'float16' + ) + attention_out = paddle.matmul(softmax_out, vt) + attention_out = paddle.transpose( + attention_out, [0, 2, 1, 3] + ) + out = paddle.assign(attention_out) + self.pass_list = ['fused_flash_attn_pass'] + self.feeds = { + "Q": np.random.random( + (bs, seq_len, num_heads, head_dim) + ).astype("float16"), + "K": np.random.random( + (bs, seq_len, num_heads, head_dim) + ).astype("float16"), + "V": np.random.random( + (bs, seq_len, num_heads, head_dim) + ).astype("float16"), + "mask": np.random.random( + (bs, num_heads, seq_len, seq_len) + ).astype("float16"), + } + self.fetch_list = [out] + self.valid_op_map = { + "pd_op.flash_attn": 1, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def test_check_output(self): + self.check_pass_correct(atol=1e-3, rtol=1e-3) + + +@unittest.skipIf( + not is_flashattn_supported(), + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must >= 8.x", +) +class TestFlashAttnPatternQscaleNoCast(PassTest): + r""" + Q K V + | | | + transpose transpose transpose + | | | + scale transpose | + | | | + -- matmul-- | + | | + mask --- add | + | | + | | + softmax | + | | + | | + ------matmul------ + | + out + + Q K V None mask + | | | | | + ------flash_attn------ + | + out + """ + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + for bs in [1]: + for seq_len in [128]: + for head_dim in [64]: + for num_heads in [8]: + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard( + main_prog, start_prog + ): + mask_shape = (bs, num_heads, seq_len, seq_len) + Q = paddle.static.data( + name='Q', + shape=[bs, seq_len, num_heads, head_dim], + dtype='float16', + ) + K = paddle.static.data( + name='K', + shape=[bs, seq_len, num_heads, head_dim], + dtype='float16', + ) + V = paddle.static.data( + name='V', + shape=[bs, seq_len, num_heads, head_dim], + dtype='float16', + ) + mask = paddle.static.data( + name='mask', + shape=mask_shape, + dtype='float16', + ) + qt = paddle.transpose(Q, [0, 2, 1, 3]) + q_scale = paddle.scale( + qt, scale=0.125, bias=0.0 + ) + kt = paddle.transpose(K, [0, 2, 1, 3]) + kt = paddle.transpose(kt, [0, 1, 3, 2]) + vt = paddle.transpose(V, [0, 2, 1, 3]) + score = paddle.matmul(q_scale, kt) + score = paddle.add(score, mask) + softmax_out = paddle.nn.functional.softmax( + score + ) + attention_out = paddle.matmul(softmax_out, vt) + attention_out = paddle.transpose( + attention_out, [0, 2, 1, 3] + ) + out = paddle.assign(attention_out) + self.pass_list = ['fused_flash_attn_pass'] + self.feeds = { + "Q": np.random.random( + (bs, seq_len, num_heads, head_dim) + ).astype("float16"), + "K": np.random.random( + (bs, seq_len, num_heads, head_dim) + ).astype("float16"), + "V": np.random.random( + (bs, seq_len, num_heads, head_dim) + ).astype("float16"), + "mask": np.random.random( + (bs, num_heads, seq_len, seq_len) + ).astype("float16"), + } + self.fetch_list = [out] + self.valid_op_map = { + "pd_op.flash_attn": 1, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def test_check_output(self): + self.check_pass_correct(atol=1e-3, rtol=1e-3) + + +@unittest.skipIf( + not is_flashattn_supported(), + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must >= 8.x", +) +class TestFlashAttnPatternOutscaleCast(PassTest): + r""" + Q K V + | | | + transpose transpose transpose + | | | + | transpose | + | | | + -- matmul-- | + | | + scale | + | | + mask --- add | + | | + cast | + | | + softmax | + | | + cast | + | | + ------matmul------ + | + out + + Q K V None mask + | | | | | + ------flash_attn------ + | + out + """ + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + for bs in [1]: + for seq_len in [128]: + for head_dim in [64]: + for num_heads in [8]: + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard( + main_prog, start_prog + ): + mask_shape = (bs, num_heads, seq_len, seq_len) + Q = paddle.static.data( + name='Q', + shape=[bs, seq_len, num_heads, head_dim], + dtype='float16', + ) + K = paddle.static.data( + name='K', + shape=[bs, seq_len, num_heads, head_dim], + dtype='float16', + ) + V = paddle.static.data( + name='V', + shape=[bs, seq_len, num_heads, head_dim], + dtype='float16', + ) + mask = paddle.static.data( + name='mask', + shape=mask_shape, + dtype='float16', + ) + qt = paddle.transpose(Q, [0, 2, 1, 3]) + kt = paddle.transpose(K, [0, 2, 1, 3]) + kt = paddle.transpose(kt, [0, 1, 3, 2]) + vt = paddle.transpose(V, [0, 2, 1, 3]) + + score = paddle.matmul(qt, kt) + score_scale = paddle.scale( + score, scale=0.125, bias=0.0 + ) + score = paddle.add(score_scale, mask) + cast_out = paddle.cast(score, 'float16') + softmax_out = paddle.nn.functional.softmax( + cast_out + ) + softmax_out = paddle.cast( + softmax_out, 'float16' + ) + attention_out = paddle.matmul(softmax_out, vt) + attention_out = paddle.transpose( + attention_out, [0, 2, 1, 3] + ) + out = paddle.assign(attention_out) + self.pass_list = ['fused_flash_attn_pass'] + self.feeds = { + "Q": np.random.random( + (bs, seq_len, num_heads, head_dim) + ).astype("float16"), + "K": np.random.random( + (bs, seq_len, num_heads, head_dim) + ).astype("float16"), + "V": np.random.random( + (bs, seq_len, num_heads, head_dim) + ).astype("float16"), + "mask": np.random.random( + (bs, num_heads, seq_len, seq_len) + ).astype("float16"), + } + self.fetch_list = [out] + self.valid_op_map = { + "pd_op.flash_attn": 1, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def test_check_output(self): + self.check_pass_correct(atol=1e-3, rtol=1e-3) + + +@unittest.skipIf( + not is_flashattn_supported(), + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must >= 8.x", +) +class TestFlashAttnPatternOutscaleNoCast(PassTest): + r""" + Q K V + | | | + transpose transpose transpose + | | | + | transpose | + | | | + -- matmul-- | + | | + scale | + | | + mask --- add | + | | + | | + softmax | + | | + | | + ------matmul------ + | + out + + Q K V None mask + | | | | | + ------flash_attn------ + | + out + """ + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + for bs in [1]: + for seq_len in [128]: + for head_dim in [64]: + for num_heads in [8]: + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard( + main_prog, start_prog + ): + mask_shape = (bs, num_heads, seq_len, seq_len) + Q = paddle.static.data( + name='Q', + shape=[bs, seq_len, num_heads, head_dim], + dtype='float16', + ) + K = paddle.static.data( + name='K', + shape=[bs, seq_len, num_heads, head_dim], + dtype='float16', + ) + V = paddle.static.data( + name='V', + shape=[bs, seq_len, num_heads, head_dim], + dtype='float16', + ) + mask = paddle.static.data( + name='mask', + shape=mask_shape, + dtype='float16', + ) + qt = paddle.transpose(Q, [0, 2, 1, 3]) + kt = paddle.transpose(K, [0, 2, 1, 3]) + kt = paddle.transpose(kt, [0, 1, 3, 2]) + vt = paddle.transpose(V, [0, 2, 1, 3]) + + score = paddle.matmul(qt, kt) + score_scale = paddle.scale( + score, scale=0.125, bias=0.0 + ) + score = paddle.add(score_scale, mask) + softmax_out = paddle.nn.functional.softmax( + score + ) + attention_out = paddle.matmul(softmax_out, vt) + attention_out = paddle.transpose( + attention_out, [0, 2, 1, 3] + ) + out = paddle.assign(attention_out) + self.pass_list = ['fused_flash_attn_pass'] + self.feeds = { + "Q": np.random.random( + (bs, seq_len, num_heads, head_dim) + ).astype("float16"), + "K": np.random.random( + (bs, seq_len, num_heads, head_dim) + ).astype("float16"), + "V": np.random.random( + (bs, seq_len, num_heads, head_dim) + ).astype("float16"), + "mask": np.random.random( + (bs, num_heads, seq_len, seq_len) + ).astype("float16"), + } + self.fetch_list = [out] + self.valid_op_map = { + "pd_op.flash_attn": 1, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def test_check_output(self): + self.check_pass_correct(atol=1e-3, rtol=1e-3) + + +@unittest.skipIf( + not is_flashattn_supported(), + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must >= 8.x", +) +class TestTransposeSliceFlashAttnPattern(PassTest): + r""" + transpose + | + -----------+---------- + | | | + slice slice slice + | | | + Q K V + | | | + | transpose | + | | | + -- matmul-- | + | | + scale | + | | + mask --- add | + | | + softmax | + | | + ------matmul------ + | + transpose + | + out + + transpose + | + ------+------ + | | | + slice slice slice + | | | + Q K V mask + | | | | + ------flash_attn------ + | + out + """ + + def is_program_valid(self, program=None): + return True + + def build_ir_program(self): + for bs in [1]: + for seq_len in [128]: + for head_dim in [64]: + for num_heads in [8]: + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard( + main_prog, start_prog + ): + x = paddle.static.data( + name='x', + shape=[bs, seq_len, 3, num_heads, head_dim], + dtype='float16', + ) + mask_shape = (bs, num_heads, seq_len, seq_len) + mask = paddle.static.data( + name='mask', + shape=mask_shape, + dtype='float16', + ) + xt = paddle.transpose(x, [2, 0, 3, 1, 4]) + q = xt[0, :, :, :, :] + k = xt[1, :, :, :, :] + v = xt[2, :, :, :, :] + kt = paddle.transpose(k, [0, 1, 3, 2]) + + score = paddle.matmul(q, kt) + score_scale = paddle.scale( + score, scale=0.125, bias=0.0 + ) + score_add = paddle.add(score_scale, mask) + softmax_out = paddle.nn.functional.softmax( + score_add + ) + attention_out = paddle.matmul(softmax_out, v) + attention_out = paddle.transpose( + attention_out, [0, 2, 1, 3] + ) + out = paddle.assign(attention_out) + self.pass_list = ['fused_flash_attn_pass'] + self.feeds = { + "x": np.random.random( + (bs, seq_len, 3, num_heads, head_dim) + ).astype("float16"), + "mask": np.random.random( + (bs, num_heads, seq_len, seq_len) + ).astype("float16"), + } + self.fetch_list = [out] + self.valid_op_map = { + "pd_op.flash_attn": 1, + } + return [main_prog, start_prog] + + def sample_program(self): + yield self.build_ir_program(), False + + def setUp(self): + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def test_check_output(self): + self.check_pass_correct(atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + unittest.main()