@@ -83,7 +83,7 @@ static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block siz
83
83
static __global__ void dequantize_block_q4_0 (const void * vx, float * y) {
84
84
const block_q4_0 * x = (const block_q4_0 *) vx;
85
85
86
- const int i = blockIdx .x ;
86
+ const int i = blockIdx .x * blockDim . x + threadIdx . x ;
87
87
88
88
const float d = x[i].d ;
89
89
@@ -182,7 +182,7 @@ static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
182
182
static __global__ void dequantize_block_q5_1 (const void * vx, float * y) {
183
183
const block_q5_1 * x = (const block_q5_1 *) vx;
184
184
185
- const int i = blockIdx .x ;
185
+ const int i = blockIdx .x * blockDim . x + threadIdx . x ;
186
186
187
187
const float d = x[i].d ;
188
188
const float m = x[i].m ;
@@ -227,7 +227,8 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
227
227
228
228
static void dequantize_row_q4_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
229
229
const int nb = k / QK4_0;
230
- dequantize_block_q4_0<<<nb, 1 , 0 , stream>>> (vx, y);
230
+ GGML_ASSERT (nb % 256 == 0 );
231
+ dequantize_block_q4_0<<<nb/256 , 256 , 0 , stream>>> (vx, y);
231
232
}
232
233
233
234
static void dequantize_row_q4_1_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
@@ -247,7 +248,8 @@ static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStre
247
248
248
249
static void dequantize_row_q5_1_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
249
250
const int nb = k / QK5_1;
250
- dequantize_block_q5_1<<<nb, 1 , 0 , stream>>> (vx, y);
251
+ GGML_ASSERT (nb % 256 == 0 );
252
+ dequantize_block_q5_1<<<nb/256 , 256 , 0 , stream>>> (vx, y);
251
253
}
252
254
253
255
static void dequantize_row_q8_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
0 commit comments