diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 32c244f7..fdd33801 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -72,7 +72,13 @@ } #define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \ - if (max_frags_z >= 4) { \ + if (max_frags_z >= 8) { \ + constexpr size_t NUM_FRAGS_Z = 8; \ + __VA_ARGS__ \ + } else if (max_frags_z >= 6) { \ + constexpr size_t NUM_FRAGS_Z = 6; \ + __VA_ARGS__ \ + } else if (max_frags_z >= 4) { \ constexpr size_t NUM_FRAGS_Z = 4; \ __VA_ARGS__ \ } else if (max_frags_z >= 2) { \