From fbf5588abc773fcfd63e2bb2647a9a1ae67a4cd8 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Fri, 19 May 2023 12:59:37 +0200 Subject: [PATCH 1/8] xor hack --- ggml-cuda.cu | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 35d2e457cbf84..1a64ff6b1a8d5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -207,8 +207,8 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, const int y_offset = qr == 1 ? 1 : qk/2; - __shared__ float tmp[block_size]; // separate sum for each thread - tmp[tid] = 0; + + float tmp = 0; // partial sum for thread in warp for (int i = 0; i < ncols/block_size; i += 2) { const int col = i*block_size + 2*tid; @@ -221,20 +221,30 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, dequantize_kernel(vx, ib, iqs, v0, v1); // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; + tmp += v0 * y[iybs + iqs + 0]; + tmp += v1 * y[iybs + iqs + y_offset]; } // sum up partial sums and write back result __syncthreads(); +#ifdef GGML_USE_HIPBLAS + __shared__ float tmpa[block_size]; + tmpa[tid] = tmp; for (int s=block_size/2; s>0; s>>=1) { if (tid < s) { - tmp[tid] += tmp[tid + s]; + tmpa[tid] += tmpa[tid + s]; } __syncthreads(); } + tmp = tmpa[0]; // now full sum +#else + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } +#endif + if (tid == 0) { - dst[row] = tmp[0]; + dst[row] = tmp; } } From 1a787101ccf0177f3dbd1c418ee87b20554c0a1a Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Fri, 19 May 2023 17:24:05 +0200 Subject: [PATCH 2/8] block y dim --- CMakeLists.txt | 2 ++ Makefile | 7 ++++++- ggml-cuda.cu | 42 ++++++++++++++++++++++-------------------- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 39db2e3fc5c23..6c43c5ec5e827 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,6 +68,7 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework" option(LLAMA_BLAS "llama: use BLAS" OFF) option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +option(LLAMA_DMMV_BLOCK_Y "llama: y block size for dmmv CUDA kernels" 1) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) @@ -184,6 +185,7 @@ if (LLAMA_CUBLAS) set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) + add_compile_definitions(CUDA_DMMV_BLOCK_Y=${LLAMA_DMMV_BLOCK_Y}) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) diff --git a/Makefile b/Makefile index 08e2503146018..3e46b2b5cf4f7 100644 --- a/Makefile +++ b/Makefile @@ -133,9 +133,14 @@ ifdef LLAMA_CUBLAS OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native +ifdef LLAMA_DMMV_BLOCK_Y + NVCCFLAGS += -DCUDA_DMMV_BLOCK_Y=$(LLAMA_DMMV_BLOCK_Y) +else + NVCCFLAGS += -DCUDA_DMMV_BLOCK_Y=1 +endif # LLAMA_DMMV_BLOCK_Y ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ -endif +endif # LLAMA_CUBLAS ifdef LLAMA_CLBLAST CFLAGS += -DGGML_USE_CLBLAST CXXFLAGS += -DGGML_USE_CLBLAST diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1a64ff6b1a8d5..c9366705dd93e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -85,7 +85,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo #define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 -#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec +#define CUDA_DMMV_BLOCK_X 32 // dmmv = dequantize_mul_mat_vec static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -202,7 +202,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) template static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { - const int row = blockIdx.x; + const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; const int y_offset = qr == 1 ? 1 : qk/2; @@ -279,33 +279,35 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu } static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(CUDA_DMMV_BLOCK_X, CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -314,9 +316,9 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c } static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { From 82cf01f8979379c7b5027f892ced08f95f289bc3 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Fri, 19 May 2023 20:26:49 +0200 Subject: [PATCH 3/8] loop unrolling --- CMakeLists.txt | 6 ++-- Makefile | 11 ++++--- ggml-cuda.cu | 84 ++++++++++++++++++++++++++++++++++++-------------- 3 files changed, 72 insertions(+), 29 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6c43c5ec5e827..b9936268e95cc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,7 +68,8 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework" option(LLAMA_BLAS "llama: use BLAS" OFF) option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) -option(LLAMA_DMMV_BLOCK_Y "llama: y block size for dmmv CUDA kernels" 1) +option(LLAMA_CUDA_BY "llama: y block size for dmmv CUDA kernels" 1) +option(LLAMA_CUDA_UNROLL "llama: unroll loops in dmmv CUDA kernels" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) @@ -185,7 +186,8 @@ if (LLAMA_CUBLAS) set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) - add_compile_definitions(CUDA_DMMV_BLOCK_Y=${LLAMA_DMMV_BLOCK_Y}) + add_compile_definitions(GGML_CUDA_DMMV_BLOCK_Y=${LLAMA_CUDA_BY}) + add_compile_definitions(GGML_CUDA_UNROLL=${LLAMA_CUDA_UNROLL}) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) diff --git a/Makefile b/Makefile index 3e46b2b5cf4f7..7008689646fa8 100644 --- a/Makefile +++ b/Makefile @@ -133,11 +133,14 @@ ifdef LLAMA_CUBLAS OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native -ifdef LLAMA_DMMV_BLOCK_Y - NVCCFLAGS += -DCUDA_DMMV_BLOCK_Y=$(LLAMA_DMMV_BLOCK_Y) +ifdef LLAMA_CUDA_BY + NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=$(LLAMA_CUDA_BY) else - NVCCFLAGS += -DCUDA_DMMV_BLOCK_Y=1 -endif # LLAMA_DMMV_BLOCK_Y + NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=1 +endif # LLAMA_CUDA_BY +ifdef LLAMA_CUDA_UNROLL + NVCCFLAGS += -DGGML_CUDA_UNROLL=$(LLAMA_CUDA_UNROLL) +endif # LLAMA_CUDA_UNROLL ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif # LLAMA_CUBLAS diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c9366705dd93e..9fb5f50feaaa5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -85,7 +85,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo #define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 -#define CUDA_DMMV_BLOCK_X 32 // dmmv = dequantize_mul_mat_vec +#define GGML_CUDA_DMMV_BLOCK_X 32 // dmmv = dequantize_mul_mat_vec static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -200,8 +200,8 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) dequantize_kernel(vx, ib, iqs, v0, v1); } -template -static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { +template +static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; @@ -210,6 +210,9 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float tmp = 0; // partial sum for thread in warp +#ifdef GGML_CUDA_UNROLL +#pragma unroll +#endif for (int i = 0; i < ncols/block_size; i += 2) { const int col = i*block_size + 2*tid; const int ib = (row*ncols + col)/qk; // block index @@ -238,6 +241,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, } tmp = tmpa[0]; // now full sum #else +#pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } @@ -278,36 +282,72 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu dequantize_block<<>>(vx, y, k); } +template +static void dequantize_mul_mat_vec_cuda(const void * vx, const float * y, float * dst, + const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(GGML_CUDA_DMMV_BLOCK_X, GGML_CUDA_DMMV_BLOCK_Y, 1); + + // Use a switch statement for ncols so the compiler can unroll all loops: + switch (ncols) { + case 4096: + dequantize_mul_mat_vec<4096, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 5120: + dequantize_mul_mat_vec<5120, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 6656: + dequantize_mul_mat_vec<6656, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 8192: + dequantize_mul_mat_vec<8192, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 11008: + dequantize_mul_mat_vec<11008, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 13824: + dequantize_mul_mat_vec<13824, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 17920: + dequantize_mul_mat_vec<17920, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 22016: + dequantize_mul_mat_vec<22016, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + default: + fprintf(stderr, "Tell the devs to add a switch case for this: ncols=%d\n", ncols); + GGML_ASSERT(false); + break; + } +} + static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); - GGML_ASSERT(nrows % CUDA_DMMV_BLOCK_Y == 0); - const dim3 block_dims(CUDA_DMMV_BLOCK_X, CUDA_DMMV_BLOCK_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); } static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); } static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); } static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); } static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -316,9 +356,7 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c } static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); } static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { From 17dc4c52d384d3239f0f2d099ff9ac93ef26cfdf Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 20 May 2023 14:31:11 +0200 Subject: [PATCH 4/8] Fixed cmake LLAMA_CUDA_BY option --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b9936268e95cc..0dff0005d3638 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,7 +68,7 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework" option(LLAMA_BLAS "llama: use BLAS" OFF) option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) -option(LLAMA_CUDA_BY "llama: y block size for dmmv CUDA kernels" 1) +set(LLAMA_CUDA_BY "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") option(LLAMA_CUDA_UNROLL "llama: unroll loops in dmmv CUDA kernels" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) From 5d0cf9928b86e869cd4997586e27f3fb45fb43d5 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 20 May 2023 14:36:26 +0200 Subject: [PATCH 5/8] Removed hipblas compatibility code --- ggml-cuda.cu | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 9fb5f50feaaa5..44e6445ecb50a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -230,22 +230,10 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, // sum up partial sums and write back result __syncthreads(); -#ifdef GGML_USE_HIPBLAS - __shared__ float tmpa[block_size]; - tmpa[tid] = tmp; - for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmpa[tid] += tmpa[tid + s]; - } - __syncthreads(); - } - tmp = tmpa[0]; // now full sum -#else #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } -#endif if (tid == 0) { dst[row] = tmp; From e199938a3a6e0c9515d6cd63b3161926483e84ed Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 20 May 2023 19:37:11 +0200 Subject: [PATCH 6/8] Define GGML_CUDA_DMMV_BLOCK_Y if not defined --- ggml-cuda.cu | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 44e6445ecb50a..afab6704ffec4 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -85,7 +85,11 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo #define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 -#define GGML_CUDA_DMMV_BLOCK_X 32 // dmmv = dequantize_mul_mat_vec +// dmmv = dequantize_mul_mat_vec +#define GGML_CUDA_DMMV_BLOCK_X 32 +#ifndef GGML_CUDA_DMMV_BLOCK_Y +#define GGML_CUDA_DMMV_BLOCK_Y 1 // can by set by compiler option LLAMA_CUDA_BY +#endif static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; From 98bfee013b18189a3335d6cf392fb69bbcd4e4ee Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sun, 21 May 2023 12:01:14 +0200 Subject: [PATCH 7/8] Fewer iters, more ops per iter --- CMakeLists.txt | 4 +- Makefile | 8 +-- ggml-cuda.cu | 132 +++++++++++++++++++++++-------------------------- 3 files changed, 68 insertions(+), 76 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0dff0005d3638..6906a849a9412 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,8 +68,8 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework" option(LLAMA_BLAS "llama: use BLAS" OFF) option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +set(LLAMA_CUDA_BX "32" CACHE STRING "llama: x block size for dmmv CUDA kernels") set(LLAMA_CUDA_BY "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") -option(LLAMA_CUDA_UNROLL "llama: unroll loops in dmmv CUDA kernels" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) @@ -186,8 +186,8 @@ if (LLAMA_CUBLAS) set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) + add_compile_definitions(GGML_CUDA_DMMV_BLOCK_X=${LLAMA_CUDA_BX}) add_compile_definitions(GGML_CUDA_DMMV_BLOCK_Y=${LLAMA_CUDA_BY}) - add_compile_definitions(GGML_CUDA_UNROLL=${LLAMA_CUDA_UNROLL}) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) diff --git a/Makefile b/Makefile index 7008689646fa8..56d67977b109d 100644 --- a/Makefile +++ b/Makefile @@ -133,14 +133,16 @@ ifdef LLAMA_CUBLAS OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native +ifdef LLAMA_CUDA_BX + NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_X=$(LLAMA_CUDA_BX) +else + NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_X=32 +endif # LLAMA_CUDA_BY ifdef LLAMA_CUDA_BY NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=$(LLAMA_CUDA_BY) else NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=1 endif # LLAMA_CUDA_BY -ifdef LLAMA_CUDA_UNROLL - NVCCFLAGS += -DGGML_CUDA_UNROLL=$(LLAMA_CUDA_UNROLL) -endif # LLAMA_CUDA_UNROLL ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif # LLAMA_CUBLAS diff --git a/ggml-cuda.cu b/ggml-cuda.cu index afab6704ffec4..6b9176697d24a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -83,10 +83,16 @@ typedef struct { } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); +#define WARP_SIZE 32 + #define CUDA_MUL_BLOCK_SIZE 256 + #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 + // dmmv = dequantize_mul_mat_vec -#define GGML_CUDA_DMMV_BLOCK_X 32 +#ifndef GGML_CUDA_DMMV_BLOCK_X +#define GGML_CUDA_DMMV_BLOCK_X 32 // can by set by compiler option LLAMA_CUDA_BY +#endif #ifndef GGML_CUDA_DMMV_BLOCK_Y #define GGML_CUDA_DMMV_BLOCK_Y 1 // can by set by compiler option LLAMA_CUDA_BY #endif @@ -204,32 +210,40 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) dequantize_kernel(vx, ib, iqs, v0, v1); } -template -static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst) { +template +static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; + const int iter_stride = 2*GGML_CUDA_DMMV_BLOCK_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter const int y_offset = qr == 1 ? 1 : qk/2; - float tmp = 0; // partial sum for thread in warp -#ifdef GGML_CUDA_UNROLL -#pragma unroll -#endif - for (int i = 0; i < ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*ncols + col)/qk; // block index - const int iqs = (col%qk)/qr; // quant index + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index const int iybs = col - col%qk; // y block start index - // dequantize - float v0, v1; - dequantize_kernel(vx, ib, iqs, v0, v1); - - // matrix multiplication - tmp += v0 * y[iybs + iqs + 0]; - tmp += v1 * y[iybs + iqs + y_offset]; +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + float v0, v1; + dequantize_kernel(vx, ib, iqs + j/qr, v0, v1); + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + + // matrix multiplication + tmp += v0 * y[iybs + iqs + j/qr + 0]; + tmp += v1 * y[iybs + iqs + j/qr + y_offset]; + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 + } } // sum up partial sums and write back result @@ -274,72 +288,44 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu dequantize_block<<>>(vx, y, k); } -template -static void dequantize_mul_mat_vec_cuda(const void * vx, const float * y, float * dst, - const int ncols, const int nrows, cudaStream_t stream) { +static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); - const dim3 block_dims(GGML_CUDA_DMMV_BLOCK_X, GGML_CUDA_DMMV_BLOCK_Y, 1); - - // Use a switch statement for ncols so the compiler can unroll all loops: - switch (ncols) { - case 4096: - dequantize_mul_mat_vec<4096, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 5120: - dequantize_mul_mat_vec<5120, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 6656: - dequantize_mul_mat_vec<6656, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 8192: - dequantize_mul_mat_vec<8192, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 11008: - dequantize_mul_mat_vec<11008, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 13824: - dequantize_mul_mat_vec<13824, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 17920: - dequantize_mul_mat_vec<17920, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 22016: - dequantize_mul_mat_vec<22016, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - default: - fprintf(stderr, "Tell the devs to add a switch case for this: ncols=%d\n", ncols); - GGML_ASSERT(false); - break; - } -} - -static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -348,7 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c } static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec<1, 1, convert_f16> + <<>>(vx, y, dst, ncols); } static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { From d45df1b1f4d49318a65ace6a3cfd181b4436bbe9 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Tue, 23 May 2023 08:55:01 +0200 Subject: [PATCH 8/8] Renamed DMMV X/Y compilation options --- CMakeLists.txt | 58 +++++++++++++++++++++++++------------------------- Makefile | 16 +++++++------- ggml-cuda.cu | 58 +++++++++++++++++++++++++------------------------- 3 files changed, 66 insertions(+), 66 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6906a849a9412..31c5bd91d196c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,44 +37,44 @@ endif() # # general -option(LLAMA_STATIC "llama: static link libraries" OFF) -option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) -option(LLAMA_LTO "llama: enable link time optimization" OFF) +option(LLAMA_STATIC "llama: static link libraries" OFF) +option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) +option(LLAMA_LTO "llama: enable link time optimization" OFF) # debug -option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) -option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) -option(LLAMA_GPROF "llama: enable gprof" OFF) +option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) +option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) +option(LLAMA_GPROF "llama: enable gprof" OFF) # sanitizers -option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) -option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) -option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) +option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) +option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) +option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) # instruction set specific -option(LLAMA_AVX "llama: enable AVX" ON) -option(LLAMA_AVX2 "llama: enable AVX2" ON) -option(LLAMA_AVX512 "llama: enable AVX512" OFF) -option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) -option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) -option(LLAMA_FMA "llama: enable FMA" ON) +option(LLAMA_AVX "llama: enable AVX" ON) +option(LLAMA_AVX2 "llama: enable AVX2" ON) +option(LLAMA_AVX512 "llama: enable AVX512" OFF) +option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) +option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) +option(LLAMA_FMA "llama: enable FMA" ON) # in MSVC F16C is implied with AVX2/AVX512 if (NOT MSVC) - option(LLAMA_F16C "llama: enable F16C" ON) + option(LLAMA_F16C "llama: enable F16C" ON) endif() # 3rd party libs -option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) -option(LLAMA_BLAS "llama: use BLAS" OFF) -option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) -option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) -set(LLAMA_CUDA_BX "32" CACHE STRING "llama: x block size for dmmv CUDA kernels") -set(LLAMA_CUDA_BY "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") -option(LLAMA_CLBLAST "llama: use CLBlast" OFF) - -option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) -option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) -option(LLAMA_BUILD_SERVER "llama: build server example" OFF) +option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) +option(LLAMA_BLAS "llama: use BLAS" OFF) +option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) +option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") +set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) + +option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_SERVER "llama: build server example" OFF) # # Build info header @@ -186,8 +186,8 @@ if (LLAMA_CUBLAS) set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) - add_compile_definitions(GGML_CUDA_DMMV_BLOCK_X=${LLAMA_CUDA_BX}) - add_compile_definitions(GGML_CUDA_DMMV_BLOCK_Y=${LLAMA_CUDA_BY}) + add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) + add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y}) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) diff --git a/Makefile b/Makefile index 56d67977b109d..804307b531703 100644 --- a/Makefile +++ b/Makefile @@ -133,16 +133,16 @@ ifdef LLAMA_CUBLAS OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native -ifdef LLAMA_CUDA_BX - NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_X=$(LLAMA_CUDA_BX) +ifdef LLAMA_CUDA_DMMV_X + NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) else - NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_X=32 -endif # LLAMA_CUDA_BY -ifdef LLAMA_CUDA_BY - NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=$(LLAMA_CUDA_BY) + NVCCFLAGS += -DGGML_CUDA_DMMV_X=32 +endif # LLAMA_CUDA_DMMV_X +ifdef LLAMA_CUDA_DMMV_Y + NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y) else - NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=1 -endif # LLAMA_CUDA_BY + NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1 +endif # LLAMA_CUDA_DMMV_Y ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif # LLAMA_CUBLAS diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6b9176697d24a..98170a3ae17de 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -90,11 +90,11 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 // dmmv = dequantize_mul_mat_vec -#ifndef GGML_CUDA_DMMV_BLOCK_X -#define GGML_CUDA_DMMV_BLOCK_X 32 // can by set by compiler option LLAMA_CUDA_BY +#ifndef GGML_CUDA_DMMV_X +#define GGML_CUDA_DMMV_X 32 #endif -#ifndef GGML_CUDA_DMMV_BLOCK_Y -#define GGML_CUDA_DMMV_BLOCK_Y 1 // can by set by compiler option LLAMA_CUDA_BY +#ifndef GGML_CUDA_DMMV_Y +#define GGML_CUDA_DMMV_Y 1 #endif static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { @@ -217,7 +217,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const int iter_stride = 2*GGML_CUDA_DMMV_BLOCK_X; + const int iter_stride = 2*GGML_CUDA_DMMV_X; const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter const int y_offset = qr == 1 ? 1 : qk/2; @@ -289,43 +289,43 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu } static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); - GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); - GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); - GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); - GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); - GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + <<>>(vx, y, dst, ncols); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -334,11 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c } static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); - GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec<1, 1, convert_f16> - <<>>(vx, y, dst, ncols); + <<>>(vx, y, dst, ncols); } static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {