Skip to content

Commit

Permalink
perf: slight optimization on fragment layout swizzle (#458)
Browse files Browse the repository at this point in the history
fuse two byte perm into one.
  • Loading branch information
yzh119 authored Aug 21, 2024
1 parent 85b4c77 commit 7c397cb
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 3 deletions.
3 changes: 1 addition & 2 deletions include/flashinfer/frag_layout_swizzle.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) {
}

__device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t x) {
x = __byte_perm(x, x, 0x3120);
uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4);
x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x5410 : 0x3276);
x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x6420 : 0x3175);
tmp = __shfl_xor_sync(0xffffffff, x, 0x8);
x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x5410 : 0x3276);
tmp = __shfl_xor_sync(0xffffffff, x, 0x10);
Expand Down
1 change: 0 additions & 1 deletion include/flashinfer/vec_dtypes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ __device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) {
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | (MASK3 >> 16);
// Final MASK value: 0x7F007F00
q = __byte_perm(q, q, 0x1302);

// Extract and shift FP8 values to FP16 format
Expand Down

0 comments on commit 7c397cb

Please # to comment.