From 695e85e70e9c12b173c57a789f7b6dd03d022206 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 16 Oct 2024 00:34:04 +0000 Subject: [PATCH 1/6] initial commit --- vllm/envs.py | 9 ++ .../quantization/kernels/MPLinearKernel.py | 4 + .../layers/quantization/kernels/__init__.py | 7 +- .../layers/quantization/kernels/exllama.py | 146 ++++++++++++++++++ .../layers/quantization/kernels/machete.py | 14 +- .../layers/quantization/utils/quant_utils.py | 12 +- vllm/scalar_type.py | 2 + 7 files changed, 179 insertions(+), 15 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/exllama.py diff --git a/vllm/envs.py b/vllm/envs.py index 8b541e5b78c01..57bd62c4e9e8b 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -66,6 +66,7 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 + VLLM_DISABLED_KERNELS: List[str] = [] def get_default_cache_root(): @@ -430,6 +431,14 @@ def get_default_config_root(): "VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1": lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0" ) == "1", + + # List of quantization kernels that should be disabled, used for testing + # and performance comparisons. Currently only affects MPLinearKernel + # selection + # (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel) + "VLLM_DISABLED_KERNELS": + lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ + "VLLM_DISABLED_KERNELS"].split(","), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py index fe50c4930d043..b04612a9b00d9 100644 --- a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py @@ -42,6 +42,10 @@ def __init__(self, self.config = c self.w_q_name = w_q_param_name self.w_s_name = w_s_param_name + if c.zero_points: + assert w_zp_param_name is not None + if c.has_g_idx: + assert w_gidx_param_name is not None self.w_zp_name = w_zp_param_name self.w_gidx_name = w_gidx_param_name diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py index 47591c2aa644e..35b378e0954e0 100644 --- a/vllm/model_executor/layers/quantization/kernels/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -1,10 +1,13 @@ import os from typing import List, Optional, Type +import vllm.envs as envs from vllm.model_executor.layers.quantization.kernels.machete import ( MacheteLinearKernel) from vllm.model_executor.layers.quantization.kernels.marlin import ( MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.exllama import ( + ExllamaLinearKernel) from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( MPLinearKernel, MPLinearLayerConfig) from vllm.platforms import current_platform @@ -13,6 +16,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ MacheteLinearKernel, MarlinLinearKernel, + ExllamaLinearKernel, ] @@ -45,8 +49,7 @@ def choose_mp_linear_kernel( failure_reasons = [] for kernel in _POSSIBLE_KERNELS: - if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ - .split(","): + if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: failure_reasons.append( f' {kernel.__name__} disabled by environment variable') continue diff --git a/vllm/model_executor/layers/quantization/kernels/exllama.py b/vllm/model_executor/layers/quantization/kernels/exllama.py new file mode 100644 index 0000000000000..f3c7fcd8a268e --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/exllama.py @@ -0,0 +1,146 @@ +from functools import partial +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.scalar_type import scalar_types +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_quantized_values_into_int32, unpack_quantized_values_into_int32) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class ExllamaLinearKernel(MPLinearKernel): + SUPPORTED_QUANT_TYPES = [ + scalar_types.uint2b2, + scalar_types.uint3b4, + scalar_types.uint4b8, + scalar_types.uint8b128 + ] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Exllama, "\ + "when the input features are partitioned across "\ + "devices" + + if c.act_type != torch.float16: + return False, "Exllama only supports float16 activations" + + if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Exllama, supported types are: "\ + f"{cls.SUPPORTED_QUANT_TYPES}" + + if c.full_weight_shape[0] % c.group_size != 0: + return False, f"Group size ({c.group_size}) does not evenly divide"\ + " the number of input features "\ + f"({c.full_weight_shape[0]})" + + return True, None + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + # For Exllama, we need to set a zero-point tensor if there is not one + if not c.zero_points: + self.w_zp_name = "qzeros" + device = getattr(layer, self.w_q_name).device + groups = c.full_weight_shape[0] // c.group_size + out_features = c.partition_weight_shape[1] + + if c.weight_type.has_bias(): + # if the type has a bias we have to create a zeros tensor that + # contains the bias values repeated for each group (-1 due to + # a bug in the original GPTQ checkpoint format leading to + # exllama kernel adding 1 to the zero points during inference) + # Documentation of the bug can be found here: + # https://garden.danieldk.eu/GPTQ-Checkpoint-Format + zeros = torch.full( + (groups, out_features), + c.weight_type.bias - 1, + dtype=torch.int32, + device=device) + else: + raise NotImplementedError( + "A 0 zero-point is not supported by Exllama due to " + "a bug in the original GPTQ checkpoint format leading to " + "exllama kernel adding 1 to the zero points during " + "inference") + print("zeros", zeros.shape) + zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1) + print("zeros_packed", zeros.shape) + setattr(layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False)) + + if c.has_g_idx: + def transform_w_g_idx(x): + # Exllama wants the permutation array instead of the group + # incdices + return torch.argsort(x).to(torch.int) + self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) + else: + self.w_gidx_name = "g_idx" + empty_g_idx = torch.nn.Parameter( + torch.empty((0, ), + dtype=torch.int, + device=device), + requires_grad=False) + setattr(layer, self.w_gidx_name, empty_g_idx) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + assert self.w_gidx_name is not None + g_idx = getattr(layer, self.w_gidx_name) + + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x_cont = x.data.contiguous() + ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits) + return x_cont + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x.to(dtype=c.act_type) + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) + + #print(w_q.shape, w_s.shape, w_zp.shape, w_g_idx.shape) + + assert w_zp is not None, "Zero points are not supported by Exllama" + assert w_g_idx is not None, "Group index is required by Exllama" + output = ops.gptq_gemm( + x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits) + + if bias is not None: + output.add_(bias) + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/machete.py index fa39cb511528e..e5696d08f30f5 100644 --- a/vllm/model_executor/layers/quantization/kernels/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/machete.py @@ -8,7 +8,7 @@ MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, query_machete_supported_quant_types) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_weights_into_int32, unpack_weights_into_int32) + pack_quantized_values_into_int32, unpack_quantized_values_into_int32) from vllm.model_executor.parameter import (BasevLLMParameter, permute_param_layout_) @@ -71,13 +71,13 @@ def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) if c.has_g_idx: - x_unpacked = unpack_weights_into_int32(x.data, - c.weight_type, - packed_dim=0) + x_unpacked = unpack_quantized_values_into_int32(x.data, + c.weight_type, + packed_dim=0) x_perm = x_unpacked[perm, :] - x.data = pack_weights_into_int32(x_perm, - c.weight_type, - packed_dim=0) + x.data = pack_quantized_values_into_int32(x_perm, + c.weight_type, + packed_dim=0) x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), self.config.weight_type) return x diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 833d00073564e..c217f5ca620a1 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -20,9 +20,9 @@ } -def pack_weights_into_int32(w_q: torch.Tensor, - wtype: ScalarType, - packed_dim: int = 0): +def pack_quantized_values_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): # move dim to pack to the end perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) inv_perm = tuple(perm.index(i) for i in range(len(perm))) @@ -42,9 +42,9 @@ def pack_weights_into_int32(w_q: torch.Tensor, return res.permute(inv_perm) -def unpack_weights_into_int32(w_q: torch.Tensor, - wtype: ScalarType, - packed_dim: int = 0): +def unpack_quantized_values_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): # move dim to pack to the end perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) inv_perm = tuple(perm.index(i) for i in range(len(perm))) diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index eb491dd1554a8..373151a5311e5 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -27,6 +27,8 @@ class scalar_types: float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value) # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) uint4b8 = ScalarType.uint(4, 8) uint8b128 = ScalarType.uint(8, 128) From d711bac197a7b73a0d85b53598efdeb6d54e0f75 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 16 Oct 2024 00:42:15 +0000 Subject: [PATCH 2/6] format --- vllm/envs.py | 4 +- .../layers/quantization/kernels/__init__.py | 5 +- .../layers/quantization/kernels/exllama.py | 62 +++++++++---------- 3 files changed, 33 insertions(+), 38 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 57bd62c4e9e8b..45a9999610f6a 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -431,10 +431,10 @@ def get_default_config_root(): "VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1": lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0" ) == "1", - + # List of quantization kernels that should be disabled, used for testing # and performance comparisons. Currently only affects MPLinearKernel - # selection + # selection # (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel) "VLLM_DISABLED_KERNELS": lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py index 35b378e0954e0..94a3dc2584d6b 100644 --- a/vllm/model_executor/layers/quantization/kernels/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -1,13 +1,12 @@ -import os from typing import List, Optional, Type import vllm.envs as envs +from vllm.model_executor.layers.quantization.kernels.exllama import ( + ExllamaLinearKernel) from vllm.model_executor.layers.quantization.kernels.machete import ( MacheteLinearKernel) from vllm.model_executor.layers.quantization.kernels.marlin import ( MarlinLinearKernel) -from vllm.model_executor.layers.quantization.kernels.exllama import ( - ExllamaLinearKernel) from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( MPLinearKernel, MPLinearLayerConfig) from vllm.platforms import current_platform diff --git a/vllm/model_executor/layers/quantization/kernels/exllama.py b/vllm/model_executor/layers/quantization/kernels/exllama.py index f3c7fcd8a268e..3672971d8302c 100644 --- a/vllm/model_executor/layers/quantization/kernels/exllama.py +++ b/vllm/model_executor/layers/quantization/kernels/exllama.py @@ -1,28 +1,21 @@ -from functools import partial from typing import Optional, Tuple import torch from vllm import _custom_ops as ops -from vllm.scalar_type import scalar_types -from vllm.model_executor.layers.quantization.utils.machete_utils import ( - MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, - query_machete_supported_quant_types) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_quantized_values_into_int32, unpack_quantized_values_into_int32) + pack_quantized_values_into_int32) from vllm.model_executor.parameter import (BasevLLMParameter, permute_param_layout_) +from vllm.scalar_type import scalar_types from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig class ExllamaLinearKernel(MPLinearKernel): - SUPPORTED_QUANT_TYPES = [ - scalar_types.uint2b2, - scalar_types.uint3b4, - scalar_types.uint4b8, - scalar_types.uint8b128 - ] + SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] + # In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but + # currently untested so not added to the list @classmethod def get_min_capability(cls) -> int: @@ -36,7 +29,7 @@ def can_implement(cls, return False, "Act reordering currently not supported by Exllama, "\ "when the input features are partitioned across "\ "devices" - + if c.act_type != torch.float16: return False, "Exllama only supports float16 activations" @@ -64,19 +57,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module): device = getattr(layer, self.w_q_name).device groups = c.full_weight_shape[0] // c.group_size out_features = c.partition_weight_shape[1] - + if c.weight_type.has_bias(): # if the type has a bias we have to create a zeros tensor that # contains the bias values repeated for each group (-1 due to - # a bug in the original GPTQ checkpoint format leading to + # a bug in the original GPTQ checkpoint format leading to # exllama kernel adding 1 to the zero points during inference) # Documentation of the bug can be found here: # https://garden.danieldk.eu/GPTQ-Checkpoint-Format - zeros = torch.full( - (groups, out_features), - c.weight_type.bias - 1, - dtype=torch.int32, - device=device) + zeros = torch.full((groups, out_features), + c.weight_type.bias - 1, + dtype=torch.int32, + device=device) else: raise NotImplementedError( "A 0 zero-point is not supported by Exllama due to " @@ -84,30 +76,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module): "exllama kernel adding 1 to the zero points during " "inference") print("zeros", zeros.shape) - zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1) + zeros = pack_quantized_values_into_int32(zeros, + c.weight_type, + packed_dim=1) print("zeros_packed", zeros.shape) - setattr(layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False)) + setattr(layer, self.w_zp_name, + torch.nn.Parameter(zeros, requires_grad=False)) if c.has_g_idx: + def transform_w_g_idx(x): # Exllama wants the permutation array instead of the group # incdices return torch.argsort(x).to(torch.int) + self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) else: self.w_gidx_name = "g_idx" - empty_g_idx = torch.nn.Parameter( - torch.empty((0, ), - dtype=torch.int, - device=device), - requires_grad=False) + empty_g_idx = torch.nn.Parameter(torch.empty((0, ), + dtype=torch.int, + device=device), + requires_grad=False) setattr(layer, self.w_gidx_name, empty_g_idx) def transform_w_q(x): assert isinstance(x, BasevLLMParameter) assert self.w_gidx_name is not None g_idx = getattr(layer, self.w_gidx_name) - + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) x_cont = x.data.contiguous() ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits) @@ -128,18 +124,18 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: c = self.config - + x_2d = x.reshape(-1, x.shape[-1]) out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) - + #print(w_q.shape, w_s.shape, w_zp.shape, w_g_idx.shape) - + assert w_zp is not None, "Zero points are not supported by Exllama" assert w_g_idx is not None, "Group index is required by Exllama" - output = ops.gptq_gemm( - x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits) + output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True, + c.weight_type.size_bits) if bias is not None: output.add_(bias) From da86e2c414814727f4713d07857a98b063d8578f Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 16 Oct 2024 04:26:11 +0000 Subject: [PATCH 3/6] fix tp > 1 --- vllm/model_executor/layers/quantization/kernels/exllama.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/exllama.py b/vllm/model_executor/layers/quantization/kernels/exllama.py index 3672971d8302c..a40426c9af646 100644 --- a/vllm/model_executor/layers/quantization/kernels/exllama.py +++ b/vllm/model_executor/layers/quantization/kernels/exllama.py @@ -55,7 +55,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module): if not c.zero_points: self.w_zp_name = "qzeros" device = getattr(layer, self.w_q_name).device - groups = c.full_weight_shape[0] // c.group_size + groups = c.partition_weight_shape[0] // c.group_size out_features = c.partition_weight_shape[1] if c.weight_type.has_bias(): @@ -75,16 +75,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module): "a bug in the original GPTQ checkpoint format leading to " "exllama kernel adding 1 to the zero points during " "inference") - print("zeros", zeros.shape) zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1) - print("zeros_packed", zeros.shape) setattr(layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False)) if c.has_g_idx: - def transform_w_g_idx(x): # Exllama wants the permutation array instead of the group # incdices From 6fb902d901d7699f780d368629b9f71a19bbb218 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 16 Oct 2024 05:16:30 +0000 Subject: [PATCH 4/6] update --- vllm/model_executor/layers/quantization/kernels/exllama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/quantization/kernels/exllama.py b/vllm/model_executor/layers/quantization/kernels/exllama.py index a40426c9af646..87ed4afe0e247 100644 --- a/vllm/model_executor/layers/quantization/kernels/exllama.py +++ b/vllm/model_executor/layers/quantization/kernels/exllama.py @@ -82,6 +82,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module): torch.nn.Parameter(zeros, requires_grad=False)) if c.has_g_idx: + def transform_w_g_idx(x): # Exllama wants the permutation array instead of the group # incdices From 0efaf4bb45ccdead988b759758661e2df67c77a2 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 16 Oct 2024 14:11:35 +0000 Subject: [PATCH 5/6] remove assumes comment --- vllm/model_executor/layers/quantization/kernels/exllama.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/exllama.py b/vllm/model_executor/layers/quantization/kernels/exllama.py index 87ed4afe0e247..4f492a9d1c31b 100644 --- a/vllm/model_executor/layers/quantization/kernels/exllama.py +++ b/vllm/model_executor/layers/quantization/kernels/exllama.py @@ -45,9 +45,6 @@ def can_implement(cls, return True, None - # note assumes that - # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} - # `weight_scale` is: {input_dim = 0, output_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module): c = self.config From 2814d18c833fe4323e788dd1f5774076ce85acb3 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 16 Oct 2024 17:20:40 +0000 Subject: [PATCH 6/6] review comments --- .../layers/quantization/kernels/exllama.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/exllama.py b/vllm/model_executor/layers/quantization/kernels/exllama.py index 4f492a9d1c31b..1d85d62ec83ee 100644 --- a/vllm/model_executor/layers/quantization/kernels/exllama.py +++ b/vllm/model_executor/layers/quantization/kernels/exllama.py @@ -30,6 +30,11 @@ def can_implement(cls, "when the input features are partitioned across "\ "devices" + if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0: + return False, "Output features must be a multiple of the pack " \ + "factor (32 / num_bits) so that we can correctly " \ + "pack the zero points" + if c.act_type != torch.float16: return False, "Exllama only supports float16 activations" @@ -82,7 +87,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module): def transform_w_g_idx(x): # Exllama wants the permutation array instead of the group - # incdices + # indices return torch.argsort(x).to(torch.int) self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) @@ -125,9 +130,7 @@ def apply_weights(self, w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) - #print(w_q.shape, w_s.shape, w_zp.shape, w_g_idx.shape) - - assert w_zp is not None, "Zero points are not supported by Exllama" + assert w_zp is not None, "Zero points are required by Exllama" assert w_g_idx is not None, "Group index is required by Exllama" output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits)