Skip to content

Commit 960ee21

Browse files
Fewer iters, more ops per iter
1 parent b00c58c commit 960ee21

File tree

3 files changed

+65
-76
lines changed

3 files changed

+65
-76
lines changed

CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ endif()
6767
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
6868
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
6969
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
70+
set(LLAMA_CUDA_BX "32" CACHE STRING "llama: x block size for dmmv CUDA kernels")
7071
set(LLAMA_CUDA_BY "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
71-
option(LLAMA_CUDA_UNROLL "llama: unroll loops in dmmv CUDA kernels" OFF)
7272
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
7373

7474
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
@@ -192,8 +192,8 @@ if (LLAMA_CUBLAS)
192192
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
193193

194194
add_compile_definitions(GGML_USE_CUBLAS)
195+
add_compile_definitions(GGML_CUDA_DMMV_BLOCK_X=${LLAMA_CUDA_BX})
195196
add_compile_definitions(GGML_CUDA_DMMV_BLOCK_Y=${LLAMA_CUDA_BY})
196-
add_compile_definitions(GGML_CUDA_UNROLL=${LLAMA_CUDA_UNROLL})
197197

198198
if (LLAMA_STATIC)
199199
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)

Makefile

+5-3
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,16 @@ ifdef LLAMA_CUBLAS
129129
OBJS += ggml-cuda.o
130130
NVCC = nvcc
131131
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
132+
ifdef LLAMA_CUDA_BX
133+
NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_X=$(LLAMA_CUDA_BX)
134+
else
135+
NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_X=32
136+
endif # LLAMA_CUDA_BY
132137
ifdef LLAMA_CUDA_BY
133138
NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=$(LLAMA_CUDA_BY)
134139
else
135140
NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=1
136141
endif # LLAMA_CUDA_BY
137-
ifdef LLAMA_CUDA_UNROLL
138-
NVCCFLAGS += -DGGML_CUDA_UNROLL=$(LLAMA_CUDA_UNROLL)
139-
endif # LLAMA_CUDA_UNROLL
140142
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
141143
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
142144
endif # LLAMA_CUBLAS

ggml-cuda.cu

+58-71
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,12 @@ typedef struct {
8383
} block_q8_0;
8484
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
8585

86+
#define WARP_SIZE 32
8687
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
8788
// dmmv = dequantize_mul_mat_vec
88-
#define GGML_CUDA_DMMV_BLOCK_X 32
89+
#ifndef GGML_CUDA_DMMV_BLOCK_X
90+
#define GGML_CUDA_DMMV_BLOCK_X 32 // can by set by compiler option LLAMA_CUDA_BY
91+
#endif
8992
#ifndef GGML_CUDA_DMMV_BLOCK_Y
9093
#define GGML_CUDA_DMMV_BLOCK_Y 1 // can by set by compiler option LLAMA_CUDA_BY
9194
#endif
@@ -194,32 +197,40 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
194197
dequantize_kernel(vx, ib, iqs, v0, v1);
195198
}
196199

197-
template <int ncols, int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
198-
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst) {
200+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
201+
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
202+
// qk = quantized weights per x block
203+
// qr = number of quantized weights per data value in x block
199204
const int row = blockIdx.x*blockDim.y + threadIdx.y;
200205
const int tid = threadIdx.x;
201206

207+
const int iter_stride = 2*GGML_CUDA_DMMV_BLOCK_X;
208+
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
202209
const int y_offset = qr == 1 ? 1 : qk/2;
203210

204-
205211
float tmp = 0; // partial sum for thread in warp
206212

207-
#ifdef GGML_CUDA_UNROLL
208-
#pragma unroll
209-
#endif
210-
for (int i = 0; i < ncols/block_size; i += 2) {
211-
const int col = i*block_size + 2*tid;
212-
const int ib = (row*ncols + col)/qk; // block index
213-
const int iqs = (col%qk)/qr; // quant index
213+
for (int i = 0; i < ncols; i += iter_stride) {
214+
const int col = i + vals_per_iter*tid;
215+
const int ib = (row*ncols + col)/qk; // x block index
216+
const int iqs = (col%qk)/qr; // x quant index
214217
const int iybs = col - col%qk; // y block start index
215218

216-
// dequantize
217-
float v0, v1;
218-
dequantize_kernel(vx, ib, iqs, v0, v1);
219-
220-
// matrix multiplication
221-
tmp += v0 * y[iybs + iqs + 0];
222-
tmp += v1 * y[iybs + iqs + y_offset];
219+
// processing >2 values per i iter is faster for fast GPUs
220+
#pragma unroll
221+
for (int j = 0; j < vals_per_iter; j += 2) {
222+
// process 2 vals per j iter
223+
224+
// dequantize
225+
float v0, v1;
226+
dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
227+
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
228+
229+
// matrix multiplication
230+
tmp += v0 * y[iybs + iqs + j/qr + 0];
231+
tmp += v1 * y[iybs + iqs + j/qr + y_offset];
232+
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
233+
}
223234
}
224235

225236
// sum up partial sums and write back result
@@ -259,72 +270,44 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
259270
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
260271
}
261272

262-
template<dequantize_kernel_t dequantize_kernel, int qk, int qr>
263-
static void dequantize_mul_mat_vec_cuda(const void * vx, const float * y, float * dst,
264-
const int ncols, const int nrows, cudaStream_t stream) {
273+
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
265274
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
266275
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
267-
const dim3 block_dims(GGML_CUDA_DMMV_BLOCK_X, GGML_CUDA_DMMV_BLOCK_Y, 1);
268-
269-
// Use a switch statement for ncols so the compiler can unroll all loops:
270-
switch (ncols) {
271-
case 4096:
272-
dequantize_mul_mat_vec<4096, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
273-
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
274-
break;
275-
case 5120:
276-
dequantize_mul_mat_vec<5120, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
277-
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
278-
break;
279-
case 6656:
280-
dequantize_mul_mat_vec<6656, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
281-
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
282-
break;
283-
case 8192:
284-
dequantize_mul_mat_vec<8192, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
285-
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
286-
break;
287-
case 11008:
288-
dequantize_mul_mat_vec<11008, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
289-
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
290-
break;
291-
case 13824:
292-
dequantize_mul_mat_vec<13824, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
293-
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
294-
break;
295-
case 17920:
296-
dequantize_mul_mat_vec<17920, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
297-
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
298-
break;
299-
case 22016:
300-
dequantize_mul_mat_vec<22016, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
301-
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
302-
break;
303-
default:
304-
fprintf(stderr, "Tell the devs to add a switch case for this: ncols=%d\n", ncols);
305-
GGML_ASSERT(false);
306-
break;
307-
}
308-
}
309-
310-
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
311-
dequantize_mul_mat_vec_cuda<dequantize_q4_0, QK4_0, QR4_0>(vx, y, dst, ncols, nrows, stream);
276+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1);
277+
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
278+
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
312279
}
313280

314281
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
315-
dequantize_mul_mat_vec_cuda<dequantize_q4_1, QK4_1, QR4_1>(vx, y, dst, ncols, nrows, stream);
282+
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
283+
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
284+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1);
285+
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
286+
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
316287
}
317288

318289
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
319-
dequantize_mul_mat_vec_cuda<dequantize_q5_0, QK5_0, QR5_0>(vx, y, dst, ncols, nrows, stream);
290+
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
291+
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
292+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1);
293+
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
294+
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
320295
}
321296

322297
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
323-
dequantize_mul_mat_vec_cuda<dequantize_q5_1, QK5_1, QR5_1>(vx, y, dst, ncols, nrows, stream);
298+
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
299+
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
300+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1);
301+
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
302+
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
324303
}
325304

326305
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
327-
dequantize_mul_mat_vec_cuda<dequantize_q8_0, QK8_0, QR8_0>(vx, y, dst, ncols, nrows, stream);
306+
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
307+
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
308+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1);
309+
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
310+
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
328311
}
329312

330313
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -333,7 +316,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
333316
}
334317

335318
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
336-
dequantize_mul_mat_vec_cuda<convert_f16, 1, 1>(vx, y, dst, ncols, nrows, stream);
319+
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
320+
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
321+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1);
322+
dequantize_mul_mat_vec<1, 1, convert_f16>
323+
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
337324
}
338325

339326
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {

0 commit comments

Comments
 (0)