diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 0a9e501a..006d9753 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -802,8 +802,8 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, } b_frag_f8[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); - bfly_exch(b_frag_f8[0], b_frag_f8[1]); vec_cast::cast<8>((DTypeQ*)b_frag, (DTypeKV*)b_frag_f8); + swap(b_frag[1], b_frag[2]); } else { v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); } diff --git a/include/flashinfer/frag_layout_swizzle.cuh b/include/flashinfer/frag_layout_swizzle.cuh index ab62498c..3dbfdb9d 100644 --- a/include/flashinfer/frag_layout_swizzle.cuh +++ b/include/flashinfer/frag_layout_swizzle.cuh @@ -39,11 +39,4 @@ __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t return x; } -__device__ __forceinline__ void bfly_exch(uint32_t& a, uint32_t& b) { - uint32_t tmp = __byte_perm(a, b, 0x5410); - uint32_t tmp2 = __byte_perm(a, b, 0x7632); - a = tmp; - b = tmp2; -} - #endif // FLASHINFER_FRAG_LAYOUT_SWIZZLE_CUH_ diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index c604fc2e..c887cd44 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -253,6 +253,12 @@ __device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, uint32_t return (x > y) ? x - y : 0U; } +__device__ __forceinline__ void swap(uint32_t& a, uint32_t& b) { + uint32_t tmp = a; + a = b; + b = tmp; +} + } // namespace flashinfer #endif // FLASHINFER_UTILS_CUH_