From daa556697fed849810745f0aae0015d8e4460050 Mon Sep 17 00:00:00 2001 From: Roy Date: Fri, 9 Aug 2024 01:36:37 +0800 Subject: [PATCH] bugfix: fix dispatch fp16 type when enable fp8 (#430) Fix https://github.com/flashinfer-ai/flashinfer/issues/402#issuecomment-2254175747 --- python/csrc/pytorch_extension_utils.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 72afcd28..a702c8ee 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -146,6 +146,10 @@ using namespace flashinfer; #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ case at::ScalarType::Float8_e4m3fn: { \ using c_type = __nv_fp8_e4m3; \ return __VA_ARGS__(); \