Skip to content

Commit

Permalink
bugfix: Fix invalid kernel configuration for sm86 (#385)
Browse files Browse the repository at this point in the history
Related issue: vllm-project/vllm#6395
  • Loading branch information
yzh119 authored Jul 20, 2024
1 parent 457a0ae commit cdac577
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1804,7 +1804,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
// we expect each sm execute two threadblocks
const int max_smem_per_threadblock = max_smem_per_sm / 2;
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2: 1;
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;

constexpr uint32_t num_warps_x = get_num_warps_x<WARP_LAYOUT>();
constexpr uint32_t num_warps_z = get_num_warps_z<WARP_LAYOUT>();
Expand Down Expand Up @@ -1949,7 +1950,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm,
cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
// we expect each sm execute two threadblocks
const int max_smem_per_threadblock = max_smem_per_sm / 2;
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2: 1;
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;

const uint32_t max_num_frags_z_reg =
(HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama &&
Expand Down Expand Up @@ -2089,7 +2091,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm,
cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
// we expect each sm execute two threadblocks
const int max_smem_per_threadblock = max_smem_per_sm / 2;
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2: 1;
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;

const uint32_t max_num_frags_z_reg =
(HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama &&
Expand Down

0 comments on commit cdac577

Please # to comment.