Skip to content

Commit

Permalink
feat: append attention kernels for fp8 kv-cache (#420)
Browse files Browse the repository at this point in the history
This implementation do not rely on fp8 tensor cores, but uses fp16
tensor cores instead, the fp8 kv-cache will be dequantized on-the-fly.

sm_89 and sm_90 append attention kernels that uses native fp8 tensor
cores will be available in later PRs.

---------

Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
  • Loading branch information
yzh119 and comaniac authored Aug 6, 2024
1 parent b781513 commit 906c2f5
Show file tree
Hide file tree
Showing 24 changed files with 1,670 additions and 822 deletions.
59 changes: 48 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ set (IDTYPES "i32")
if(FLASHINFER_ENABLE_FP8)
list(APPEND DECODE_DTYPES "e4m3" "e5m2")
list(APPEND DECODE_FP8_DTYPES "e4m3" "e5m2")
list(APPEND PREFILL_FP8_DTYPES "e4m3" "e5m2")
endif(FLASHINFER_ENABLE_FP8)

if(FLASHINFER_ENABLE_BF16)
Expand Down Expand Up @@ -194,7 +195,7 @@ foreach(head_dim IN LISTS HEAD_DIMS)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src}
Expand All @@ -204,6 +205,18 @@ foreach(head_dim IN LISTS HEAD_DIMS)
)
list(APPEND single_prefill_kernels_src ${generated_kernel_src})
endforeach(dtype)

foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND single_prefill_kernels_src ${generated_kernel_src})
endforeach(dtype_kv)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
Expand All @@ -216,9 +229,9 @@ foreach(head_dim IN LISTS HEAD_DIMS)
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
foreach(idtype IN LISTS IDTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
foreach(idtype IN LISTS IDTYPES)
foreach(dtype IN LISTS PREFILL_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
Expand All @@ -227,8 +240,20 @@ foreach(head_dim IN LISTS HEAD_DIMS)
VERBATIM
)
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
endforeach(idtype)
endforeach(dtype)
endforeach(dtype)

foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
endforeach(dtype_kv)
endforeach(idtype)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
Expand All @@ -241,9 +266,9 @@ foreach(head_dim IN LISTS HEAD_DIMS)
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
foreach(idtype IN LISTS IDTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
foreach(idtype IN LISTS IDTYPES)
foreach(dtype IN LISTS PREFILL_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
Expand All @@ -252,8 +277,20 @@ foreach(head_dim IN LISTS HEAD_DIMS)
VERBATIM
)
list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src})
endforeach(idtype)
endforeach(dtype)
endforeach(dtype)

foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src})
endforeach(dtype_kv)
endforeach(idtype)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
Expand Down
Loading

0 comments on commit 906c2f5

Please # to comment.