Skip to content

Commit

Permalink
misc: refactor group gemm API (#545)
Browse files Browse the repository at this point in the history
- [x] Rewrite GEMM arg preparation kernels with Triton
- [ ] Support customized `x_strides`, `weight_strides` input
- [ ] Move input checks to Python

---------

Signed-off-by: xsling <me@xsl.sh>
  • Loading branch information
xslingcn authored Nov 8, 2024
1 parent c7dc921 commit 9430a8a
Show file tree
Hide file tree
Showing 11 changed files with 467 additions and 385 deletions.
38 changes: 9 additions & 29 deletions include/flashinfer/gemm/group_gemm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,33 +36,9 @@ namespace group_gemm {

template <typename DType>
cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffer_size_in_bytes,
DType* x, DType* w, DType* y, int64_t* xy_indptr_d,
int64_t* w_indices_d, unsigned int batch_size, unsigned int d_in,
unsigned int d_out, bool weight_column_major,
cudaStream_t stream) {
AlignedAllocator allocator(workspace_buffer, workspace_buffer_size_in_bytes);
cutlass::gemm::GemmCoord* problem_sizes_device =
allocator.aligned_alloc<cutlass::gemm::GemmCoord>(
batch_size * sizeof(cutlass::gemm::GemmCoord), 16, "problem_sizes_device");
DType** x_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16, "x_data");
DType** w_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16, "w_data");
DType** y_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16, "y_data");
int64_t* ld_x = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16, "ld_x");
int64_t* ld_w = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16, "ld_w");
int64_t* ld_y = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16, "ld_y");

// NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API,
// so I just use the kernel function directly, need to investigate more.
auto compute_args_kernel = compute_sm80_cutlass_group_gemm_args<DType, DType>;
compute_args_kernel<<<batch_size, 1, 0, stream>>>(
problem_sizes_device, x_data, w_data, y_data, ld_x, ld_w, ld_y, (DType*)x, (DType*)w,
(DType*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
std::cerr << "Failed to launch kernel: " << cudaGetErrorString(err) << std::endl;
return err;
}

void* all_problems, unsigned int batch_size, void* x, void* w,
void* y, void* x_ld, void* w_ld, void* y_ld,
bool weight_column_major, cudaStream_t stream) {
using cutlass::epilogue::thread::LinearCombination;
using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle;
DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, {
Expand Down Expand Up @@ -91,8 +67,12 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe
using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp;
typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0);
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
typename GemmGrouped::Arguments args(problem_sizes_device, batch_size, 4, epilogue_op, x_data,
w_data, y_data, y_data, ld_x, ld_w, ld_y, ld_y);
typename GemmGrouped::Arguments args(
reinterpret_cast<cutlass::gemm::GemmCoord*>(all_problems), (int)batch_size,
/*threadblock_count=*/4, epilogue_op, static_cast<DType**>(x), static_cast<DType**>(w),
static_cast<DType**>(y), static_cast<DType**>(y), reinterpret_cast<int64_t*>(x_ld),
reinterpret_cast<int64_t*>(w_ld), reinterpret_cast<int64_t*>(y_ld),
reinterpret_cast<int64_t*>(y_ld));

GemmGrouped gemm;
auto status = gemm.initialize(args, nullptr, stream);
Expand Down
35 changes: 0 additions & 35 deletions include/flashinfer/gemm/group_gemm_cutlass.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,41 +56,6 @@ struct cutlass_dtype<__nv_fp8_e5m2> {
using type = cutlass::float_e5m2_t;
};

template <typename DTypeIn, typename DTypeOut>
__global__ void compute_sm80_cutlass_group_gemm_args(
cutlass::gemm::GemmCoord* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr,
int64_t* x_ld, int64_t* w_ld, int64_t* y_ld, DTypeIn* x, DTypeIn* w, DTypeOut* y,
int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) {
int i = blockIdx.x;
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
all_problems[i] = cutlass::gemm::GemmCoord(m, n, k);
w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n;
x_ptr[i] = x + xy_indptr[i] * k;
y_ptr[i] = y + xy_indptr[i] * n;
x_ld[i] = k; // m * k
w_ld[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
y_ld[i] = n; // m * n
}

template <typename DTypeIn, typename DTypeOut, typename ProblemShape, typename StrideA,
typename StrideB, typename StrideCD>
__global__ void compute_sm90_cutlass_group_gemm_args(
ProblemShape* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr,
StrideA* x_stride, StrideB* w_stride, StrideCD* y_stride, DTypeIn* x, DTypeIn* w, DTypeOut* y,
int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) {
int i = blockIdx.x;
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
all_problems[i] = ProblemShape(m, n, k);
w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n;
x_ptr[i] = x + xy_indptr[i] * k;
y_ptr[i] = y + xy_indptr[i] * n;

x_stride[i] = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
w_stride[i] = w_column_major ? cutlass::make_cute_packed_stride(StrideB{}, {k, n, 1})
: cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
y_stride[i] = cutlass::make_cute_packed_stride(StrideCD{}, {m, n, 1});
}

} // namespace group_gemm

} // namespace flashinfer
Expand Down
251 changes: 93 additions & 158 deletions include/flashinfer/gemm/group_gemm_sm90.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,176 +76,111 @@ using namespace cute;

template <typename DTypeIn, typename DTypeOut>
cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_size_in_bytes,
void* int_buffer, size_t int_buffer_size_in_bytes, DTypeIn* x,
DTypeIn* w, DTypeOut* y, int64_t* xy_indptr_d,
int64_t* w_indices_d, unsigned int batch_size,
unsigned int d_in, unsigned int d_out,
void* int_buffer, size_t int_buffer_size_in_bytes,
void* all_problems, unsigned int batch_size, void* x, void* w,
void* y, void* x_stride, void* w_stride, void* y_stride,
bool weight_column_major, cudaStream_t stream) {
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first < 9) {
std::cerr << "CutlassSegmentGEMMSM90Run requires compute capability of at least 9.0"
<< std::endl;
return cudaErrorNotSupported;
} else {
// Compute capability >= 9.0
// Reference implementation
// -
// https://github.com/NVIDIA/cutlass/blob/f7b19de32c5d1f3cedfc735c2849f12b537522ee/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per group
using ElementA = DTypeIn; // Element type for A matrix operand
using ElementB = DTypeIn; // Element type for B matrix operand
using ElementC = DTypeOut; // Element type for C and D matrix operands

DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, {
if constexpr (std::is_same_v<WEIGHT_LAYOUT, cutlass::layout::RowMajor> &&
sizeof(DTypeIn) == 1) {
std::ostringstream err_msg;
err_msg << "Row-major layout is not supported for fp8 data type";
throw std::runtime_error(err_msg.str());
} else {
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of
// elements (up to 16 bytes)

// B matrix configuration
using LayoutB = WEIGHT_LAYOUT; // Layout type for B matrix operand
constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of
// elements (up to 16 bytes)

// C/D matrix configuration
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of
// elements (up to 16 bytes)

constexpr bool is_fp8 = sizeof(DTypeIn) == 1;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the
// intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape =
typename std::conditional<is_fp8, Shape<_128, _128, _64>,
Shape<_128, _64, _64>>::type; // Threadblock-level tile size
using ClusterShape =
typename std::conditional<is_fp8, Shape<_2, _2, _1>, Shape<_2, _1, _1>>::
type; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized
// based on the tile size
using KernelSchedule = typename std::conditional<
is_fp8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>::type; // Kernel to launch
using EpilogueSchedule =
cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator,
ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC,
EpilogueSchedule>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

// Reference device GEMM implementation type
using DeviceGemmReference =
cutlass::reference::device::Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC,
LayoutC, ElementAccumulator, ElementAccumulator>;

using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;

AlignedAllocator allocator(int_buffer, int_buffer_size_in_bytes);
ProblemShape::UnderlyingProblemShape* problem_sizes_device =
allocator.aligned_alloc<ProblemShape::UnderlyingProblemShape>(
batch_size * sizeof(ProblemShape::UnderlyingProblemShape), 16,
"problem_sizes_device");
DTypeIn** x_data =
allocator.aligned_alloc<DTypeIn*>(batch_size * sizeof(DTypeIn*), 16, "x_data");
DTypeIn** w_data =
allocator.aligned_alloc<DTypeIn*>(batch_size * sizeof(DTypeIn*), 16, "w_data");
DTypeOut** y_data =
allocator.aligned_alloc<DTypeOut*>(batch_size * sizeof(DTypeOut*), 16, "y_data");
StrideA* x_stride =
allocator.aligned_alloc<StrideA>(batch_size * sizeof(StrideA), 16, "x_stride");
StrideB* w_stride =
allocator.aligned_alloc<StrideB>(batch_size * sizeof(StrideB), 16, "w_stride");
StrideC* y_stride =
allocator.aligned_alloc<StrideC>(batch_size * sizeof(StrideC), 16, "y_stride");

cutlass::KernelHardwareInfo hw_info;
cudaGetDevice(&hw_info.device_id);
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

typename Gemm::EpilogueOutputOp::Params params;
// TODO(Zihao): support block alpha and beta
params = typename Gemm::EpilogueOutputOp::Params(/*alpha=*/ElementAccumulator(1.f),
/*beta=*/ElementAccumulator(0.f));

typename Gemm::Arguments arguments;

arguments = typename Gemm::Arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{int(batch_size), problem_sizes_device, nullptr},
{const_cast<const DTypeIn**>(x_data), x_stride, const_cast<const DTypeIn**>(w_data),
w_stride},
{params, const_cast<const DTypeIn**>(y_data), y_stride, y_data, y_stride},
hw_info};

compute_sm90_cutlass_group_gemm_args<<<batch_size, 1, 0, stream>>>(
problem_sizes_device, x_data, w_data, y_data, x_stride, w_stride, y_stride, (DTypeIn*)x,
(DTypeIn*)w, (DTypeOut*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
std::cerr << "Failed to launch compute_sm90_cutlass_group_gemm_args kernel: "
<< cudaGetErrorString(err) << std::endl;
return err;
}

// Initialize the gemm kernel
Gemm gemm;

// Using the arguments, query for extra workspace required for matrix multiplication
// computation
size_t workspace_size = Gemm::get_workspace_size(arguments);

// Allocate workspace memory
AlignedAllocator float_allocator(float_buffer, float_buffer_size_in_bytes);
auto workspace_ptr = float_allocator.aligned_alloc<void>(workspace_size, 64,
"sm90_group_gemm_float_workspace");

// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));

// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace_ptr));

// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run()); // Warmup
}
});
}

using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
using ElementA = DTypeIn;
using ElementB = DTypeIn;
using ElementC = DTypeOut;

DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, {
if constexpr (std::is_same_v<WEIGHT_LAYOUT, cutlass::layout::RowMajor> &&
sizeof(DTypeIn) == 1) {
std::ostringstream err_msg;
err_msg << "Row-major layout is not supported for fp8 data type";
throw std::runtime_error(err_msg.str());
} else {
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

using LayoutB = WEIGHT_LAYOUT;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

using LayoutC = cutlass::layout::RowMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;

constexpr bool is_fp8 = sizeof(DTypeIn) == 1;

using ElementAccumulator = float;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape =
typename std::conditional<is_fp8, Shape<_128, _128, _64>, Shape<_128, _64, _64>>::type;
using ClusterShape =
typename std::conditional<is_fp8, Shape<_2, _2, _1>, Shape<_2, _1, _1>>::type;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
using KernelSchedule = typename std::conditional<
is_fp8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>::type;
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator,
ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC,
EpilogueSchedule>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;

cutlass::KernelHardwareInfo hw_info;
cudaGetDevice(&hw_info.device_id);
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

typename Gemm::EpilogueOutputOp::Params params;
params =
typename Gemm::EpilogueOutputOp::Params(ElementAccumulator(1.f), ElementAccumulator(0.f));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{int(batch_size), reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(all_problems),
nullptr},
{static_cast<const DTypeIn**>(x), reinterpret_cast<StrideA*>(x_stride),
static_cast<const DTypeIn**>(w), reinterpret_cast<StrideB*>(w_stride)},
{params, static_cast<const DTypeOut**>(y), reinterpret_cast<StrideC*>(y_stride),
static_cast<DTypeOut**>(y), reinterpret_cast<StrideD*>(y_stride)},
hw_info};

Gemm gemm;

size_t workspace_size = Gemm::get_workspace_size(arguments);
AlignedAllocator float_allocator(float_buffer, float_buffer_size_in_bytes);
auto workspace_ptr = float_allocator.aligned_alloc<void>(workspace_size, 64,
"sm90_group_gemm_float_workspace");

CUTLASS_CHECK(gemm.can_implement(arguments));
CUTLASS_CHECK(gemm.initialize(arguments, workspace_ptr));
CUTLASS_CHECK(gemm.run(stream));
}
});

return cudaSuccess;
}

} // namespace group_gemm

} // namespace flashinfer

#endif // FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_
2 changes: 1 addition & 1 deletion python/aot_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None:
name="flashinfer._kernels_sm90",
sources=[
"csrc/group_gemm_sm90.cu",
"csrc_aot/flashinfer_sm90_ops.cu",
"csrc/flashinfer_gemm_sm90_ops.cu",
],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args_sm90,
Expand Down
Loading

0 comments on commit 9430a8a

Please # to comment.