@@ -83,9 +83,12 @@ typedef struct {
83
83
} block_q8_0;
84
84
static_assert (sizeof (block_q8_0) == sizeof(float ) + QK8_0, "wrong q8_0 block size/padding");
85
85
86
+ #define WARP_SIZE 32
86
87
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
87
88
// dmmv = dequantize_mul_mat_vec
88
- #define GGML_CUDA_DMMV_BLOCK_X 32
89
+ #ifndef GGML_CUDA_DMMV_BLOCK_X
90
+ #define GGML_CUDA_DMMV_BLOCK_X 32 // can by set by compiler option LLAMA_CUDA_BY
91
+ #endif
89
92
#ifndef GGML_CUDA_DMMV_BLOCK_Y
90
93
#define GGML_CUDA_DMMV_BLOCK_Y 1 // can by set by compiler option LLAMA_CUDA_BY
91
94
#endif
@@ -194,32 +197,40 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
194
197
dequantize_kernel (vx, ib, iqs, v0, v1);
195
198
}
196
199
197
- template <int ncols, int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
198
- static __global__ void dequantize_mul_mat_vec (const void * vx, const float * y, float * dst) {
200
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
201
+ static __global__ void dequantize_mul_mat_vec (const void * vx, const float * y, float * dst, const int ncols) {
202
+ // qk = quantized weights per x block
203
+ // qr = number of quantized weights per data value in x block
199
204
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
200
205
const int tid = threadIdx .x ;
201
206
207
+ const int iter_stride = 2 *GGML_CUDA_DMMV_BLOCK_X;
208
+ const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
202
209
const int y_offset = qr == 1 ? 1 : qk/2 ;
203
210
204
-
205
211
float tmp = 0 ; // partial sum for thread in warp
206
212
207
- #ifdef GGML_CUDA_UNROLL
208
- #pragma unroll
209
- #endif
210
- for (int i = 0 ; i < ncols/block_size; i += 2 ) {
211
- const int col = i*block_size + 2 *tid;
212
- const int ib = (row*ncols + col)/qk; // block index
213
- const int iqs = (col%qk)/qr; // quant index
213
+ for (int i = 0 ; i < ncols; i += iter_stride) {
214
+ const int col = i + vals_per_iter*tid;
215
+ const int ib = (row*ncols + col)/qk; // x block index
216
+ const int iqs = (col%qk)/qr; // x quant index
214
217
const int iybs = col - col%qk; // y block start index
215
218
216
- // dequantize
217
- float v0, v1;
218
- dequantize_kernel (vx, ib, iqs, v0, v1);
219
-
220
- // matrix multiplication
221
- tmp += v0 * y[iybs + iqs + 0 ];
222
- tmp += v1 * y[iybs + iqs + y_offset];
219
+ // processing >2 values per i iter is faster for fast GPUs
220
+ #pragma unroll
221
+ for (int j = 0 ; j < vals_per_iter; j += 2 ) {
222
+ // process 2 vals per j iter
223
+
224
+ // dequantize
225
+ float v0, v1;
226
+ dequantize_kernel (vx, ib, iqs + j/qr, v0, v1);
227
+ // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
228
+
229
+ // matrix multiplication
230
+ tmp += v0 * y[iybs + iqs + j/qr + 0 ];
231
+ tmp += v1 * y[iybs + iqs + j/qr + y_offset];
232
+ // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
233
+ }
223
234
}
224
235
225
236
// sum up partial sums and write back result
@@ -259,72 +270,44 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
259
270
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
260
271
}
261
272
262
- template <dequantize_kernel_t dequantize_kernel, int qk, int qr>
263
- static void dequantize_mul_mat_vec_cuda (const void * vx, const float * y, float * dst,
264
- const int ncols, const int nrows, cudaStream_t stream) {
273
+ static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
265
274
GGML_ASSERT (ncols % GGML_CUDA_DMMV_BLOCK_X == 0 );
266
275
GGML_ASSERT (nrows % GGML_CUDA_DMMV_BLOCK_Y == 0 );
267
- const dim3 block_dims (GGML_CUDA_DMMV_BLOCK_X, GGML_CUDA_DMMV_BLOCK_Y, 1 );
268
-
269
- // Use a switch statement for ncols so the compiler can unroll all loops:
270
- switch (ncols) {
271
- case 4096 :
272
- dequantize_mul_mat_vec<4096 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
273
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
274
- break ;
275
- case 5120 :
276
- dequantize_mul_mat_vec<5120 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
277
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
278
- break ;
279
- case 6656 :
280
- dequantize_mul_mat_vec<6656 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
281
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
282
- break ;
283
- case 8192 :
284
- dequantize_mul_mat_vec<8192 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
285
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
286
- break ;
287
- case 11008 :
288
- dequantize_mul_mat_vec<11008 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
289
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
290
- break ;
291
- case 13824 :
292
- dequantize_mul_mat_vec<13824 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
293
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
294
- break ;
295
- case 17920 :
296
- dequantize_mul_mat_vec<17920 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
297
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
298
- break ;
299
- case 22016 :
300
- dequantize_mul_mat_vec<22016 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
301
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
302
- break ;
303
- default :
304
- fprintf (stderr, " Tell the devs to add a switch case for this: ncols=%d\n " , ncols);
305
- GGML_ASSERT (false );
306
- break ;
307
- }
308
- }
309
-
310
- static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
311
- dequantize_mul_mat_vec_cuda<dequantize_q4_0, QK4_0, QR4_0>(vx, y, dst, ncols, nrows, stream);
276
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1 );
277
+ dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
278
+ <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
312
279
}
313
280
314
281
static void dequantize_mul_mat_vec_q4_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
315
- dequantize_mul_mat_vec_cuda<dequantize_q4_1, QK4_1, QR4_1>(vx, y, dst, ncols, nrows, stream);
282
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_BLOCK_X == 0 );
283
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_BLOCK_Y == 0 );
284
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1 );
285
+ dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
286
+ <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
316
287
}
317
288
318
289
static void dequantize_mul_mat_vec_q5_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
319
- dequantize_mul_mat_vec_cuda<dequantize_q5_0, QK5_0, QR5_0>(vx, y, dst, ncols, nrows, stream);
290
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_BLOCK_X == 0 );
291
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_BLOCK_Y == 0 );
292
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1 );
293
+ dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
294
+ <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
320
295
}
321
296
322
297
static void dequantize_mul_mat_vec_q5_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
323
- dequantize_mul_mat_vec_cuda<dequantize_q5_1, QK5_1, QR5_1>(vx, y, dst, ncols, nrows, stream);
298
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_BLOCK_X == 0 );
299
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_BLOCK_Y == 0 );
300
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1 );
301
+ dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
302
+ <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
324
303
}
325
304
326
305
static void dequantize_mul_mat_vec_q8_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
327
- dequantize_mul_mat_vec_cuda<dequantize_q8_0, QK8_0, QR8_0>(vx, y, dst, ncols, nrows, stream);
306
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_BLOCK_X == 0 );
307
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_BLOCK_Y == 0 );
308
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1 );
309
+ dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
310
+ <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
328
311
}
329
312
330
313
static void convert_fp16_to_fp32_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -333,7 +316,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
333
316
}
334
317
335
318
static void convert_mul_mat_vec_f16_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
336
- dequantize_mul_mat_vec_cuda<convert_f16, 1 , 1 >(vx, y, dst, ncols, nrows, stream);
319
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_BLOCK_X == 0 );
320
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_BLOCK_Y == 0 );
321
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1 );
322
+ dequantize_mul_mat_vec<1 , 1 , convert_f16>
323
+ <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
337
324
}
338
325
339
326
static to_fp32_cuda_t ggml_get_to_fp32_cuda (ggml_type type) {
0 commit comments