Skip to content

Commit

Permalink
bugfix: fix dispatch fp16 type when enable fp8 (#430)
Browse files Browse the repository at this point in the history
  • Loading branch information
esmeetu authored Aug 8, 2024
1 parent d52f2da commit daa5566
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ using namespace flashinfer;
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
case at::ScalarType::Half: { \
using c_type = nv_half; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Float8_e4m3fn: { \
using c_type = __nv_fp8_e4m3; \
return __VA_ARGS__(); \
Expand Down

0 comments on commit daa5566

Please # to comment.