|
| 1 | +from pathlib import Path |
| 2 | +from typing import List |
| 3 | + |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | +from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize |
| 7 | +from huggingface_hub import snapshot_download |
| 8 | + |
| 9 | +import vllm._custom_ops as ops |
| 10 | + |
| 11 | +GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample") |
| 12 | + |
| 13 | + |
| 14 | +def get_gguf_sample_tensors( |
| 15 | + hidden_size: int, |
| 16 | + quant_type: GGMLQuantizationType) -> List[ReaderTensor]: |
| 17 | + sample_dir = GGUF_SAMPLE |
| 18 | + filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" |
| 19 | + sample_file = Path(sample_dir) / filename |
| 20 | + return GGUFReader(sample_file).tensors |
| 21 | + |
| 22 | + |
| 23 | +DTYPES = [torch.half] |
| 24 | +# Hidden_size for testing, must match the sample file in HF repo, |
| 25 | +# we have `hidden_size = 256, 1024` for test in HF repo currently. |
| 26 | +HIDDEN_SIZES = [256, 1024] |
| 27 | +NUM_TOKENS = [7, 83, 128, 2048] # Arbitrary values for testing |
| 28 | +SEEDS = [0] |
| 29 | +QUANT_TYPES = [ |
| 30 | + # i-matrix |
| 31 | + GGMLQuantizationType.IQ1_M, |
| 32 | + GGMLQuantizationType.IQ1_S, |
| 33 | + GGMLQuantizationType.IQ2_S, |
| 34 | + GGMLQuantizationType.IQ2_XS, |
| 35 | + GGMLQuantizationType.IQ3_S, |
| 36 | + GGMLQuantizationType.IQ3_XXS, |
| 37 | + GGMLQuantizationType.IQ4_NL, |
| 38 | + GGMLQuantizationType.IQ4_XS, |
| 39 | + # k-quants |
| 40 | + GGMLQuantizationType.Q2_K, |
| 41 | + GGMLQuantizationType.Q3_K, |
| 42 | + GGMLQuantizationType.Q4_K, |
| 43 | + GGMLQuantizationType.Q5_K, |
| 44 | + GGMLQuantizationType.Q6_K, |
| 45 | + # standard quantization |
| 46 | + GGMLQuantizationType.Q4_0, |
| 47 | + GGMLQuantizationType.Q5_0, |
| 48 | + GGMLQuantizationType.Q8_0, |
| 49 | +] |
| 50 | + |
| 51 | + |
| 52 | +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) |
| 53 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 54 | +@pytest.mark.parametrize("quant_type", QUANT_TYPES) |
| 55 | +@torch.inference_mode() |
| 56 | +def test_dequantize(hidden_size: int, dtype: torch.dtype, |
| 57 | + quant_type: GGMLQuantizationType): |
| 58 | + tensors = get_gguf_sample_tensors(hidden_size, quant_type) |
| 59 | + for tensor in tensors: |
| 60 | + shape_str = tensor.name.split("_")[-1] |
| 61 | + shape = map(int, shape_str.split("x")) |
| 62 | + |
| 63 | + ref_output = torch.tensor(dequantize(tensor.data, quant_type), |
| 64 | + device="cuda").to(dtype) |
| 65 | + output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"), |
| 66 | + quant_type, *list(shape)).to(dtype) |
| 67 | + |
| 68 | + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2) |
| 69 | + |
| 70 | + |
| 71 | +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) |
| 72 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 73 | +@pytest.mark.parametrize("quant_type", QUANT_TYPES) |
| 74 | +@torch.inference_mode() |
| 75 | +def test_mmvq(hidden_size: int, dtype: torch.dtype, |
| 76 | + quant_type: GGMLQuantizationType): |
| 77 | + torch.cuda.manual_seed_all(0) |
| 78 | + |
| 79 | + tensors = get_gguf_sample_tensors(hidden_size, quant_type) |
| 80 | + x = torch.rand((1, hidden_size), dtype=dtype, device="cuda") |
| 81 | + for tensor in tensors: |
| 82 | + weight = torch.tensor(dequantize(tensor.data, quant_type), |
| 83 | + device="cuda").to(dtype) |
| 84 | + ref_output = x @ weight.T |
| 85 | + |
| 86 | + qweight = torch.tensor(tensor.data, device="cuda") |
| 87 | + output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, |
| 88 | + qweight.shape[0]).to(dtype) |
| 89 | + |
| 90 | + torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) |
| 91 | + |
| 92 | + |
| 93 | +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| 94 | +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) |
| 95 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 96 | +@pytest.mark.parametrize( |
| 97 | + "quant_type", |
| 98 | + [ |
| 99 | + # k-quants |
| 100 | + GGMLQuantizationType.Q2_K, |
| 101 | + GGMLQuantizationType.Q3_K, |
| 102 | + GGMLQuantizationType.Q4_K, |
| 103 | + GGMLQuantizationType.Q5_K, |
| 104 | + GGMLQuantizationType.Q6_K, |
| 105 | + # standard quants |
| 106 | + GGMLQuantizationType.Q4_0, |
| 107 | + GGMLQuantizationType.Q5_0, |
| 108 | + GGMLQuantizationType.Q8_0, |
| 109 | + ]) |
| 110 | +@torch.inference_mode() |
| 111 | +def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, |
| 112 | + quant_type: GGMLQuantizationType): |
| 113 | + torch.cuda.manual_seed_all(0) |
| 114 | + |
| 115 | + tensors = get_gguf_sample_tensors(hidden_size, quant_type) |
| 116 | + x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda") |
| 117 | + for tensor in tensors: |
| 118 | + weight = torch.tensor(dequantize(tensor.data, quant_type), |
| 119 | + device="cuda").to(dtype) |
| 120 | + ref_output = x @ weight.T |
| 121 | + |
| 122 | + qweight = torch.tensor(tensor.data, device="cuda") |
| 123 | + output = ops.ggml_mul_mat_a8(qweight, x, quant_type, |
| 124 | + qweight.shape[0]).to(dtype) |
| 125 | + |
| 126 | + torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) |
0 commit comments