Skip to content

bugfix: bugfix for blackwell mla split-k #1109

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ def gen_mla_module() -> JitSpec:
jit_env.FLASHINFER_CSRC_DIR / "cutlass_mla.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_mla_ops.cu",
],
extra_include_paths=[
jit_env.CUTLASS_INCLUDE_DIRS[0] / ".." / "examples" / "77_blackwell_fmha",
jit_env.CUTLASS_INCLUDE_DIRS[0] / ".." / "examples" / "common",
],
extra_cuda_cflags=sm100a_nvcc_flags,
)

Expand Down
4 changes: 2 additions & 2 deletions include/flashinfer/attention/blackwell/device/sm100_mla.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)

#include "kernel/sm100_fmha_mla_reduction.hpp"
#include "kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
#include "../kernel/sm100_fmha_mla_reduction.hpp"
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"

////////////////////////////////////////////////////////////////////////////////

Expand Down
190 changes: 190 additions & 0 deletions include/flashinfer/attention/blackwell/kernel/gather_tensor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once

#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cute/util/print.hpp"

namespace example {

using namespace cute;

// Empty type used to disable gather/scatter for a GEMM argument
struct NoGather {
template <class... Ts>
NoGather(Ts...){};
};

/// Function object that applies an index to its argument
template <class Index>
struct IndexedGather {
CUTE_HOST_DEVICE constexpr IndexedGather(Index const* indices = {}) : indices_(indices) {}

template <typename I>
CUTE_HOST_DEVICE constexpr Index operator()(I i) const {
return indices_[i];
}

CUTE_HOST_DEVICE friend void print(IndexedGather const& s) { cute::print("Indexed"); }

Index const* indices_;
};

/// Function object that applies a stride to its argument
/// Example: StridedFunc<int,_2> gathers every other row/column
template <class Stride>
struct StridedGather {
CUTE_HOST_DEVICE constexpr StridedGather(Stride stride = {}) : stride_(stride) {}

template <class I>
CUTE_HOST_DEVICE constexpr auto operator()(I i) const {
return i * stride_;
}

CUTE_HOST_DEVICE friend void print(StridedGather const& s) {
cute::print("Strided{");
print(s.stride_);
cute::print("}");
}

Stride stride_;
};

/// Custom stride object that applies a function followed by a stride
template <class Func, class Stride>
struct CustomStride {
CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride)
: func_(func), stride_(stride) {}

template <class I>
CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s) {
return s.func_(i) * s.stride_;
}

template <class I>
CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i) {
return s.func_(i) * s.stride_;
}

CUTE_HOST_DEVICE friend void print(CustomStride const& s) {
cute::print("Custom{");
print(s.func_);
cute::print(",");
print(s.stride_);
cute::print("}");
}

template <class Div>
CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) {
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_,
safe_div(s.stride_, div));
}

// Circumvent the requirement on make_layout that shape and stride are integral
template <class Shape>
CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape,
CustomStride const& stride) {
return Layout<Shape, CustomStride>(shape, stride);
}

Func func_;
Stride stride_;
};

template <class Stride, class Func>
CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func) {
// Use a dummy shape and replace the first non-unit stride with a custom gather stride
auto idx = find_if(stride, [](auto x) { return not is_constant<1, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;
return make_layout(repeat_like(stride, _1{}),
replace<I>(stride, CustomStride{static_cast<Func&&>(func), get<I>(stride)}));
}

/// Helper function to optionally create a gather tensor
template <class Iterator, class Shape, class Stride, class Func>
CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride,
Func&& func) {
if constexpr (not cutlass::platform::is_same<remove_cvref_t<Func>, NoGather>::value) {
Layout matrix_layout = make_identity_layout(shape);
auto offset = as_arithmetic_tuple(repeat_like(shape, _0{}));
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
} else {
return make_tensor(iter, shape, stride);
}
}

} // namespace example

namespace cute {

template <int N, int I, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride) {
if constexpr (is_tuple<Shape>::value) {
return transform_layout(shape, stride,
[](auto const& s, auto const& d) { return upcast<N, I>(s, d); });
} else if constexpr (is_scaled_basis<Stride>::value) {
if constexpr (Stride::mode() == I) {
return make_layout(ceil_div(shape, Int<N>{}), ceil_div(stride, Int<N>{}));
} else {
return make_layout(shape, stride);
}
} else {
return upcast<N>(shape, stride);
}

CUTE_GCC_UNREACHABLE;
}

template <int N, class OuterShape, class OuterStride, class Offset, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr auto upcast(
ComposedLayout<Layout<OuterShape, OuterStride>, Offset, Layout<Shape, Stride>> const& layout) {
// Find index of the stride-1 mode - that is the only one that requires updating inner shape and
// offset
auto idx =
find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;

// Upcast the outer layout (works as expected)
auto outer = upcast<N>(layout.layout_a());

// Upcast the accumulated offset along stride-1 mode
auto offset =
as_arithmetic_tuple(replace<I>(layout.offset(), upcast<N>(get<I>(layout.offset()))));

// Upcast the inner layout's shape along stride-1 mode
auto inner = upcast<N, I>(layout.layout_b().shape(), layout.layout_b().stride());

return composition(outer, offset, inner);
}

} // namespace cute
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ struct Sm100FmhaMlaReductionKernel {
ElementAcc sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max);
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
}

CUTLASS_PRAGMA_UNROLL
Expand All @@ -152,7 +152,7 @@ struct Sm100FmhaMlaReductionKernel {

ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse)
? std::numeric_limits<ElementAcc>::infinity()
: logf(sum_lse) + params.scale * lse_max;
: logf(sum_lse) + lse_max;
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
gLSE(0) = global_lse;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

#pragma once

#include "common/pow_2.hpp"
#include "../common/pow_2.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cute/tensor.hpp"
#include "cutlass/arch/arch.h"
Expand Down
6 changes: 3 additions & 3 deletions include/flashinfer/attention/cutlass_mla.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
#include "cutlass/kernel_hardware_info.h"

// From 3rdparty/cutlass/examples/77_blackwell_fmha
#include "device/sm100_mla.hpp"
#include "kernel/sm100_mla_tile_scheduler.hpp"
#include "blackwell/device/sm100_mla.hpp"
#include "blackwell/kernel/sm100_mla_tile_scheduler.hpp"

namespace flashinfer {

Expand Down Expand Up @@ -116,7 +116,7 @@ typename T::Fmha::Arguments args_from_options(void* out_ptr, void* lse_ptr, void
// static_cast<ElementAcc*>(lse.data_ptr()), stride_LSE},
static_cast<ElementAcc*>(nullptr), stride_LSE},
hw_info,
1, // split_kv
-1, // split_kv
nullptr, // is_var_split_kv=false
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
Expand Down
17 changes: 11 additions & 6 deletions tests/test_deepseek_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,12 +634,17 @@ def test_cutlass_mla(batch_size, max_seq_len, page_size, dtype):
head_dim_kpe = 64
total_page_num = 8192

q_nope_pe = torch.randn(
batch_size,
num_local_heads,
head_dim_ckv + head_dim_kpe,
dtype=dtype,
device="cuda",
# NOTE(Zihao): use larger scale to detect bugs such as
# https://github.com/flashinfer-ai/flashinfer/pull/1055
q_nope_pe = (
torch.randn(
batch_size,
num_local_heads,
head_dim_ckv + head_dim_kpe,
dtype=dtype,
device="cuda",
)
* 100
)
ckv_kpe = torch.randn(
total_page_num,
Expand Down