diff --git a/docs/zh_cn/tutorials/quant/post_training_quantization.md b/docs/zh_cn/tutorials/quant/post_training_quantization.md index 09391de13..e9e92576e 100644 --- a/docs/zh_cn/tutorials/quant/post_training_quantization.md +++ b/docs/zh_cn/tutorials/quant/post_training_quantization.md @@ -15,12 +15,13 @@ ### 1. 量化配置相关概念以及接口: -`Observer`:用于统计OP输入或输出,并计算出量化相关的统计量,比如scale、zero_point等。每个离线量化算法对应一个Observer,现已有的Observer包含: +`Observer`:用于统计OP输入或输出,并计算出量化相关的统计量,比如scale、zero_point等。每个离线量化算法对应一个Observer,Observer可以使用属性quant_bits调整量化的数据类型,quant_bits = 8代表INT8量化,quant_bits = (4,3)代表FP8量化,现已有的Observer包含: - `AVGObserver`:收集目标Tensor的平均值作为量化scale - `MSEObserver`:收集最大绝对值并通过最小化MSE误差,收集量化scale - `EMDObserver`:收集最大绝对值并通过最小化EMD误差,收集量化scale - `HistObserver`:将张量值收集到直方图中,并根据百分比计算量化scale - `KLObserver`:以最小化浮点值分布与量化浮点值分布之间的 Kullback-Leibler散度计算量化scale +- `AbsmaxObserver`:根据目标权重的Tensor维度,收集最大绝对值作为量化scale - `AbsMaxChannelWiseWeightObserver`:根据目标权重的通道维度,收集最大绝对值作为量化scale - `MSEChannelWiseWeightObserver`:根据目标权重的通道维度,收集最大绝对值并通过最小化MSE误差,收集量化scale @@ -44,7 +45,7 @@ | convert | `model`:需要被转化的量化模型
`inplace`:inplace=True时,该模型会被inplace的量化;inplace=False时,不改变原模型,并且会return一个量化的模型 | 将模型转化成onnx形式,进行此步骤之后才能对量化模型进行验证、导出成静态图等 -## 使用示例 +## INT8量化使用示例 ```python import paddle import paddleslim @@ -91,3 +92,52 @@ for step, data in enumerate(dataloader): # convert to quant model that can evaluate and export model = ptq.convert(model, inplace=True) ``` + + +## FP8量化使用示例 +```python +import paddle +import paddleslim +from paddle.vision.models import mobilenet_v1 +from paddle.quantization import QuantConfig +from paddle.quantization import PTQ +from paddleslim.quant.observers import HistObserver, KLObserver, EMDObserver, MSEObserver, AVGObserver, MSEChannelWiseWeightObserver, AbsMaxChannelWiseWeightObserver + +# create the model +model = mobilenet_v1() + +# define QuantConfig +q_config = QuantConfig(activation=None, weight=None) + +# define act_quanter and weight_quanter +act_quanter = AbsmaxObserver(quant_bits=(4,3)) +weight_quanter = AbsMaxChannelWiseWeightObserver(quant_bits=(4,3)) + +# map ColumnParallelLinear to QuantizedColumnParallelLinear +q_config.add_qat_layer_mapping(ColumnParallelLinear, + QuantizedColumnParallelLinear) +# map RowParallelLinear to QuantizedRowParallelLinear +q_config.add_qat_layer_mapping(RowParallelLinear, + QuantizedRowParallelLinear) +# for each layer if type in [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear] +# make them quantizable +q_config.add_type_config( + [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear], + activation=activation, + weight=weight, + ) + + +ptq = PTQ(q_config) +model = ptq.quantize(model, inplace=True) + +# ptq sample +ptq_step = 100 +for step, data in enumerate(dataloader): + pred = model(data) + if step == ptq_step: + break + +# convert to quant model that can evaluate and export +model = ptq.convert(model, inplace=True) +``` \ No newline at end of file diff --git a/paddleslim/quant/observers/base_hist.py b/paddleslim/quant/observers/base_hist.py index 1a4755071..36b9e196a 100644 --- a/paddleslim/quant/observers/base_hist.py +++ b/paddleslim/quant/observers/base_hist.py @@ -68,7 +68,10 @@ def forward(self, inputs): self._zero_point = None self._min = None self._max = None - + """" Cast inputs to 'float32' for numpy compatibility in _init_hists function, avoiding issues with types like bf16. + """ + dtype = inputs.dtype + inputs = inputs.cast('float32') if self._hist_min is None or self._hist_max is None: self._hist_min, self._hist_max = self._min_max(inputs) self._hist = self._init_hists(inputs) @@ -82,7 +85,7 @@ def forward(self, inputs): self._upsample_bin_count, ) self._hist_min, self._hist_max = new_min, new_max self._hist = new_hist - return inputs + return inputs.cast(dtype) def _update_min_max_and_hist(self, tensor, origin_min, origin_max, origin_hist, bins_count, upsample_bins_count): diff --git a/paddleslim/quant/observers/emd.py b/paddleslim/quant/observers/emd.py index 8dea968e7..66cf06072 100644 --- a/paddleslim/quant/observers/emd.py +++ b/paddleslim/quant/observers/emd.py @@ -67,7 +67,6 @@ def cal_min_max(self, inputs): while s <= 1.0: scale = s * abs_max_value s += 0.02 - bins = 2**(self._quant_bits - 1) - 1 quant_var = paddle.clip( paddle.round(inputs / scale * self.qmax), -self.qmax - 1, self.qmax) diff --git a/paddleslim/quant/observers/kl.py b/paddleslim/quant/observers/kl.py index b9653ff21..35ac01200 100644 --- a/paddleslim/quant/observers/kl.py +++ b/paddleslim/quant/observers/kl.py @@ -127,6 +127,8 @@ def cal_kl_threshold(hist, bin_width, bits): assert hist.ndim == 1 hist_bins = hist.shape[0] starting_iter = int((hist_bins - 1) * 0.5) + if isinstance(bits,tuple): + bits = bits[0] + bits[1] quant_range = 2**(bits - 1) - 1 P_sum = np.sum(np.array(hist).ravel()) diff --git a/paddleslim/quant/observers/uniform.py b/paddleslim/quant/observers/uniform.py index d874fa687..bc3a5ef31 100644 --- a/paddleslim/quant/observers/uniform.py +++ b/paddleslim/quant/observers/uniform.py @@ -26,7 +26,7 @@ class UniformObserver(BaseObserver): an integer value ensuring that zero is quantized without error. Args: - quant_bits (int): The number of bits for quantization. + quant_bits (int) or (Tuple): The number of bits for quantization. sign (bool): Whether the quantized integer includes a sign. symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric. In symmetric quantization, the range of floating point values is relaxed to be symmetric @@ -56,12 +56,24 @@ def __init__( def qmin_qmax(self): """ Calculate the range of the quantized integer based on the specified quant_bits, sign, and symmetric properties.""" - if self._sign: - self._qmin = -2**(self.bit_length() - 1) - self._qmax = 2**(self.bit_length() - 1) - 1 - else: - self._qmin = 0 - self._qmax = 2**self.bit_length() + if isinstance(self._quant_bits,tuple): + if (self._quant_bits[0]==4 and self._quant_bits[1]==3 and len(self._quant_bits)==2): + self._qmin = -448.0 + self._qmax = 448.0 + elif (self._quant_bits[0]==5 and self._quant_bits[1]==2 and len(self._quant_bits)==2): + self._qmin = 57344.0 + self._qmax = 57344.0 + else: + raise NotImplementedError( + "Currently, only float8_e4m3 and float8_e5m2 formats are supported. Please set quant_bits to (4,3) or (5,2) for the corresponding format." + ) + else: + if self._sign: + self._qmin = -2**(self.bit_length() - 1) + self._qmax = 2**(self.bit_length() - 1) - 1 + else: + self._qmin = 0 + self._qmax = 2**self.bit_length() return self._qmin, self._qmax @abc.abstractmethod