Skip to content

Commit

Permalink
feat: support bmm fp8 (#469)
Browse files Browse the repository at this point in the history
`torch.bmm` doesn't support fp8 and `torch._scaled_mm` doesn't support
3d, so I write this one. @yzh119 cc @merrymercy @Ying1123 @ispobock

Thanks @yzh119 for assisting with debug.

AType: fp8 e4m3, fp8 e5m2
BType: fp8 e4m3, fp8 e5m2
DType: bf16, fp16

Does not support both AType and BType fp8 e5m2. ref
https://docs.nvidia.com/cuda/cublas/#cublasltmatmul

```python3
pytest python/tests/test_bmm_fp8.py
```

works on H100
```
=================================================================================== test session starts ===================================================================================
platform linux -- Python 3.12.4, pytest-8.3.2, pluggy-1.5.0
rootdir: /flashinfer
collected 8 items

python/tests/test_bmm_fp8.py ...s...s                                                                                                                                                                       [100%]

============================================================================== 6 passed, 2 skipped in 2.16s ===============================================================================
```

---------

Co-authored-by: Zihao Ye <expye@outlook.com>
  • Loading branch information
zhyncs and yzh119 authored Aug 26, 2024
1 parent 2ba3f1c commit f1c0b68
Show file tree
Hide file tree
Showing 10 changed files with 371 additions and 9 deletions.
File renamed without changes.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ FlashInfer is a library for Large Language Models that provides high-performance
api/python/sparse
api/python/page
api/python/sampling
api/python/group_gemm
api/python/gemm
api/python/norm
api/python/rope
api/python/quantization
200 changes: 200 additions & 0 deletions include/flashinfer/bmm_fp8.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_BMM_FP8_CUH_
#define FLASHINFER_BMM_FP8_CUH_

#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <cublasLt.h>
#include <cuda_fp8.h>
#include <torch/extension.h>

#include <stdexcept>
#include <type_traits>

namespace flashinfer {

namespace bmm_fp8 {

template <typename T, cublasStatus_t (*destructor)(T*)>
struct CuBlasLtDeleter {
void operator()(T* x) {
if (x != nullptr) {
TORCH_CUDABLAS_CHECK(destructor(x));
}
}
};

template <typename T, cublasStatus_t (*destructor)(T*)>
class CuBlasLtDescriptor {
public:
T* descriptor() const { return descriptor_.get(); }
T* descriptor() { return descriptor_.get(); }

protected:
std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_;
};

class CuBlasLtMatmulDescriptor
: public CuBlasLtDescriptor<cublasLtMatmulDescOpaque_t, &cublasLtMatmulDescDestroy> {
public:
CuBlasLtMatmulDescriptor(cublasComputeType_t compute_type, cudaDataType_t scale_type) {
cublasLtMatmulDesc_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

class CuBlasLtMatrixLayout
: public CuBlasLtDescriptor<cublasLtMatrixLayoutOpaque_t, &cublasLtMatrixLayoutDestroy> {
public:
CuBlasLtMatrixLayout(cudaDataType_t type, uint64_t rows, uint64_t cols, int64_t ld,
bool t = false) {
cublasLtMatrixLayout_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<cublasLtMatmulPreferenceOpaque_t,
&cublasLtMatmulPreferenceDestroy> {
public:
CuBlasLtMatmulPreference() {
cublasLtMatmulPreference_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(
::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

template <typename T>
cudaDataType_t get_cuda_data_type() {
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
return CUDA_R_8F_E4M3;
} else if constexpr (std::is_same_v<T, __nv_fp8_e5m2>) {
return CUDA_R_8F_E5M2;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return CUDA_R_16BF;
} else if constexpr (std::is_same_v<T, half>) {
return CUDA_R_16F;
} else {
throw std::runtime_error("Unsupported type");
}
}

template <typename AT, typename BT, typename DT>
void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size, int m, int n, int k,
const float* A_scale, const float* B_scale) {
const void* A_scale_ptr = static_cast<const void*>(A_scale);
const void* B_scale_ptr = static_cast<const void*>(B_scale);
auto matmul_desp = CuBlasLtMatmulDescriptor(CUBLAS_COMPUTE_32F, CUDA_R_32F);
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T);
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N);
int8_t fast_accum = 1;
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fast_accum);

matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, A_scale_ptr);
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, B_scale_ptr);

cudaDataType_t a_type = get_cuda_data_type<AT>();
cudaDataType_t b_type = get_cuda_data_type<BT>();
cudaDataType_t d_type = get_cuda_data_type<DT>();
if (std::is_same_v<AT, __nv_fp8_e5m2> && std::is_same_v<BT, __nv_fp8_e5m2>) {
throw std::runtime_error("Unsupported combination: both A and B are e5m2");
}

auto a_desp = CuBlasLtMatrixLayout(a_type, m, k, k, true);
auto b_desp = CuBlasLtMatrixLayout(b_type, k, n, k);
auto d_desp = CuBlasLtMatrixLayout(d_type, m, n, m);

if (batch_size > 1) {
int64_t stride_a = m * k;
int64_t stride_b = k * n;
int64_t stride_d = m * n;
a_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size);
a_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_a);
b_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size);
b_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_b);
d_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size);
d_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_d);
}

CuBlasLtMatmulPreference preference;
size_t workspace_size = 1024 * 1024; // 1 MiB
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size);
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspace_size);
cublasLtMatmulHeuristicResult_t heuristic_result = {};
int returned_result = 0;
auto lt_handle = at::cuda::getCurrentCUDABlasLtHandle();
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
lt_handle, matmul_desp.descriptor(), a_desp.descriptor(), b_desp.descriptor(),
d_desp.descriptor(), d_desp.descriptor(), preference.descriptor(), 1, &heuristic_result,
&returned_result));
if (returned_result == 0) {
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
}

const float alpha = 1.0f;
const float beta = 0.0f;
cublasStatus_t status = cublasLtMatmul(
lt_handle, matmul_desp.descriptor(), &alpha, A, a_desp.descriptor(), B, b_desp.descriptor(),
&beta, nullptr, d_desp.descriptor(), D, d_desp.descriptor(), &heuristic_result.algo,
workspace.mutable_get(), workspace_size, at::cuda::getCurrentCUDAStream());
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, at::cuda::blas::_cublasGetErrorEnum(status));
}

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>(
const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n,
int k, const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, half>(
const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, half* D, int batch_size, int m, int n, int k,
const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, __nv_bfloat16>(
const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, __nv_bfloat16* D, int batch_size, int m, int n,
int k, const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, half>(
const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, half* D, int batch_size, int m, int n, int k,
const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, __nv_bfloat16>(
const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n,
int k, const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>(
const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, half* D, int batch_size, int m, int n, int k,
const float* A_scale, const float* B_scale);

} // namespace bmm_fp8
} // namespace flashinfer

#endif // FLASHINFER_BMM_FP8_CUH_
68 changes: 68 additions & 0 deletions python/csrc/bmm_fp8.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include <flashinfer/bmm_fp8.cuh>

#include "flashinfer_ops.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
torch::Tensor& A_scale, torch::Tensor& B_scale) {
TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor");
TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor");
TORCH_CHECK(D.is_cuda(), "D must be a CUDA tensor");
TORCH_CHECK(A.dim() == 3, "Expected 3D tensor for A");
TORCH_CHECK(B.dim() == 3, "Expected 3D tensor for B");
TORCH_CHECK(D.dim() == 3, "Expected 3D tensor for D");
TORCH_CHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0), "Batch sizes must match");
TORCH_CHECK(A.size(2) == B.size(1), "Incompatible matrix sizes");
TORCH_CHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2),
"Result tensor has incorrect shape");
TORCH_CHECK(A.scalar_type() == torch::kFloat8_e4m3fn || A.scalar_type() == torch::kFloat8_e5m2,
"A must be Float8_e4m3fn or Float8_e5m2");
TORCH_CHECK(B.scalar_type() == torch::kFloat8_e4m3fn || B.scalar_type() == torch::kFloat8_e5m2,
"B must be Float8_e4m3fn or Float8_e5m2");
TORCH_CHECK(D.scalar_type() == torch::kBFloat16 || D.scalar_type() == torch::kHalf,
"D must be BFloat16 or Half");

TORCH_CHECK(A_scale.scalar_type() == torch::kFloat32 && B_scale.scalar_type() == torch::kFloat32,
"A_scale and B_scale must be Float32");

auto batch_size = A.size(0);
auto m = A.size(1);
auto k = A.size(2);
auto n = B.size(2);

// PyTorch is row major by default. cuBLASLt is column major by default.
// We need row major D as expected.
// A ^ T * B = D, so D ^ T = B ^ T * A
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(B.scalar_type(), b_type, [&] {
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(A.scalar_type(), a_type, [&] {
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(D.scalar_type(), d_type, [&] {
flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
static_cast<b_type*>(B.data_ptr()), static_cast<a_type*>(A.data_ptr()),
static_cast<d_type*>(D.data_ptr()), batch_size, n, m, k,
static_cast<float*>(B_scale.data_ptr()), static_cast<float*>(A_scale.data_ptr()));
return true;
});
});
});
}
1 change: 1 addition & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
m.def("packbits", &packbits, "GPU packbits operator");
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
m.def("bmm_fp8", &bmm_fp8, "BMM FP8");
py::class_<CutlassSegmentGEMMPyTorchWrapper>(m, "CutlassSegmentGEMMPyTorchWrapper")
.def(py::init<torch::Tensor>())
.def("register_workspace", &CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer)
Expand Down
3 changes: 3 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);
torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
torch::Tensor output_indptr, const std::string& bitorder);

void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
torch::Tensor& A_scale, torch::Tensor& B_scale);

class CutlassSegmentGEMMPyTorchWrapper {
public:
void RegisterWorkspaceBuffer(torch::Tensor workspace_buffer);
Expand Down
12 changes: 6 additions & 6 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
limitations under the License.
"""

from .activation import gelu_tanh_and_mul, silu_and_mul
from .cascade import (
MultiLevelCascadeAttentionWrapper,
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper,
merge_state,
merge_state_in_place,
merge_states,
Expand All @@ -27,8 +28,7 @@
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache,
)
from .activation import gelu_tanh_and_mul, silu_and_mul
from .group_gemm import SegmentGEMMWrapper
from .gemm import SegmentGEMMWrapper, bmm_fp8
from .norm import fused_add_rmsnorm, rmsnorm
from .page import append_paged_kv_cache
from .prefill import (
Expand All @@ -46,15 +46,15 @@
)
from .sampling import (
chain_speculative_sampling,
min_p_sampling_from_probs,
sampling_from_probs,
top_k_renorm_prob,
top_k_mask_logits,
top_k_renorm_prob,
top_k_sampling_from_probs,
top_k_top_p_sampling_from_probs,
top_k_top_p_sampling_from_logits,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
top_p_sampling_from_probs,
min_p_sampling_from_probs,
)
from .sparse import BlockSparseAttentionWrapper

Expand Down
Loading

0 comments on commit f1c0b68

Please # to comment.