Skip to content

Commit 9bdb8fc

Browse files
LucasWilkinsonvrdn-23
authored andcommitted
[Kernel] Add Exllama as a backend for compressed-tensors (vllm-project#9395)
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
1 parent e9cc271 commit 9bdb8fc

File tree

7 files changed

+173
-16
lines changed

7 files changed

+173
-16
lines changed

vllm/envs.py

+9
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
VLLM_SKIP_P2P_CHECK: bool = False
6767
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
6868
VLLM_TORCH_COMPILE_LEVEL: int = 0
69+
VLLM_DISABLED_KERNELS: List[str] = []
6970

7071

7172
def get_default_cache_root():
@@ -430,6 +431,14 @@ def get_default_config_root():
430431
"VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1":
431432
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0"
432433
) == "1",
434+
435+
# List of quantization kernels that should be disabled, used for testing
436+
# and performance comparisons. Currently only affects MPLinearKernel
437+
# selection
438+
# (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel)
439+
"VLLM_DISABLED_KERNELS":
440+
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
441+
"VLLM_DISABLED_KERNELS"].split(","),
433442
}
434443

435444
# end-env-vars-definition

vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py

+4
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def __init__(self,
4242
self.config = c
4343
self.w_q_name = w_q_param_name
4444
self.w_s_name = w_s_param_name
45+
if c.zero_points:
46+
assert w_zp_param_name is not None
47+
if c.has_g_idx:
48+
assert w_gidx_param_name is not None
4549
self.w_zp_name = w_zp_param_name
4650
self.w_gidx_name = w_gidx_param_name
4751

vllm/model_executor/layers/quantization/kernels/__init__.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
import os
21
from typing import List, Optional, Type
32

3+
import vllm.envs as envs
4+
from vllm.model_executor.layers.quantization.kernels.exllama import (
5+
ExllamaLinearKernel)
46
from vllm.model_executor.layers.quantization.kernels.machete import (
57
MacheteLinearKernel)
68
from vllm.model_executor.layers.quantization.kernels.marlin import (
@@ -13,6 +15,7 @@
1315
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
1416
MacheteLinearKernel,
1517
MarlinLinearKernel,
18+
ExllamaLinearKernel,
1619
]
1720

1821

@@ -45,8 +48,7 @@ def choose_mp_linear_kernel(
4548

4649
failure_reasons = []
4750
for kernel in _POSSIBLE_KERNELS:
48-
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
49-
.split(","):
51+
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
5052
failure_reasons.append(
5153
f' {kernel.__name__} disabled by environment variable')
5254
continue
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from typing import Optional, Tuple
2+
3+
import torch
4+
5+
from vllm import _custom_ops as ops
6+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
7+
pack_quantized_values_into_int32)
8+
from vllm.model_executor.parameter import (BasevLLMParameter,
9+
permute_param_layout_)
10+
from vllm.scalar_type import scalar_types
11+
12+
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
13+
14+
15+
class ExllamaLinearKernel(MPLinearKernel):
16+
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
17+
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
18+
# currently untested so not added to the list
19+
20+
@classmethod
21+
def get_min_capability(cls) -> int:
22+
return 60
23+
24+
@classmethod
25+
def can_implement(cls,
26+
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
27+
if c.has_g_idx and\
28+
c.partition_weight_shape[0] != c.full_weight_shape[0]:
29+
return False, "Act reordering currently not supported by Exllama, "\
30+
"when the input features are partitioned across "\
31+
"devices"
32+
33+
if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
34+
return False, "Output features must be a multiple of the pack " \
35+
"factor (32 / num_bits) so that we can correctly " \
36+
"pack the zero points"
37+
38+
if c.act_type != torch.float16:
39+
return False, "Exllama only supports float16 activations"
40+
41+
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
42+
return False, f"Quant type ({c.weight_type}) not supported by "\
43+
"Exllama, supported types are: "\
44+
f"{cls.SUPPORTED_QUANT_TYPES}"
45+
46+
if c.full_weight_shape[0] % c.group_size != 0:
47+
return False, f"Group size ({c.group_size}) does not evenly divide"\
48+
" the number of input features "\
49+
f"({c.full_weight_shape[0]})"
50+
51+
return True, None
52+
53+
def process_weights_after_loading(self, layer: torch.nn.Module):
54+
c = self.config
55+
56+
# For Exllama, we need to set a zero-point tensor if there is not one
57+
if not c.zero_points:
58+
self.w_zp_name = "qzeros"
59+
device = getattr(layer, self.w_q_name).device
60+
groups = c.partition_weight_shape[0] // c.group_size
61+
out_features = c.partition_weight_shape[1]
62+
63+
if c.weight_type.has_bias():
64+
# if the type has a bias we have to create a zeros tensor that
65+
# contains the bias values repeated for each group (-1 due to
66+
# a bug in the original GPTQ checkpoint format leading to
67+
# exllama kernel adding 1 to the zero points during inference)
68+
# Documentation of the bug can be found here:
69+
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
70+
zeros = torch.full((groups, out_features),
71+
c.weight_type.bias - 1,
72+
dtype=torch.int32,
73+
device=device)
74+
else:
75+
raise NotImplementedError(
76+
"A 0 zero-point is not supported by Exllama due to "
77+
"a bug in the original GPTQ checkpoint format leading to "
78+
"exllama kernel adding 1 to the zero points during "
79+
"inference")
80+
zeros = pack_quantized_values_into_int32(zeros,
81+
c.weight_type,
82+
packed_dim=1)
83+
setattr(layer, self.w_zp_name,
84+
torch.nn.Parameter(zeros, requires_grad=False))
85+
86+
if c.has_g_idx:
87+
88+
def transform_w_g_idx(x):
89+
# Exllama wants the permutation array instead of the group
90+
# indices
91+
return torch.argsort(x).to(torch.int)
92+
93+
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx)
94+
else:
95+
self.w_gidx_name = "g_idx"
96+
empty_g_idx = torch.nn.Parameter(torch.empty((0, ),
97+
dtype=torch.int,
98+
device=device),
99+
requires_grad=False)
100+
setattr(layer, self.w_gidx_name, empty_g_idx)
101+
102+
def transform_w_q(x):
103+
assert isinstance(x, BasevLLMParameter)
104+
assert self.w_gidx_name is not None
105+
g_idx = getattr(layer, self.w_gidx_name)
106+
107+
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
108+
x_cont = x.data.contiguous()
109+
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
110+
return x_cont
111+
112+
def transform_w_s(x):
113+
assert isinstance(x, BasevLLMParameter)
114+
permute_param_layout_(x, input_dim=0, output_dim=1)
115+
x.data = x.data.contiguous()
116+
return x.to(dtype=c.act_type)
117+
118+
# Repack weights and scales for Machete
119+
self._transform_param(layer, self.w_q_name, transform_w_q)
120+
self._transform_param(layer, self.w_s_name, transform_w_s)
121+
122+
def apply_weights(self,
123+
layer: torch.nn.Module,
124+
x: torch.Tensor,
125+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
126+
c = self.config
127+
128+
x_2d = x.reshape(-1, x.shape[-1])
129+
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
130+
131+
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
132+
133+
assert w_zp is not None, "Zero points are required by Exllama"
134+
assert w_g_idx is not None, "Group index is required by Exllama"
135+
output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True,
136+
c.weight_type.size_bits)
137+
138+
if bias is not None:
139+
output.add_(bias)
140+
return output.reshape(out_shape)

vllm/model_executor/layers/quantization/kernels/machete.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
99
query_machete_supported_quant_types)
1010
from vllm.model_executor.layers.quantization.utils.quant_utils import (
11-
pack_weights_into_int32, unpack_weights_into_int32)
11+
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
1212
from vllm.model_executor.parameter import (BasevLLMParameter,
1313
permute_param_layout_)
1414

@@ -71,13 +71,13 @@ def transform_w_q(x):
7171
assert isinstance(x, BasevLLMParameter)
7272
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
7373
if c.has_g_idx:
74-
x_unpacked = unpack_weights_into_int32(x.data,
75-
c.weight_type,
76-
packed_dim=0)
74+
x_unpacked = unpack_quantized_values_into_int32(x.data,
75+
c.weight_type,
76+
packed_dim=0)
7777
x_perm = x_unpacked[perm, :]
78-
x.data = pack_weights_into_int32(x_perm,
79-
c.weight_type,
80-
packed_dim=0)
78+
x.data = pack_quantized_values_into_int32(x_perm,
79+
c.weight_type,
80+
packed_dim=0)
8181
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
8282
self.config.weight_type)
8383
return x

vllm/model_executor/layers/quantization/utils/quant_utils.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
}
2121

2222

23-
def pack_weights_into_int32(w_q: torch.Tensor,
24-
wtype: ScalarType,
25-
packed_dim: int = 0):
23+
def pack_quantized_values_into_int32(w_q: torch.Tensor,
24+
wtype: ScalarType,
25+
packed_dim: int = 0):
2626
# move dim to pack to the end
2727
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
2828
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
@@ -42,9 +42,9 @@ def pack_weights_into_int32(w_q: torch.Tensor,
4242
return res.permute(inv_perm)
4343

4444

45-
def unpack_weights_into_int32(w_q: torch.Tensor,
46-
wtype: ScalarType,
47-
packed_dim: int = 0):
45+
def unpack_quantized_values_into_int32(w_q: torch.Tensor,
46+
wtype: ScalarType,
47+
packed_dim: int = 0):
4848
# move dim to pack to the end
4949
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
5050
inv_perm = tuple(perm.index(i) for i in range(len(perm)))

vllm/scalar_type.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class scalar_types:
2727
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
2828

2929
# "gptq" types
30+
uint2b2 = ScalarType.uint(2, 2)
31+
uint3b4 = ScalarType.uint(3, 4)
3032
uint4b8 = ScalarType.uint(4, 8)
3133
uint8b128 = ScalarType.uint(8, 128)
3234

0 commit comments

Comments
 (0)