Skip to content

Commit

Permalink
bugfix: fix the alignment of o_frag (#608)
Browse files Browse the repository at this point in the history
Since `o_frag` was not always aligned to a 16-byte boundary, `memcpy`
implemented using 4x float moves was crashing in `cuda-gdb` when
compiled with `-G`.
  • Loading branch information
nandor authored Nov 13, 2024
1 parent 45e9273 commit 32d9510
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV
constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b<DTypeO>();

DTypeQKAccum s_frag[NUM_FRAGS_Q][NUM_FRAGS_KV][8];
float o_frag[NUM_FRAGS_Q][NUM_FRAGS_D][8];
alignas(16) float o_frag[NUM_FRAGS_Q][NUM_FRAGS_D][8];
DTypeQKAccum m[NUM_FRAGS_Q][2];
float d[NUM_FRAGS_Q][2];
float rope_freq[NUM_FRAGS_D / 2][4];
Expand Down Expand Up @@ -1579,7 +1579,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag
constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b<DTypeO>();

DTypeQKAccum s_frag[NUM_FRAGS_Q][NUM_FRAGS_KV][8];
float o_frag[NUM_FRAGS_Q][NUM_FRAGS_D][8];
alignas(16) float o_frag[NUM_FRAGS_Q][NUM_FRAGS_D][8];
DTypeQKAccum m[NUM_FRAGS_Q][2];
float d[NUM_FRAGS_Q][2];
float rope_freq[NUM_FRAGS_D / 2][4];
Expand Down Expand Up @@ -1866,7 +1866,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag
constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b<DTypeO>();

DTypeQKAccum s_frag[NUM_FRAGS_Q][NUM_FRAGS_KV][8];
float o_frag[NUM_FRAGS_Q][NUM_FRAGS_D][8];
alignas(16) float o_frag[NUM_FRAGS_Q][NUM_FRAGS_D][8];
DTypeQKAccum m[NUM_FRAGS_Q][2];
float d[NUM_FRAGS_Q][2];
float rope_freq[NUM_FRAGS_D / 2][4];
Expand Down

0 comments on commit 32d9510

Please # to comment.