Skip to content

Commit d2a8eac

Browse files
Isotr0pyMengqingCao
authored andcommitted
[Bugfix][Kernel] Add IQ1_M quantization implementation to GGUF kernel (vllm-project#8357)
1 parent 39fd215 commit d2a8eac

File tree

8 files changed

+548
-162
lines changed

8 files changed

+548
-162
lines changed

csrc/quantization/gguf/dequantize.cuh

+46-9
Original file line numberDiff line numberDiff line change
@@ -353,18 +353,47 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
353353
template<typename dst_t>
354354
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
355355

356-
const int i = blockIdx.x;
356+
const int64_t i = blockIdx.x;
357357
const block_iq1_s * x = (const block_iq1_s *) vx;
358358

359-
const int tid = threadIdx.x;
360-
const int il = tid/8; // 0...3
361-
const int ib = tid%8; // 0...7
359+
const int64_t tid = threadIdx.x;
360+
const int64_t il = tid/8; // 0...3
361+
const int64_t ib = tid%8; // 0...7
362+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
363+
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
364+
const float d = __half2float(x[i].d) * (2*((x[i].qh[ib] >> 12) & 7) + 1);
365+
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
366+
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
367+
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
368+
grid32[0] &= 0x0f0f0f0f;
369+
for (int j = 0; j < 8; ++j) {
370+
y[j] = __float2half(d * (q[j] + delta));
371+
}
372+
}
373+
374+
template<typename dst_t>
375+
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
376+
377+
const int64_t i = blockIdx.x;
378+
const block_iq1_m * x = (const block_iq1_m *) vx;
379+
380+
const int64_t tid = threadIdx.x;
381+
const int64_t il = tid/8; // 0...3
382+
const int64_t ib = tid%8; // 0...7
362383
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
363-
const int i8 = 4*ib+il;
364-
uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
365-
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
366-
const float d = __half2float(x[i].d) * (2*(h & 7) + 1);
367-
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j]);
384+
const uint16_t * sc = (const uint16_t *)x[i].scales;
385+
iq1m_scale_t scale;
386+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
387+
const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
388+
const float d = __half2float(scale.f16) * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
389+
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
390+
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
391+
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
392+
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
393+
grid32[0] &= 0x0f0f0f0f;
394+
for (int j = 0; j < 8; ++j) {
395+
y[j] = __float2half(d * (q[j] + delta));
396+
}
368397
}
369398

370399
template<typename dst_t>
@@ -475,6 +504,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, c
475504
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
476505
}
477506

507+
template<typename dst_t>
508+
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
509+
const int nb = k / QK_K;
510+
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
511+
}
512+
478513
template<typename dst_t>
479514
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
480515
const int nb = (k + QK_K - 1) / QK_K;
@@ -525,6 +560,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
525560
return dequantize_row_iq2_s_cuda;
526561
case 23:
527562
return dequantize_row_iq4_xs_cuda;
563+
case 29:
564+
return dequantize_row_iq1_m_cuda;
528565
default:
529566
return nullptr;
530567
}

csrc/quantization/gguf/ggml-common.h

+277-131
Large diffs are not rendered by default.

csrc/quantization/gguf/gguf_kernel.cu

+5
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
166166
(void*)quant_X.data_ptr(),
167167
(half*)Y.data_ptr(), col, row, stream);
168168
break;
169+
case 29:
170+
mul_mat_vec_iq1_m_q8_1_cuda((void*)W.data_ptr(),
171+
(void*)quant_X.data_ptr(),
172+
(half*)Y.data_ptr(), col, row, stream);
173+
break;
169174
}
170175
return Y;
171176
}

csrc/quantization/gguf/mmvq.cuh

+8
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,14 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half *
157157
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
158158
}
159159

160+
static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
161+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
162+
const dim3 block_nums(block_num_y, 1, 1);
163+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
164+
mul_mat_vec_q<QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
165+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
166+
}
167+
160168
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
161169
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
162170
const dim3 block_nums(block_num_y, 1, 1);

csrc/quantization/gguf/vecdotq.cuh

+81-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/vecdotq.cuh
22
// and https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
3+
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
4+
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
5+
6+
int x32 = x16[2*i32 + 0] << 0;
7+
x32 |= x16[2*i32 + 1] << 16;
8+
9+
return x32;
10+
}
11+
12+
static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) {
13+
return ((const int *) x)[i32]; // assume at least 4 byte alignment
14+
}
15+
316
static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
417
const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
518
int x32 = 0;
@@ -1658,28 +1671,76 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
16581671

16591672
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
16601673
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
1661-
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
16621674
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
16631675

1664-
const int ib32 = iqs;
1665-
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
1666-
const uint8_t h1 = bq1->scales[2*ib32+0];
1667-
const uint8_t h2 = bq1->scales[2*ib32+1];
1668-
const int * q8 = (const int *)bq8_1[ib32].qs;
1669-
const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
1670-
const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
1671-
const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
1672-
const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
1673-
for (int j = 0; j < 2; ++j) {
1674-
sumi1 = __dp4a(q8[j+0], grid1[j], sumi1);
1675-
sumi2 = __dp4a(q8[j+2], grid2[j], sumi2);
1676-
sumi3 = __dp4a(q8[j+4], grid3[j], sumi3);
1677-
sumi4 = __dp4a(q8[j+6], grid4[j], sumi4);
1678-
}
1679-
const float d = __half2float(bq1->d) * __low2float(bq8_1[ib32].ds);
1680-
return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
1681-
sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));
1682-
#endif
1676+
const int qs_packed = get_int_b2(bq1->qs, iqs);
1677+
const uint8_t * qs = (const uint8_t *) &qs_packed;
1678+
1679+
const int qh = bq1->qh[iqs];
1680+
1681+
int sumi = 0;
1682+
#pragma unroll
1683+
for (int l0 = 0; l0 < 8; l0 += 2) {
1684+
const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)];
1685+
1686+
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
1687+
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
1688+
1689+
const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
1690+
const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
1691+
1692+
sumi = __dp4a(grid0, u0, sumi);
1693+
sumi = __dp4a(grid1, u1, sumi);
1694+
}
1695+
1696+
const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1);
1697+
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
1698+
const float2 ds = __half22float2(bq8_1[iqs].ds);
1699+
return d1q * (ds.x*sumi + ds.y*delta);
1700+
}
1701+
1702+
static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
1703+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
1704+
1705+
const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
1706+
1707+
const int qs_packed = get_int_b4(bq1->qs, iqs);
1708+
const uint8_t * qs = (const uint8_t *) &qs_packed;
1709+
1710+
int sumi[2] = {0};
1711+
float sumf[2] = {0.0f};
1712+
#pragma unroll
1713+
for (int l0 = 0; l0 < 8; l0 += 2) {
1714+
const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2));
1715+
1716+
const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)];
1717+
1718+
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
1719+
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
1720+
1721+
const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
1722+
const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
1723+
1724+
sumi[l0/4] = __dp4a(grid0, u0, sumi[l0/4]);
1725+
sumi[l0/4] = __dp4a(grid1, u1, sumi[l0/4]);
1726+
1727+
const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08);
1728+
int sumy = 0;
1729+
sumy = __dp4a(u0, 0x01010101, sumy);
1730+
sumy = __dp4a(u1, 0x01010101, sumy);
1731+
sumf[l0/4] += delta*sumy;
1732+
}
1733+
1734+
const uint16_t * sc = (const uint16_t *) bq1->scales;
1735+
1736+
iq1m_scale_t scale;
1737+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000);
1738+
const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds);
1739+
1740+
const int tmp = sc[iqs/2] >> (6*(iqs%2));
1741+
const int sc0 = 2*((tmp >> 0) & 0x07) + 1;
1742+
const int sc1 = 2*((tmp >> 3) & 0x07) + 1;
1743+
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
16831744
}
16841745

16851746
static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values,

requirements-common.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
2424
partial-json-parser # used for parsing partial JSON outputs
2525
pyzmq
2626
msgspec
27-
gguf == 0.9.1
27+
gguf == 0.10.0
2828
importlib_metadata
2929
mistral_common >= 1.4.0
3030
pyyaml

tests/kernels/test_gguf.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from pathlib import Path
2+
from typing import List
3+
4+
import pytest
5+
import torch
6+
from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
7+
from huggingface_hub import snapshot_download
8+
9+
import vllm._custom_ops as ops
10+
11+
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
12+
13+
14+
def get_gguf_sample_tensors(
15+
hidden_size: int,
16+
quant_type: GGMLQuantizationType) -> List[ReaderTensor]:
17+
sample_dir = GGUF_SAMPLE
18+
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
19+
sample_file = Path(sample_dir) / filename
20+
return GGUFReader(sample_file).tensors
21+
22+
23+
DTYPES = [torch.half]
24+
# Hidden_size for testing, must match the sample file in HF repo,
25+
# we have `hidden_size = 256, 1024` for test in HF repo currently.
26+
HIDDEN_SIZES = [256, 1024]
27+
NUM_TOKENS = [7, 83, 128, 2048] # Arbitrary values for testing
28+
SEEDS = [0]
29+
QUANT_TYPES = [
30+
# i-matrix
31+
GGMLQuantizationType.IQ1_M,
32+
GGMLQuantizationType.IQ1_S,
33+
GGMLQuantizationType.IQ2_S,
34+
GGMLQuantizationType.IQ2_XS,
35+
GGMLQuantizationType.IQ3_S,
36+
GGMLQuantizationType.IQ3_XXS,
37+
GGMLQuantizationType.IQ4_NL,
38+
GGMLQuantizationType.IQ4_XS,
39+
# k-quants
40+
GGMLQuantizationType.Q2_K,
41+
GGMLQuantizationType.Q3_K,
42+
GGMLQuantizationType.Q4_K,
43+
GGMLQuantizationType.Q5_K,
44+
GGMLQuantizationType.Q6_K,
45+
# standard quantization
46+
GGMLQuantizationType.Q4_0,
47+
GGMLQuantizationType.Q5_0,
48+
GGMLQuantizationType.Q8_0,
49+
]
50+
51+
52+
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
53+
@pytest.mark.parametrize("dtype", DTYPES)
54+
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
55+
@torch.inference_mode()
56+
def test_dequantize(hidden_size: int, dtype: torch.dtype,
57+
quant_type: GGMLQuantizationType):
58+
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
59+
for tensor in tensors:
60+
shape_str = tensor.name.split("_")[-1]
61+
shape = map(int, shape_str.split("x"))
62+
63+
ref_output = torch.tensor(dequantize(tensor.data, quant_type),
64+
device="cuda").to(dtype)
65+
output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"),
66+
quant_type, *list(shape)).to(dtype)
67+
68+
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)
69+
70+
71+
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
72+
@pytest.mark.parametrize("dtype", DTYPES)
73+
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
74+
@torch.inference_mode()
75+
def test_mmvq(hidden_size: int, dtype: torch.dtype,
76+
quant_type: GGMLQuantizationType):
77+
torch.cuda.manual_seed_all(0)
78+
79+
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
80+
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
81+
for tensor in tensors:
82+
weight = torch.tensor(dequantize(tensor.data, quant_type),
83+
device="cuda").to(dtype)
84+
ref_output = x @ weight.T
85+
86+
qweight = torch.tensor(tensor.data, device="cuda")
87+
output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type,
88+
qweight.shape[0]).to(dtype)
89+
90+
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
91+
92+
93+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
94+
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
95+
@pytest.mark.parametrize("dtype", DTYPES)
96+
@pytest.mark.parametrize(
97+
"quant_type",
98+
[
99+
# k-quants
100+
GGMLQuantizationType.Q2_K,
101+
GGMLQuantizationType.Q3_K,
102+
GGMLQuantizationType.Q4_K,
103+
GGMLQuantizationType.Q5_K,
104+
GGMLQuantizationType.Q6_K,
105+
# standard quants
106+
GGMLQuantizationType.Q4_0,
107+
GGMLQuantizationType.Q5_0,
108+
GGMLQuantizationType.Q8_0,
109+
])
110+
@torch.inference_mode()
111+
def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
112+
quant_type: GGMLQuantizationType):
113+
torch.cuda.manual_seed_all(0)
114+
115+
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
116+
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
117+
for tensor in tensors:
118+
weight = torch.tensor(dequantize(tensor.data, quant_type),
119+
device="cuda").to(dtype)
120+
ref_output = x @ weight.T
121+
122+
qweight = torch.tensor(tensor.data, device="cuda")
123+
output = ops.ggml_mul_mat_a8(qweight, x, quant_type,
124+
qweight.shape[0]).to(dtype)
125+
126+
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)

vllm/model_executor/layers/quantization/gguf.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ def get_scaled_act_names(self) -> List[str]:
5555
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
5656
qweight_type: int) -> torch.Tensor:
5757
# use dequantize mulmat for IQmatrix, mmq for k-quants
58-
if qweight_type >= 16:
58+
if x.shape[0] == 1:
59+
# enable mmvq in contiguous batching
60+
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
61+
elif qweight_type >= 16:
5962
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
6063
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
6164
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)

0 commit comments

Comments
 (0)