|
| 1 | +from typing import Optional, Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from vllm import _custom_ops as ops |
| 6 | +from vllm.model_executor.layers.quantization.utils.quant_utils import ( |
| 7 | + pack_quantized_values_into_int32) |
| 8 | +from vllm.model_executor.parameter import (BasevLLMParameter, |
| 9 | + permute_param_layout_) |
| 10 | +from vllm.scalar_type import scalar_types |
| 11 | + |
| 12 | +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig |
| 13 | + |
| 14 | + |
| 15 | +class ExllamaLinearKernel(MPLinearKernel): |
| 16 | + SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] |
| 17 | + # In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but |
| 18 | + # currently untested so not added to the list |
| 19 | + |
| 20 | + @classmethod |
| 21 | + def get_min_capability(cls) -> int: |
| 22 | + return 60 |
| 23 | + |
| 24 | + @classmethod |
| 25 | + def can_implement(cls, |
| 26 | + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: |
| 27 | + if c.has_g_idx and\ |
| 28 | + c.partition_weight_shape[0] != c.full_weight_shape[0]: |
| 29 | + return False, "Act reordering currently not supported by Exllama, "\ |
| 30 | + "when the input features are partitioned across "\ |
| 31 | + "devices" |
| 32 | + |
| 33 | + if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0: |
| 34 | + return False, "Output features must be a multiple of the pack " \ |
| 35 | + "factor (32 / num_bits) so that we can correctly " \ |
| 36 | + "pack the zero points" |
| 37 | + |
| 38 | + if c.act_type != torch.float16: |
| 39 | + return False, "Exllama only supports float16 activations" |
| 40 | + |
| 41 | + if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: |
| 42 | + return False, f"Quant type ({c.weight_type}) not supported by "\ |
| 43 | + "Exllama, supported types are: "\ |
| 44 | + f"{cls.SUPPORTED_QUANT_TYPES}" |
| 45 | + |
| 46 | + if c.full_weight_shape[0] % c.group_size != 0: |
| 47 | + return False, f"Group size ({c.group_size}) does not evenly divide"\ |
| 48 | + " the number of input features "\ |
| 49 | + f"({c.full_weight_shape[0]})" |
| 50 | + |
| 51 | + return True, None |
| 52 | + |
| 53 | + def process_weights_after_loading(self, layer: torch.nn.Module): |
| 54 | + c = self.config |
| 55 | + |
| 56 | + # For Exllama, we need to set a zero-point tensor if there is not one |
| 57 | + if not c.zero_points: |
| 58 | + self.w_zp_name = "qzeros" |
| 59 | + device = getattr(layer, self.w_q_name).device |
| 60 | + groups = c.partition_weight_shape[0] // c.group_size |
| 61 | + out_features = c.partition_weight_shape[1] |
| 62 | + |
| 63 | + if c.weight_type.has_bias(): |
| 64 | + # if the type has a bias we have to create a zeros tensor that |
| 65 | + # contains the bias values repeated for each group (-1 due to |
| 66 | + # a bug in the original GPTQ checkpoint format leading to |
| 67 | + # exllama kernel adding 1 to the zero points during inference) |
| 68 | + # Documentation of the bug can be found here: |
| 69 | + # https://garden.danieldk.eu/GPTQ-Checkpoint-Format |
| 70 | + zeros = torch.full((groups, out_features), |
| 71 | + c.weight_type.bias - 1, |
| 72 | + dtype=torch.int32, |
| 73 | + device=device) |
| 74 | + else: |
| 75 | + raise NotImplementedError( |
| 76 | + "A 0 zero-point is not supported by Exllama due to " |
| 77 | + "a bug in the original GPTQ checkpoint format leading to " |
| 78 | + "exllama kernel adding 1 to the zero points during " |
| 79 | + "inference") |
| 80 | + zeros = pack_quantized_values_into_int32(zeros, |
| 81 | + c.weight_type, |
| 82 | + packed_dim=1) |
| 83 | + setattr(layer, self.w_zp_name, |
| 84 | + torch.nn.Parameter(zeros, requires_grad=False)) |
| 85 | + |
| 86 | + if c.has_g_idx: |
| 87 | + |
| 88 | + def transform_w_g_idx(x): |
| 89 | + # Exllama wants the permutation array instead of the group |
| 90 | + # indices |
| 91 | + return torch.argsort(x).to(torch.int) |
| 92 | + |
| 93 | + self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) |
| 94 | + else: |
| 95 | + self.w_gidx_name = "g_idx" |
| 96 | + empty_g_idx = torch.nn.Parameter(torch.empty((0, ), |
| 97 | + dtype=torch.int, |
| 98 | + device=device), |
| 99 | + requires_grad=False) |
| 100 | + setattr(layer, self.w_gidx_name, empty_g_idx) |
| 101 | + |
| 102 | + def transform_w_q(x): |
| 103 | + assert isinstance(x, BasevLLMParameter) |
| 104 | + assert self.w_gidx_name is not None |
| 105 | + g_idx = getattr(layer, self.w_gidx_name) |
| 106 | + |
| 107 | + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) |
| 108 | + x_cont = x.data.contiguous() |
| 109 | + ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits) |
| 110 | + return x_cont |
| 111 | + |
| 112 | + def transform_w_s(x): |
| 113 | + assert isinstance(x, BasevLLMParameter) |
| 114 | + permute_param_layout_(x, input_dim=0, output_dim=1) |
| 115 | + x.data = x.data.contiguous() |
| 116 | + return x.to(dtype=c.act_type) |
| 117 | + |
| 118 | + # Repack weights and scales for Machete |
| 119 | + self._transform_param(layer, self.w_q_name, transform_w_q) |
| 120 | + self._transform_param(layer, self.w_s_name, transform_w_s) |
| 121 | + |
| 122 | + def apply_weights(self, |
| 123 | + layer: torch.nn.Module, |
| 124 | + x: torch.Tensor, |
| 125 | + bias: Optional[torch.Tensor] = None) -> torch.Tensor: |
| 126 | + c = self.config |
| 127 | + |
| 128 | + x_2d = x.reshape(-1, x.shape[-1]) |
| 129 | + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) |
| 130 | + |
| 131 | + w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) |
| 132 | + |
| 133 | + assert w_zp is not None, "Zero points are required by Exllama" |
| 134 | + assert w_g_idx is not None, "Group index is required by Exllama" |
| 135 | + output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True, |
| 136 | + c.weight_type.size_bits) |
| 137 | + |
| 138 | + if bias is not None: |
| 139 | + output.add_(bias) |
| 140 | + return output.reshape(out_shape) |
0 commit comments