Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: mma rowsum for fp8 #180

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions include/flashinfer/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) {
template <typename T, MMAMode mma_mode = MMAMode::kInplaceUpdate>
__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uint32_t* A,
uint32_t* B) {
static_assert(sizeof(T) == 1, "DType must be 8bit floating data type");
#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
if constexpr (mma_mode == MMAMode::kInit) {
if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
Expand Down Expand Up @@ -216,7 +217,7 @@ __device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uin
}
}
#else
static_assert(false, "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+");
#error "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"
#endif
}

Expand Down Expand Up @@ -387,8 +388,45 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u
#endif
}

// template <typename DType>
// __device__ __forceinline__ void
/*!
* \brief Use mma instructions to compute rowsum.
*/
template <typename DType>
__device__ __forceinline__ void rowsum_f8f8f32(float* d, DType* s) {
static_assert(sizeof(DType) == 1, "DType must be 8bit floating data type");
uint32_t* s_u32 = (uint32_t*)(s);
#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
if constexpr (std::is_same<DType, __nv_fp8_e4m3>::value) {
asm volatile(
"{\n"
".reg .f32 ph;\n"
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0, ph, %1, ph},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, 0., %9, 0.};\n"
"}\n"
: "=f"(d[0]), "=f"(d[1])
: "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), "r"(s_u32[3]), "r"(943208504),
"r"(943208504), "f"(d[0]), "f"(d[1]));
} else { // e5m2
asm volatile(
"{\n"
".reg .f32 ph;\n"
"mma.sync.aligned.m16n8k16.row.col.f32.e5m2.e5m2.f32 "
"{%0, ph, %1, ph},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, 0., %9, 0.};\n"
"}\n"
: "=f"(d[0]), "=f"(d[1])
: "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), "r"(s_u32[3]), "r"(1010580540),
"r"(1010580540), "f"(d[0]), "f"(d[1]));
}
#else
#error "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"
#endif
}

/*!
* \brief Use mma instructions to compute rowsum.
Expand Down