Skip to content

Commit

Permalink
perf: faster fp8->fp16 dequantization for pre sm_90 arch (#439)
Browse files Browse the repository at this point in the history
hardware fp8->fp16 fast conversion instruction is not available for
sm_80 & sm_89, which makes #420 slow for these architectures.

this pr uses marlin's fast fp8->fp16x4 conversion algorithm (copied from
vllm project) to accelerate such cases.

Co-authored-by: Antoni Baum <antoni@anyscale.com>
Co-authored-by: Cody Yu <cody@anyscale.com>
  • Loading branch information
3 people authored Aug 11, 2024
1 parent adcf701 commit c93f647
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 15 deletions.
13 changes: 13 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ flashinfer_option(FLASHINFER_CASCADE "Whether to compile cascade kernel tests/be
flashinfer_option(FLASHINFER_SAMPLING "Whether to compile sampling kernel tests/benchmarks or not." OFF)
flashinfer_option(FLASHINFER_NORM "Whether to compile normalization kernel tests/benchmarks or not." OFF)
flashinfer_option(FLASHINFER_DISTRIBUTED "Whether to compile distributed kernel tests/benchmarks or not." OFF)
flashinfer_option(FLASHINFER_FASTDIV_TEST "Whether to compile fastdiv kernel tests or not." OFF)
flashinfer_option(FLASHINFER_FASTDEQAUNT_TEST "Whether to compile fast dequant kernel tests or not." OFF)
flashinfer_option(FLASHINFER_TVM_BINDING "Whether to compile tvm binding or not." OFF)
flashinfer_option(FLASHINFER_TVM_SOURCE_DIR "The path to tvm for building tvm binding." "")

Expand Down Expand Up @@ -477,6 +479,17 @@ if(FLASHINFER_FASTDIV_TEST)
target_link_libraries(test_fastdiv PRIVATE gtest gtest_main)
endif(FLASHINFER_FASTDIV_TEST)

if(FLASHINFER_FASTDEQUANT_TEST)
message(STATUS "Compile fast dequant test.")
file(GLOB_RECURSE TEST_FAST_DEQUANT_SRCS ${PROJECT_SOURCE_DIR}/src/test_fast_dequant.cu)
add_executable(test_fast_dequant ${TEST_FAST_DEQUANT_SRCS})
target_include_directories(test_fast_dequant PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_include_directories(test_fast_dequant PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
target_link_libraries(test_fast_dequant PRIVATE gtest gtest_main)
endif(FLASHINFER_FASTDIV_TEST)



if (FLASHINFER_DISTRIBUTED)
find_package(MPI REQUIRED)

Expand Down
2 changes: 2 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ set(FLASHINFER_SAMPLING ON)
set(FLASHINFER_NORMALIZATION ON)
# Whether to compile fastdiv tests
set(FLASHINFER_FASTDIV_TEST ON)
# Whether to compile fastdequant tests
set(FLASHINFER_FASTDEQUANT_TEST ON)
# Whether to compile distributed tests
set(FLASHINFER_DISTRIBUTED ON)
# The following configurations can impact the binary
Expand Down
182 changes: 167 additions & 15 deletions include/flashinfer/vec_dtypes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@
#ifndef VEC_DTYPES_CUH_
#define VEC_DTYPES_CUH_

#ifdef FLASHINFER_ENABLE_BF16
#include <cuda_bf16.h>
#endif
#include <cuda_fp16.h>
#ifdef FLASHINFER_ENABLE_FP8
#include <cuda_fp8.h>
#endif
#include <cuda_runtime.h>

#include <type_traits>

namespace flashinfer {

#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900))
#define FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED
#endif

#define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__

/******************* vec_t type cast *******************/
Expand Down Expand Up @@ -74,11 +74,130 @@ struct vec_cast<half, float> {
}
};

#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900))
template <typename T>
constexpr FLASHINFER_INLINE int get_exponent_bits() {
if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
return 4;
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
return 5;
} else if constexpr (std::is_same<T, half>::value) {
return 5;
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
return 8;
}
}

template <typename T>
constexpr FLASHINFER_INLINE int get_mantissa_bits() {
if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
return 3;
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
return 2;
} else if constexpr (std::is_same<T, half>::value) {
return 11;
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
return 7;
}
}

/*!
* \brief Fallback to software fast dequant implementation if hardware dequantization is not
* available.
* \note Inspired by Marlin's fast dequantization, but here we don't have to permute
* weights order.
* \ref
* https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120
*/
template <typename fp8_dtype, typename fp16_dtype>
__device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) {
uint32_t q = *input;
if constexpr (std::is_same<fp8_dtype, __nv_fp8_e5m2>::value &&
std::is_same<fp16_dtype, half>::value) {
output->x = __byte_perm(0U, q, 0x5140);
output->y = __byte_perm(0U, q, 0x7362);
} else {
constexpr int FP8_EXPONENT = get_exponent_bits<fp8_dtype>();
constexpr int FP8_MANTISSA = get_mantissa_bits<fp8_dtype>();
constexpr int FP16_EXPONENT = get_exponent_bits<fp16_dtype>();

constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
// Calculate MASK for extracting mantissa and exponent
constexpr int MASK1 = 0x80000000;
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | (MASK3 >> 16);
// Final MASK value: 0x7F007F00
q = __byte_perm(q, q, 0x1302);

// Extract and shift FP8 values to FP16 format
uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
uint32_t Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);

constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
// Construct and apply exponent bias
if (std::is_same<fp16_dtype, half>::value) {
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));

// Convert to half2 and apply bias
*(half2*)&(output->x) = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
*(half2*)&(output->y) = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
} else {
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to bfloat162 and apply bias
*(nv_bfloat162*)&(output->x) =
__hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
*(nv_bfloat162*)&(output->y) =
__hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
}
}
}

template <>
struct vec_cast<nv_bfloat16, __nv_fp8_e4m3> {
template <size_t vec_size>
FLASHINFER_INLINE static void cast(nv_bfloat16* dst, const __nv_fp8_e4m3* src) {
if constexpr (vec_size == 1) {
dst[0] = nv_bfloat16(src[0]);
} else if constexpr (vec_size == 2) {
dst[0] = nv_bfloat16(src[0]);
dst[1] = nv_bfloat16(src[1]);
} else {
static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
#pragma unroll
for (uint32_t i = 0; i < vec_size / 4; ++i) {
fast_dequant_f8f16x4<__nv_fp8_e4m3, nv_bfloat16>((uint32_t*)&src[i * 4],
(uint2*)&dst[i * 4]);
}
}
}
};

template <>
struct vec_cast<nv_bfloat16, __nv_fp8_e5m2> {
template <size_t vec_size>
FLASHINFER_INLINE static void cast(nv_bfloat16* dst, const __nv_fp8_e5m2* src) {
if constexpr (vec_size == 1) {
dst[0] = nv_bfloat16(src[0]);
} else if constexpr (vec_size == 2) {
dst[0] = nv_bfloat16(src[0]);
dst[1] = nv_bfloat16(src[1]);
} else {
static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
#pragma unroll
for (uint32_t i = 0; i < vec_size / 4; ++i) {
fast_dequant_f8f16x4<__nv_fp8_e5m2, nv_bfloat16>((uint32_t*)&src[i * 4],
(uint2*)&dst[i * 4]);
}
}
}
};

template <>
struct vec_cast<__nv_fp8_e4m3, half> {
template <size_t vec_size>
FLASHINFER_INLINE static void cast(__nv_fp8_e4m3* dst, const half* src) {
#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED
if constexpr (vec_size == 1) {
dst[0] = __nv_fp8_e4m3(src[0]);
} else {
Expand All @@ -90,13 +209,20 @@ struct vec_cast<__nv_fp8_e4m3, half> {
*(uint16_t*)&dst[i * 2] = y;
}
}
#else
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
dst[i] = __nv_fp8_e4m3(src[i]);
}
#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED
}
};

template <>
struct vec_cast<__nv_fp8_e5m2, half> {
template <size_t vec_size>
FLASHINFER_INLINE static void cast(__nv_fp8_e5m2* dst, const half* src) {
#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED
if constexpr (vec_size == 1) {
dst[0] = __nv_fp8_e5m2(src[0]);
} else {
Expand All @@ -108,13 +234,20 @@ struct vec_cast<__nv_fp8_e5m2, half> {
*(uint16_t*)&dst[i * 2] = y;
}
}
#else
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
dst[i] = __nv_fp8_e5m2(src[i]);
}
#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED
}
};

template <>
struct vec_cast<half, __nv_fp8_e4m3> {
template <size_t vec_size>
FLASHINFER_INLINE static void cast(half* dst, const __nv_fp8_e4m3* src) {
#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED
if constexpr (vec_size == 1) {
dst[0] = half(src[0]);
} else {
Expand All @@ -126,13 +259,28 @@ struct vec_cast<half, __nv_fp8_e4m3> {
*(uint32_t*)&dst[i * 2] = y;
}
}
#else
if constexpr (vec_size == 1) {
dst[0] = half(src[0]);
} else if constexpr (vec_size == 2) {
dst[0] = half(src[0]);
dst[1] = half(src[1]);
} else {
static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
#pragma unroll
for (uint32_t i = 0; i < vec_size / 4; ++i) {
fast_dequant_f8f16x4<__nv_fp8_e4m3, half>((uint32_t*)&src[i * 4], (uint2*)&dst[i * 4]);
}
}
#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED
}
};

template <>
struct vec_cast<half, __nv_fp8_e5m2> {
template <size_t vec_size>
FLASHINFER_INLINE static void cast(half* dst, const __nv_fp8_e5m2* src) {
#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED
if constexpr (vec_size == 1) {
dst[0] = half(src[0]);
} else {
Expand All @@ -144,13 +292,23 @@ struct vec_cast<half, __nv_fp8_e5m2> {
*(uint32_t*)&dst[i * 2] = y;
}
}
#else
if constexpr (vec_size == 1) {
dst[0] = half(src[0]);
} else if constexpr (vec_size == 2) {
dst[0] = half(src[0]);
dst[1] = half(src[1]);
} else {
static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
#pragma unroll
for (uint32_t i = 0; i < vec_size / 4; ++i) {
fast_dequant_f8f16x4<__nv_fp8_e5m2, half>((uint32_t*)&src[i * 4], (uint2*)&dst[i * 4]);
}
}
#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED
}
};

#endif // !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900)

#ifdef FLASHINFER_ENABLE_BF16

template <>
struct vec_cast<float, nv_bfloat16> {
template <size_t vec_size>
Expand Down Expand Up @@ -180,7 +338,6 @@ struct vec_cast<nv_bfloat16, float> {
}
}
};
#endif // FLASHINFER_ENABLE_BF16

template <typename float_t, size_t vec_size>
struct vec_t {
Expand Down Expand Up @@ -230,7 +387,6 @@ FLASHINFER_INLINE void cast_store_impl(tgt_float_t* dst_ptr,
}
}

#ifdef FLASHINFER_ENABLE_FP8
/******************* vec_t<__nv_fp8_e4m3> *******************/

// __nv_fp8_e4m3 x 1
Expand Down Expand Up @@ -724,7 +880,6 @@ struct vec_t<__nv_fp8_e5m2, vec_size> {
}
}
};
#endif

/******************* vec_t<half> *******************/

Expand Down Expand Up @@ -889,7 +1044,6 @@ struct vec_t<half, vec_size> {
}
};

#ifdef FLASHINFER_ENABLE_BF16
/******************* vec_t<nv_bfloat16> *******************/

// nv_bfloat16 x 1
Expand Down Expand Up @@ -1071,8 +1225,6 @@ struct vec_t<nv_bfloat16, vec_size> {
}
};

#endif

/******************* vec_t<float> *******************/

// float x 1
Expand Down
Loading

0 comments on commit c93f647

Please # to comment.