Skip to content

Commit 60aabba

Browse files
dsikkaFaraz9877ilmarkovrahul-tulirobertgshaw2-redhat
authored andcommitted
[Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support (vllm-project#10995)
Co-authored-by: Faraz Shahsavan <faraz.shahsavan@gmail.com> Co-authored-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Rahul Tuli <rahul@neuralmagic.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
1 parent 791e1d6 commit 60aabba

30 files changed

+2365
-117
lines changed

CMakeLists.txt

+16-10
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
206206
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
207207

208208
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
209-
set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use")
209+
set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")
210210

211211
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
212212
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
223223
FetchContent_Declare(
224224
cutlass
225225
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
226-
GIT_TAG v3.5.1
226+
GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227
227227
GIT_PROGRESS TRUE
228228

229229
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
230230
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
231231
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
232-
GIT_SHALLOW TRUE
232+
GIT_SHALLOW FALSE
233233
)
234234
endif()
235235
FetchContent_MakeAvailable(cutlass)
@@ -241,7 +241,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
241241
"csrc/quantization/awq/gemm_kernels.cu"
242242
"csrc/custom_all_reduce.cu"
243243
"csrc/permute_cols.cu"
244-
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
244+
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
245+
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
246+
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
247+
"csrc/cutlass_extensions/common.cpp")
245248

246249
set_gencode_flags_for_srcs(
247250
SRCS "${VLLM_EXT_SRC}"
@@ -271,11 +274,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
271274
endif()
272275

273276
#
274-
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
277+
# The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels
278+
# For Hopper (c3x, i.e. CUTLASS 3.x) require
275279
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
276280
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
277281
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
278-
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
282+
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
283+
"csrc/sparse/cutlass/sparse_compressor_c3x.cu"
284+
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
279285
set_gencode_flags_for_srcs(
280286
SRCS "${SRCS}"
281287
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
@@ -284,12 +290,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
284290
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
285291
else()
286292
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
287-
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
293+
message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is "
288294
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
289-
"later if you intend on running FP8 quantized models on "
295+
"later if you intend on running FP8 sparse or quantized models on "
290296
"Hopper.")
291297
else()
292-
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
298+
message(STATUS "Not building cutlass_c3x as no compatible archs found "
293299
"in CUDA target architectures")
294300
endif()
295301

@@ -404,7 +410,7 @@ define_gpu_extension_target(
404410
SOURCES ${VLLM_EXT_SRC}
405411
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
406412
ARCHITECTURES ${VLLM_GPU_ARCHES}
407-
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
413+
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
408414
USE_SABI 3
409415
WITH_SOABI)
410416

0 commit comments

Comments
 (0)