From 9430a8a618b2e65cd94d06f075509d78b0ca580f Mon Sep 17 00:00:00 2001 From: Shanli Xing Date: Thu, 7 Nov 2024 19:45:12 -0800 Subject: [PATCH] misc: refactor group gemm API (#545) - [x] Rewrite GEMM arg preparation kernels with Triton - [ ] Support customized `x_strides`, `weight_strides` input - [ ] Move input checks to Python --------- Signed-off-by: xsling --- include/flashinfer/gemm/group_gemm.cuh | 38 +- .../flashinfer/gemm/group_gemm_cutlass.cuh | 35 -- include/flashinfer/gemm/group_gemm_sm90.cuh | 251 +++++------- python/aot_setup.py | 2 +- python/csrc/flashinfer_gemm_ops.cu | 8 +- python/csrc/flashinfer_gemm_sm90_ops.cu | 11 +- python/csrc/group_gemm.cu | 44 +- python/csrc/group_gemm_sm90.cu | 51 +-- python/csrc_aot/flashinfer_ops.cu | 8 +- python/csrc_aot/flashinfer_sm90_ops.cu | 27 -- python/flashinfer/gemm.py | 377 +++++++++++++++--- 11 files changed, 467 insertions(+), 385 deletions(-) delete mode 100644 python/csrc_aot/flashinfer_sm90_ops.cu diff --git a/include/flashinfer/gemm/group_gemm.cuh b/include/flashinfer/gemm/group_gemm.cuh index b6d1c4cc..2fe458e6 100644 --- a/include/flashinfer/gemm/group_gemm.cuh +++ b/include/flashinfer/gemm/group_gemm.cuh @@ -36,33 +36,9 @@ namespace group_gemm { template cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffer_size_in_bytes, - DType* x, DType* w, DType* y, int64_t* xy_indptr_d, - int64_t* w_indices_d, unsigned int batch_size, unsigned int d_in, - unsigned int d_out, bool weight_column_major, - cudaStream_t stream) { - AlignedAllocator allocator(workspace_buffer, workspace_buffer_size_in_bytes); - cutlass::gemm::GemmCoord* problem_sizes_device = - allocator.aligned_alloc( - batch_size * sizeof(cutlass::gemm::GemmCoord), 16, "problem_sizes_device"); - DType** x_data = allocator.aligned_alloc(batch_size * sizeof(DType*), 16, "x_data"); - DType** w_data = allocator.aligned_alloc(batch_size * sizeof(DType*), 16, "w_data"); - DType** y_data = allocator.aligned_alloc(batch_size * sizeof(DType*), 16, "y_data"); - int64_t* ld_x = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16, "ld_x"); - int64_t* ld_w = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16, "ld_w"); - int64_t* ld_y = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16, "ld_y"); - - // NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API, - // so I just use the kernel function directly, need to investigate more. - auto compute_args_kernel = compute_sm80_cutlass_group_gemm_args; - compute_args_kernel<<>>( - problem_sizes_device, x_data, w_data, y_data, ld_x, ld_w, ld_y, (DType*)x, (DType*)w, - (DType*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major); - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - std::cerr << "Failed to launch kernel: " << cudaGetErrorString(err) << std::endl; - return err; - } - + void* all_problems, unsigned int batch_size, void* x, void* w, + void* y, void* x_ld, void* w_ld, void* y_ld, + bool weight_column_major, cudaStream_t stream) { using cutlass::epilogue::thread::LinearCombination; using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, { @@ -91,8 +67,12 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); using GemmGrouped = cutlass::gemm::device::GemmGrouped; - typename GemmGrouped::Arguments args(problem_sizes_device, batch_size, 4, epilogue_op, x_data, - w_data, y_data, y_data, ld_x, ld_w, ld_y, ld_y); + typename GemmGrouped::Arguments args( + reinterpret_cast(all_problems), (int)batch_size, + /*threadblock_count=*/4, epilogue_op, static_cast(x), static_cast(w), + static_cast(y), static_cast(y), reinterpret_cast(x_ld), + reinterpret_cast(w_ld), reinterpret_cast(y_ld), + reinterpret_cast(y_ld)); GemmGrouped gemm; auto status = gemm.initialize(args, nullptr, stream); diff --git a/include/flashinfer/gemm/group_gemm_cutlass.cuh b/include/flashinfer/gemm/group_gemm_cutlass.cuh index 0f71fa3d..5a9af0c1 100644 --- a/include/flashinfer/gemm/group_gemm_cutlass.cuh +++ b/include/flashinfer/gemm/group_gemm_cutlass.cuh @@ -56,41 +56,6 @@ struct cutlass_dtype<__nv_fp8_e5m2> { using type = cutlass::float_e5m2_t; }; -template -__global__ void compute_sm80_cutlass_group_gemm_args( - cutlass::gemm::GemmCoord* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr, - int64_t* x_ld, int64_t* w_ld, int64_t* y_ld, DTypeIn* x, DTypeIn* w, DTypeOut* y, - int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) { - int i = blockIdx.x; - int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out; - all_problems[i] = cutlass::gemm::GemmCoord(m, n, k); - w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n; - x_ptr[i] = x + xy_indptr[i] * k; - y_ptr[i] = y + xy_indptr[i] * n; - x_ld[i] = k; // m * k - w_ld[i] = w_column_major ? k : n; // k * n if column major, n * k if row major - y_ld[i] = n; // m * n -} - -template -__global__ void compute_sm90_cutlass_group_gemm_args( - ProblemShape* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr, - StrideA* x_stride, StrideB* w_stride, StrideCD* y_stride, DTypeIn* x, DTypeIn* w, DTypeOut* y, - int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) { - int i = blockIdx.x; - int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out; - all_problems[i] = ProblemShape(m, n, k); - w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n; - x_ptr[i] = x + xy_indptr[i] * k; - y_ptr[i] = y + xy_indptr[i] * n; - - x_stride[i] = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); - w_stride[i] = w_column_major ? cutlass::make_cute_packed_stride(StrideB{}, {k, n, 1}) - : cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); - y_stride[i] = cutlass::make_cute_packed_stride(StrideCD{}, {m, n, 1}); -} - } // namespace group_gemm } // namespace flashinfer diff --git a/include/flashinfer/gemm/group_gemm_sm90.cuh b/include/flashinfer/gemm/group_gemm_sm90.cuh index 23901288..c446910a 100644 --- a/include/flashinfer/gemm/group_gemm_sm90.cuh +++ b/include/flashinfer/gemm/group_gemm_sm90.cuh @@ -76,176 +76,111 @@ using namespace cute; template cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_size_in_bytes, - void* int_buffer, size_t int_buffer_size_in_bytes, DTypeIn* x, - DTypeIn* w, DTypeOut* y, int64_t* xy_indptr_d, - int64_t* w_indices_d, unsigned int batch_size, - unsigned int d_in, unsigned int d_out, + void* int_buffer, size_t int_buffer_size_in_bytes, + void* all_problems, unsigned int batch_size, void* x, void* w, + void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major, cudaStream_t stream) { auto compute_capacity = GetCudaComputeCapability(); if (compute_capacity.first < 9) { std::cerr << "CutlassSegmentGEMMSM90Run requires compute capability of at least 9.0" << std::endl; return cudaErrorNotSupported; - } else { - // Compute capability >= 9.0 - // Reference implementation - // - - // https://github.com/NVIDIA/cutlass/blob/f7b19de32c5d1f3cedfc735c2849f12b537522ee/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu - using ProblemShape = - cutlass::gemm::GroupProblemShape>; // per group - using ElementA = DTypeIn; // Element type for A matrix operand - using ElementB = DTypeIn; // Element type for B matrix operand - using ElementC = DTypeOut; // Element type for C and D matrix operands - - DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, { - if constexpr (std::is_same_v && - sizeof(DTypeIn) == 1) { - std::ostringstream err_msg; - err_msg << "Row-major layout is not supported for fp8 data type"; - throw std::runtime_error(err_msg.str()); - } else { - using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand - constexpr int AlignmentA = - 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of - // elements (up to 16 bytes) - - // B matrix configuration - using LayoutB = WEIGHT_LAYOUT; // Layout type for B matrix operand - constexpr int AlignmentB = - 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of - // elements (up to 16 bytes) - - // C/D matrix configuration - using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands - constexpr int AlignmentC = - 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of - // elements (up to 16 bytes) - - constexpr bool is_fp8 = sizeof(DTypeIn) == 1; - // Core kernel configurations - using ElementAccumulator = float; // Element type for internal accumulation - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the - // intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag - using TileShape = - typename std::conditional, - Shape<_128, _64, _64>>::type; // Threadblock-level tile size - using ClusterShape = - typename std::conditional, Shape<_2, _1, _1>>:: - type; // Shape of the threadblocks in a cluster - using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized - // based on the tile size - using KernelSchedule = typename std::conditional< - is_fp8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, - cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>::type; // Kernel to launch - using EpilogueSchedule = - cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, - ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, - EpilogueSchedule>::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB, - ElementAccumulator, TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - // Reference device GEMM implementation type - using DeviceGemmReference = - cutlass::reference::device::Gemm; - - using StrideA = typename Gemm::GemmKernel::InternalStrideA; - using StrideB = typename Gemm::GemmKernel::InternalStrideB; - using StrideC = typename Gemm::GemmKernel::InternalStrideC; - using StrideD = typename Gemm::GemmKernel::InternalStrideD; - - AlignedAllocator allocator(int_buffer, int_buffer_size_in_bytes); - ProblemShape::UnderlyingProblemShape* problem_sizes_device = - allocator.aligned_alloc( - batch_size * sizeof(ProblemShape::UnderlyingProblemShape), 16, - "problem_sizes_device"); - DTypeIn** x_data = - allocator.aligned_alloc(batch_size * sizeof(DTypeIn*), 16, "x_data"); - DTypeIn** w_data = - allocator.aligned_alloc(batch_size * sizeof(DTypeIn*), 16, "w_data"); - DTypeOut** y_data = - allocator.aligned_alloc(batch_size * sizeof(DTypeOut*), 16, "y_data"); - StrideA* x_stride = - allocator.aligned_alloc(batch_size * sizeof(StrideA), 16, "x_stride"); - StrideB* w_stride = - allocator.aligned_alloc(batch_size * sizeof(StrideB), 16, "w_stride"); - StrideC* y_stride = - allocator.aligned_alloc(batch_size * sizeof(StrideC), 16, "y_stride"); - - cutlass::KernelHardwareInfo hw_info; - cudaGetDevice(&hw_info.device_id); - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - - typename Gemm::EpilogueOutputOp::Params params; - // TODO(Zihao): support block alpha and beta - params = typename Gemm::EpilogueOutputOp::Params(/*alpha=*/ElementAccumulator(1.f), - /*beta=*/ElementAccumulator(0.f)); - - typename Gemm::Arguments arguments; - - arguments = typename Gemm::Arguments{ - cutlass::gemm::GemmUniversalMode::kGrouped, - {int(batch_size), problem_sizes_device, nullptr}, - {const_cast(x_data), x_stride, const_cast(w_data), - w_stride}, - {params, const_cast(y_data), y_stride, y_data, y_stride}, - hw_info}; - - compute_sm90_cutlass_group_gemm_args<<>>( - problem_sizes_device, x_data, w_data, y_data, x_stride, w_stride, y_stride, (DTypeIn*)x, - (DTypeIn*)w, (DTypeOut*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major); - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - std::cerr << "Failed to launch compute_sm90_cutlass_group_gemm_args kernel: " - << cudaGetErrorString(err) << std::endl; - return err; - } - - // Initialize the gemm kernel - Gemm gemm; - - // Using the arguments, query for extra workspace required for matrix multiplication - // computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - AlignedAllocator float_allocator(float_buffer, float_buffer_size_in_bytes); - auto workspace_ptr = float_allocator.aligned_alloc(workspace_size, 64, - "sm90_group_gemm_float_workspace"); - - // Check if the problem size is supported or not - CUTLASS_CHECK(gemm.can_implement(arguments)); - - // Initialize CUTLASS kernel with arguments and workspace pointer - CUTLASS_CHECK(gemm.initialize(arguments, workspace_ptr)); - - // Correctness / Warmup iteration - CUTLASS_CHECK(gemm.run()); // Warmup - } - }); } + using ProblemShape = cutlass::gemm::GroupProblemShape>; + using ElementA = DTypeIn; + using ElementB = DTypeIn; + using ElementC = DTypeOut; + + DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, { + if constexpr (std::is_same_v && + sizeof(DTypeIn) == 1) { + std::ostringstream err_msg; + err_msg << "Row-major layout is not supported for fp8 data type"; + throw std::runtime_error(err_msg.str()); + } else { + using LayoutA = cutlass::layout::RowMajor; + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using LayoutB = WEIGHT_LAYOUT; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using LayoutC = cutlass::layout::RowMajor; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + constexpr bool is_fp8 = sizeof(DTypeIn) == 1; + + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = + typename std::conditional, Shape<_128, _64, _64>>::type; + using ClusterShape = + typename std::conditional, Shape<_2, _1, _1>>::type; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = typename std::conditional< + is_fp8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>::type; + using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + cutlass::KernelHardwareInfo hw_info; + cudaGetDevice(&hw_info.device_id); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Gemm::EpilogueOutputOp::Params params; + params = + typename Gemm::EpilogueOutputOp::Params(ElementAccumulator(1.f), ElementAccumulator(0.f)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {int(batch_size), reinterpret_cast(all_problems), + nullptr}, + {static_cast(x), reinterpret_cast(x_stride), + static_cast(w), reinterpret_cast(w_stride)}, + {params, static_cast(y), reinterpret_cast(y_stride), + static_cast(y), reinterpret_cast(y_stride)}, + hw_info}; + + Gemm gemm; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + AlignedAllocator float_allocator(float_buffer, float_buffer_size_in_bytes); + auto workspace_ptr = float_allocator.aligned_alloc(workspace_size, 64, + "sm90_group_gemm_float_workspace"); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace_ptr)); + CUTLASS_CHECK(gemm.run(stream)); + } + }); + return cudaSuccess; } } // namespace group_gemm - } // namespace flashinfer #endif // FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_ diff --git a/python/aot_setup.py b/python/aot_setup.py index ee9cfa2a..670c2cb8 100644 --- a/python/aot_setup.py +++ b/python/aot_setup.py @@ -441,7 +441,7 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None: name="flashinfer._kernels_sm90", sources=[ "csrc/group_gemm_sm90.cu", - "csrc_aot/flashinfer_sm90_ops.cu", + "csrc/flashinfer_gemm_sm90_ops.cu", ], include_dirs=include_dirs, extra_compile_args=extra_compile_args_sm90, diff --git a/python/csrc/flashinfer_gemm_ops.cu b/python/csrc/flashinfer_gemm_ops.cu index 69bd47c4..2e31ec0c 100644 --- a/python/csrc/flashinfer_gemm_ops.cu +++ b/python/csrc/flashinfer_gemm_ops.cu @@ -18,10 +18,10 @@ void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, torch::Tensor& A_scale, torch::Tensor& B_scale); -torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, - bool weight_column_major); +void CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor all_problems, + torch::Tensor x_ptr, torch::Tensor w_ptr, torch::Tensor y_ptr, + torch::Tensor x_ld, torch::Tensor w_ld, torch::Tensor y_ld, + torch::Tensor empty_x_data, bool weight_column_major); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM"); diff --git a/python/csrc/flashinfer_gemm_sm90_ops.cu b/python/csrc/flashinfer_gemm_sm90_ops.cu index 2ebe32e5..10d90984 100644 --- a/python/csrc/flashinfer_gemm_sm90_ops.cu +++ b/python/csrc/flashinfer_gemm_sm90_ops.cu @@ -15,11 +15,12 @@ */ #include -torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, - bool weight_column_major); +void CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, torch::Tensor all_problems, + torch::Tensor x_ptr, torch::Tensor w_ptr, torch::Tensor y_ptr, + torch::Tensor x_stride, torch::Tensor weight_stride, + torch::Tensor y_stride, torch::Tensor empty_x_data, + bool weight_column_major); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, diff --git a/python/csrc/group_gemm.cu b/python/csrc/group_gemm.cu index ab200e8b..5d07dd14 100644 --- a/python/csrc/group_gemm.cu +++ b/python/csrc/group_gemm.cu @@ -19,47 +19,23 @@ using namespace flashinfer::group_gemm; -torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, - bool weight_column_major) { - // TODO(Zihao): Add more checks here - CHECK_INPUT(seg_indptr); - CHECK_INPUT(x); - CHECK_INPUT(weight); - auto device = x.device(); - CHECK_EQ(seg_indptr.device(), device); - CHECK_EQ(weight.device(), device); - CHECK_DIM(2, x); // x: [sum(m_i), d_in] - CHECK_DIM(3, weight); // weight: [num_weights, d_out, d_in] if weight_column_major, [num_weights, - // d_in, d_out] otherwise - int64_t cumulative_batch_size = x.size(0); - int64_t d_out = weight_column_major ? weight.size(1) : weight.size(2); - int64_t d_in = weight_column_major ? weight.size(2) : weight.size(1); - CHECK_EQ(x.size(1), d_in); - auto y = torch::zeros({cumulative_batch_size, d_out}, x.options()); +void CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor all_problems, + torch::Tensor x_ptr, torch::Tensor w_ptr, torch::Tensor y_ptr, + torch::Tensor x_ld, torch::Tensor w_ld, torch::Tensor y_ld, + torch::Tensor empty_x_data, bool weight_column_major) { + unsigned int batch_size = x_ptr.size(0); + auto device = workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - seg_indptr = seg_indptr.to(torch::kInt64); - bool weight_indices_defined = weight_indices.numel() > 0; - if (weight_indices_defined) { - CHECK_INPUT(weight_indices); - CHECK_EQ(weight_indices.device(), device); - weight_indices = weight_indices.to(torch::kInt64); - } - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_x_data.scalar_type(), c_type, [&] { using cutlass_t = typename cutlass_dtype::type; auto status = CutlassSegmentGEMMRun( workspace_buffer.data_ptr(), workspace_buffer.element_size() * workspace_buffer.size(0), - static_cast(x.data_ptr()), static_cast(weight.data_ptr()), - static_cast(y.data_ptr()), static_cast(seg_indptr.data_ptr()), - weight_indices_defined ? static_cast(weight_indices.data_ptr()) : nullptr, - batch_size, d_in, d_out, weight_column_major, torch_current_stream); + all_problems.data_ptr(), batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(), + x_ld.data_ptr(), w_ld.data_ptr(), y_ld.data_ptr(), weight_column_major, + torch_current_stream); TORCH_CHECK(status == cudaSuccess, "Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status)); return true; }); - - return y; } diff --git a/python/csrc/group_gemm_sm90.cu b/python/csrc/group_gemm_sm90.cu index 7873e133..a6484a9f 100644 --- a/python/csrc/group_gemm_sm90.cu +++ b/python/csrc/group_gemm_sm90.cu @@ -19,52 +19,27 @@ using namespace flashinfer::group_gemm; -torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, - bool weight_column_major) { - // TODO(Zihao): Add more checks here - CHECK_INPUT(seg_indptr); - CHECK_INPUT(x); - CHECK_INPUT(weight); - auto device = x.device(); - CHECK_EQ(seg_indptr.device(), device); - CHECK_EQ(weight.device(), device); - CHECK_DIM(2, x); // x: [sum(m_i), d_in] - CHECK_DIM(3, weight); // weight: [num_weights, d_out, d_in] if weight_column_major, [num_weights, - // d_in, d_out] otherwise - int64_t cumulative_batch_size = x.size(0); - int64_t d_out = weight_column_major ? weight.size(1) : weight.size(2); - int64_t d_in = weight_column_major ? weight.size(2) : weight.size(1); - CHECK_EQ(x.size(1), d_in); - auto y = torch::zeros({cumulative_batch_size, d_out}, x.options()); +void CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, torch::Tensor all_problems, + torch::Tensor x_ptr, torch::Tensor w_ptr, torch::Tensor y_ptr, + torch::Tensor x_stride, torch::Tensor weight_stride, + torch::Tensor y_stride, torch::Tensor empty_x_data, + bool weight_column_major) { + unsigned int batch_size = x_ptr.size(0); + auto device = float_workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - seg_indptr = seg_indptr.to(torch::kInt64); - bool weight_indices_defined = weight_indices.numel() > 0; - if (weight_indices_defined) { - CHECK_INPUT(weight_indices); - CHECK_EQ(weight_indices.device(), device); - weight_indices = weight_indices.to(torch::kInt64); - } - - // TODO(Zihao): add fp8 support - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_x_data.scalar_type(), c_type, [&] { using cutlass_t = typename cutlass_dtype::type; - auto status = CutlassSegmentGEMMSM90Run( + auto status = CutlassSegmentGEMMSM90Run( float_workspace_buffer.data_ptr(), float_workspace_buffer.element_size() * float_workspace_buffer.size(0), int_workspace_buffer.data_ptr(), - int_workspace_buffer.element_size() * int_workspace_buffer.size(0), - static_cast(x.data_ptr()), static_cast(weight.data_ptr()), - static_cast(y.data_ptr()), static_cast(seg_indptr.data_ptr()), - weight_indices_defined ? static_cast(weight_indices.data_ptr()) : nullptr, - batch_size, d_in, d_out, weight_column_major, torch_current_stream); + int_workspace_buffer.element_size() * int_workspace_buffer.size(0), all_problems.data_ptr(), + batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(), x_stride.data_ptr(), + weight_stride.data_ptr(), y_stride.data_ptr(), weight_column_major, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status)); return true; }); - - return y; } diff --git a/python/csrc_aot/flashinfer_ops.cu b/python/csrc_aot/flashinfer_ops.cu index 80039260..cc9f3935 100644 --- a/python/csrc_aot/flashinfer_ops.cu +++ b/python/csrc_aot/flashinfer_ops.cu @@ -60,10 +60,10 @@ torch::Tensor BatchDecodeWithPagedKVCacheRun( void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, torch::Tensor& A_scale, torch::Tensor& B_scale); -torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, - bool weight_column_major); +void CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor all_problems, + torch::Tensor x_ptr, torch::Tensor w_ptr, torch::Tensor y_ptr, + torch::Tensor x_ld, torch::Tensor w_ld, torch::Tensor y_ld, + torch::Tensor empty_x_data, bool weight_column_major); //========== norm ========== diff --git a/python/csrc_aot/flashinfer_sm90_ops.cu b/python/csrc_aot/flashinfer_sm90_ops.cu deleted file mode 100644 index 2ebe32e5..00000000 --- a/python/csrc_aot/flashinfer_sm90_ops.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, - bool weight_column_major); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, - "Cutlass Segment GEMM operator for SM90"); -} diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index 4c88ac2e..f3be0928 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -18,6 +18,8 @@ from typing import Optional import torch +import triton +import triton.language as tl from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops from .utils import ( @@ -72,37 +74,48 @@ def _fake_bmm_fp8( # torch library for cutlass_segment_gemm - @register_custom_op("flashinfer::cutlass_segment_gemm", mutates_args=()) + @register_custom_op("flashinfer::cutlass_segment_gemm", mutates_args=("y")) def cutlass_segment_gemm( workspace_buffer: torch.Tensor, - seg_indptr: torch.Tensor, - weight_indices: torch.Tensor, - x: torch.Tensor, - weights: torch.Tensor, - batch_size: int, + all_problems: torch.Tensor, + x_data: torch.Tensor, + w_data: torch.Tensor, + y_data: torch.Tensor, + x_ld: torch.Tensor, + w_ld: torch.Tensor, + y_ld: torch.Tensor, + y: torch.Tensor, + empty_x_data: torch.Tensor, weight_column_major: bool, - ) -> torch.Tensor: - return module.cutlass_segment_gemm( + ) -> None: + module.cutlass_segment_gemm( workspace_buffer, - seg_indptr, - weight_indices, - x, - weights, - batch_size, + all_problems, + x_data, + w_data, + y_data, + x_ld, + w_ld, + y_ld, + empty_x_data, weight_column_major, ) @register_fake_op("flashinfer::cutlass_segment_gemm") def _fake_cutlass_segment_gemm( workspace_buffer: torch.Tensor, - seg_indptr: torch.Tensor, - weight_indices: torch.Tensor, - x: torch.Tensor, - weights: torch.Tensor, - batch_size: int, + all_problems: torch.Tensor, + x_data: torch.Tensor, + w_data: torch.Tensor, + y_data: torch.Tensor, + x_ld: torch.Tensor, + w_ld: torch.Tensor, + y_ld: torch.Tensor, + y: torch.Tensor, + empty_x_data: torch.Tensor, weight_column_major: bool, - ) -> torch.Tensor: - return torch.empty_like(x) + ) -> None: + pass # Register the module _gemm_module = SimpleNamespace( @@ -114,7 +127,6 @@ def _fake_cutlass_segment_gemm( def get_gemm_sm90_module(): - print("get_gemm_sm90_module") global _gemm_module_sm90 if _gemm_module_sm90 is None: if has_prebuilt_ops: @@ -133,25 +145,32 @@ def get_gemm_sm90_module(): # torch library for cutlass_segment_gemm_sm90 - @register_custom_op("flashinfer::cutlass_segment_gemm_sm90", mutates_args=()) + @register_custom_op("flashinfer::cutlass_segment_gemm_sm90", mutates_args=("y")) def cutlass_segment_gemm_sm90( workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, - seg_indptr: torch.Tensor, - weight_indices: torch.Tensor, - x: torch.Tensor, - weights: torch.Tensor, - batch_size: int, + all_problems: torch.Tensor, + x_data: torch.Tensor, + w_data: torch.Tensor, + y_data: torch.Tensor, + x_stride: torch.Tensor, + w_stride: torch.Tensor, + y_stride: torch.Tensor, + y: torch.Tensor, + empty_x_data: torch.Tensor, weight_column_major: bool, - ) -> torch.Tensor: - return module.cutlass_segment_gemm_sm90( + ) -> None: + module.cutlass_segment_gemm_sm90( workspace_buffer, int_workspace_buffer, - seg_indptr, - weight_indices, - x, - weights, - batch_size, + all_problems, + x_data, + w_data, + y_data, + x_stride, + w_stride, + y_stride, + empty_x_data, weight_column_major, ) @@ -159,14 +178,18 @@ def cutlass_segment_gemm_sm90( def _fake_cutlass_segment_gemm_sm90( workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, - seg_indptr: torch.Tensor, - weight_indices: torch.Tensor, - x: torch.Tensor, - weights: torch.Tensor, - batch_size: int, + all_problems: torch.Tensor, + x_data: torch.Tensor, + w_data: torch.Tensor, + y_data: torch.Tensor, + x_stride: torch.Tensor, + w_stride: torch.Tensor, + y_stride: torch.Tensor, + y: torch.Tensor, + empty_x_data: torch.Tensor, weight_column_major: bool, - ) -> torch.Tensor: - return torch.empty_like(x) + ) -> None: + pass # Register the module _gemm_module_sm90 = SimpleNamespace( @@ -176,6 +199,212 @@ def _fake_cutlass_segment_gemm_sm90( return _gemm_module_sm90 +@triton.jit +def compute_sm80_group_gemm_args( + all_problems_ptr, + x_ptr, + w_ptr, + y_ptr, + x_ld_ptr, + w_ld_ptr, + y_ld_ptr, + x, + w, + y, + xy_indptr, + w_indices, + d_in, + d_out, + w_column_major, +): + + pid = tl.program_id(0) + + m = tl.load(xy_indptr + pid + 1) - tl.load(xy_indptr + pid) + k, n = d_in, d_out + + tl.store(all_problems_ptr + pid * 3, m) + tl.store(all_problems_ptr + pid * 3 + 1, n) + tl.store(all_problems_ptr + pid * 3 + 2, k) + + w_i = tl.load(w_indices + pid) if w_indices else tl.cast(pid, tl.int64) + w_curr_ptr = w + w_i * k * n + tl.store(w_ptr + pid, w_curr_ptr) + + x_curr_ptr = x + tl.load(xy_indptr + pid) * k + tl.store(x_ptr + pid, x_curr_ptr) + + y_curr_ptr = y + tl.load(xy_indptr + pid) * n + tl.store(y_ptr + pid, y_curr_ptr) + + tl.store(x_ld_ptr + pid, k) + tl.store(w_ld_ptr + pid, k if w_column_major else n) + tl.store(y_ld_ptr + pid, n) + + +@triton.jit +def compute_sm90_group_gemm_args( + all_problems_ptr, + x_ptr, + w_ptr, + y_ptr, + x_stride_ptr, + w_stride_ptr, + y_stride_ptr, + x, + w, + y, + xy_indptr, + w_indices, + d_in, + d_out, + w_column_major, +): + + pid = tl.program_id(0) + + m = tl.load(xy_indptr + pid + 1) - tl.load(xy_indptr + pid) + k, n = d_in, d_out + + tl.store(all_problems_ptr + pid * 3, m) + tl.store(all_problems_ptr + pid * 3 + 1, n) + tl.store(all_problems_ptr + pid * 3 + 2, k) + + w_i = tl.load(w_indices + pid) if w_indices else tl.cast(pid, tl.int64) + w_curr_ptr = w + w_i * k * n + tl.store(w_ptr + pid, w_curr_ptr) + + x_curr_ptr = x + tl.load(xy_indptr + pid) * k + tl.store(x_ptr + pid, x_curr_ptr) + + y_curr_ptr = y + tl.load(xy_indptr + pid) * n + tl.store(y_ptr + pid, y_curr_ptr) + + tl.store(x_stride_ptr + pid, k) + tl.store(w_stride_ptr + pid, k if w_column_major else n) + tl.store(y_stride_ptr + pid, n) + + +def launch_compute_sm80_group_gemm_args( + x: torch.Tensor, + weights: torch.Tensor, + y: torch.Tensor, + w_column_major: bool, + batch_size: int, + seg_indptr: torch.Tensor, + weight_indices: Optional[torch.Tensor] = None, +): + device = x.device + prob_type = torch.int32 # problem sizes -> int + ptr_type = torch.int64 # pointers -> int64_t + ld_type = torch.int64 # strides -> int64_t + + seg_indptr = seg_indptr.to(ptr_type) + if weight_indices is not None: + weight_indices = weight_indices.to(ptr_type) + + d_out = weights.size(1) if w_column_major else weights.size(2) + d_in = weights.size(2) if w_column_major else weights.size(1) + + all_problems = torch.empty((batch_size, 3), dtype=prob_type, device=device) + + x_data = torch.empty(batch_size, dtype=ptr_type, device=device) + w_data = torch.empty(batch_size, dtype=ptr_type, device=device) + y_data = torch.empty(batch_size, dtype=ptr_type, device=device) + + x_stride_data = torch.empty(batch_size, dtype=ld_type, device=device) + w_stride_data = torch.empty(batch_size, dtype=ld_type, device=device) + y_stride_data = torch.empty(batch_size, dtype=ld_type, device=device) + + compute_sm80_group_gemm_args[(batch_size,)]( + all_problems, + x_data, + w_data, + y_data, + x_stride_data, + w_stride_data, + y_stride_data, + x, + weights, + y, + seg_indptr, + weight_indices, + d_in, + d_out, + w_column_major, + ) + + return ( + all_problems, + x_data, + w_data, + y_data, + x_stride_data, + w_stride_data, + y_stride_data, + ) + + +def launch_compute_sm90_group_gemm_args( + x: torch.Tensor, + weights: torch.Tensor, + y: torch.Tensor, + w_column_major: bool, + batch_size: int, + seg_indptr: torch.Tensor, + weight_indices: Optional[torch.Tensor] = None, +): + device = x.device + prob_type = torch.int32 # problem sizes -> int + ptr_type = torch.int64 # pointers -> int64_t + stride_type = torch.int64 # strides -> int64_t + + seg_indptr = seg_indptr.to(ptr_type) + if weight_indices is not None: + weight_indices = weight_indices.to(ptr_type) + + d_out = weights.size(1) if w_column_major else weights.size(2) + d_in = weights.size(2) if w_column_major else weights.size(1) + + all_problems = torch.empty((batch_size, 3), dtype=prob_type, device=device) + + x_data = torch.empty(batch_size, dtype=ptr_type, device=device) + w_data = torch.empty(batch_size, dtype=ptr_type, device=device) + y_data = torch.empty(batch_size, dtype=ptr_type, device=device) + + x_stride_data = torch.empty(batch_size, dtype=stride_type, device=device) + w_stride_data = torch.empty(batch_size, dtype=stride_type, device=device) + y_stride_data = torch.empty(batch_size, dtype=stride_type, device=device) + + compute_sm90_group_gemm_args[(batch_size,)]( + all_problems, + x_data, + w_data, + y_data, + x_stride_data, + w_stride_data, + y_stride_data, + x, + weights, + y, + seg_indptr, + weight_indices, + d_in, + d_out, + w_column_major, + ) + + return ( + all_problems, + x_data, + w_data, + y_data, + x_stride_data, + w_stride_data, + y_stride_data, + ) + + class SegmentGEMMWrapper: r"""Wrapper for segment GEMM kernels. @@ -332,27 +561,75 @@ def run( # create an empty CPU tensor as placeholder weight_indices = torch.empty(0, dtype=torch.int64) major, _ = get_compute_capability(x.device) + cumulative_batch_size = x.size(0) + d_out = weights.size(1) if weight_column_major else weights.size(2) + y = torch.zeros((cumulative_batch_size, d_out), dtype=x.dtype, device=x.device) + empty_x_data = torch.empty(0, dtype=x.dtype, device=x.device) + if major >= 9: - return get_gemm_sm90_module().cutlass_segment_gemm_sm90( - self._float_workspace_buffer, - self._int_workspace_buffer, - seg_indptr, - weight_indices, + ( + all_problems, + x_data, + w_data, + y_data, + x_stride_data, + w_stride_data, + y_stride_data, + ) = launch_compute_sm90_group_gemm_args( x, weights, + y, + weight_column_major, batch_size, + seg_indptr, + weight_indices, + ) + get_gemm_sm90_module().cutlass_segment_gemm_sm90( + self._float_workspace_buffer, + self._int_workspace_buffer, + all_problems, + x_data, + w_data, + y_data, + x_stride_data, + w_stride_data, + y_stride_data, + y, # for torch compile mutates_args + empty_x_data, # for kernel type dispatch weight_column_major, ) else: - return get_gemm_module().cutlass_segment_gemm( - self._int_workspace_buffer, - seg_indptr, - weight_indices, + ( + all_problems, + x_data, + w_data, + y_data, + x_ld_data, + w_ld_data, + y_ld_data, + ) = launch_compute_sm80_group_gemm_args( x, weights, + y, + weight_column_major, batch_size, + seg_indptr, + weight_indices, + ) + get_gemm_module().cutlass_segment_gemm( + self._int_workspace_buffer, + all_problems, + x_data, + w_data, + y_data, + x_ld_data, + w_ld_data, + y_ld_data, + y, + empty_x_data, weight_column_major, ) + return y forward = run