Skip to content

Commit

Permalink
perf: slight optimization on f16->f8 fragment layout swizzling (#453)
Browse files Browse the repository at this point in the history
swap after dequantize.
  • Loading branch information
yzh119 authored Aug 18, 2024
1 parent fa38b5e commit 0d61871
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -802,8 +802,8 @@ __device__ __forceinline__ void compute_sfm_v(smem_t<swizzle_mode>* v_smem,
}
b_frag_f8[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]);
b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]);
bfly_exch(b_frag_f8[0], b_frag_f8[1]);
vec_cast<DTypeQ, DTypeKV>::cast<8>((DTypeQ*)b_frag, (DTypeKV*)b_frag_f8);
swap(b_frag[1], b_frag[2]);
} else {
v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag);
}
Expand Down
7 changes: 0 additions & 7 deletions include/flashinfer/frag_layout_swizzle.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,4 @@ __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t
return x;
}

__device__ __forceinline__ void bfly_exch(uint32_t& a, uint32_t& b) {
uint32_t tmp = __byte_perm(a, b, 0x5410);
uint32_t tmp2 = __byte_perm(a, b, 0x7632);
a = tmp;
b = tmp2;
}

#endif // FLASHINFER_FRAG_LAYOUT_SWIZZLE_CUH_
6 changes: 6 additions & 0 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,12 @@ __device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, uint32_t
return (x > y) ? x - y : 0U;
}

__device__ __forceinline__ void swap(uint32_t& a, uint32_t& b) {
uint32_t tmp = a;
a = b;
b = tmp;
}

} // namespace flashinfer

#endif // FLASHINFER_UTILS_CUH_

0 comments on commit 0d61871

Please # to comment.