Skip to content

2024-05-30 Add FP8 PTQ #1877

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
54 changes: 52 additions & 2 deletions docs/zh_cn/tutorials/quant/post_training_quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@

### 1. 量化配置相关概念以及接口:

`Observer`:用于统计OP输入或输出,并计算出量化相关的统计量,比如scale、zero_point等。每个离线量化算法对应一个Observer,现已有的Observer包含:
`Observer`:用于统计OP输入或输出,并计算出量化相关的统计量,比如scale、zero_point等。每个离线量化算法对应一个Observer,Observer可以使用属性quant_bits调整量化的数据类型,quant_bits = 8代表INT8量化,quant_bits = (4,3)代表FP8量化,现已有的Observer包含:
- `AVGObserver`:收集目标Tensor的平均值作为量化scale
- `MSEObserver`:收集最大绝对值并通过最小化MSE误差,收集量化scale
- `EMDObserver`:收集最大绝对值并通过最小化EMD误差,收集量化scale
- `HistObserver`:将张量值收集到直方图中,并根据百分比计算量化scale
- `KLObserver`:以最小化浮点值分布与量化浮点值分布之间的 Kullback-Leibler散度计算量化scale
- `AbsmaxObserver`:根据目标权重的Tensor维度,收集最大绝对值作为量化scale
- `AbsMaxChannelWiseWeightObserver`:根据目标权重的通道维度,收集最大绝对值作为量化scale
- `MSEChannelWiseWeightObserver`:根据目标权重的通道维度,收集最大绝对值并通过最小化MSE误差,收集量化scale

Expand All @@ -44,7 +45,7 @@
| convert | `model`:需要被转化的量化模型 <br> `inplace`:inplace=True时,该模型会被inplace的量化;inplace=False时,不改变原模型,并且会return一个量化的模型 | 将模型转化成onnx形式,进行此步骤之后才能对量化模型进行验证、导出成静态图等


## 使用示例
## INT8量化使用示例
```python
import paddle
import paddleslim
Expand Down Expand Up @@ -91,3 +92,52 @@ for step, data in enumerate(dataloader):
# convert to quant model that can evaluate and export
model = ptq.convert(model, inplace=True)
```


## FP8量化使用示例
```python
import paddle
import paddleslim
from paddle.vision.models import mobilenet_v1
from paddle.quantization import QuantConfig
from paddle.quantization import PTQ
from paddleslim.quant.observers import HistObserver, KLObserver, EMDObserver, MSEObserver, AVGObserver, MSEChannelWiseWeightObserver, AbsMaxChannelWiseWeightObserver

# create the model
model = mobilenet_v1()

# define QuantConfig
q_config = QuantConfig(activation=None, weight=None)

# define act_quanter and weight_quanter
act_quanter = AbsmaxObserver(quant_bits=(4,3))
weight_quanter = AbsMaxChannelWiseWeightObserver(quant_bits=(4,3))

# map ColumnParallelLinear to QuantizedColumnParallelLinear
q_config.add_qat_layer_mapping(ColumnParallelLinear,
QuantizedColumnParallelLinear)
# map RowParallelLinear to QuantizedRowParallelLinear
q_config.add_qat_layer_mapping(RowParallelLinear,
QuantizedRowParallelLinear)
# for each layer if type in [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear]
# make them quantizable
q_config.add_type_config(
[paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear],
activation=activation,
weight=weight,
)


ptq = PTQ(q_config)
model = ptq.quantize(model, inplace=True)

# ptq sample
ptq_step = 100
for step, data in enumerate(dataloader):
pred = model(data)
if step == ptq_step:
break

# convert to quant model that can evaluate and export
model = ptq.convert(model, inplace=True)
```
7 changes: 5 additions & 2 deletions paddleslim/quant/observers/base_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def forward(self, inputs):
self._zero_point = None
self._min = None
self._max = None

"""" Cast inputs to 'float32' for numpy compatibility in _init_hists function, avoiding issues with types like bf16.
"""
dtype = inputs.dtype
inputs = inputs.cast('float32')
if self._hist_min is None or self._hist_max is None:
self._hist_min, self._hist_max = self._min_max(inputs)
self._hist = self._init_hists(inputs)
Expand All @@ -82,7 +85,7 @@ def forward(self, inputs):
self._upsample_bin_count, )
self._hist_min, self._hist_max = new_min, new_max
self._hist = new_hist
return inputs
return inputs.cast(dtype)

def _update_min_max_and_hist(self, tensor, origin_min, origin_max,
origin_hist, bins_count, upsample_bins_count):
Expand Down
1 change: 0 additions & 1 deletion paddleslim/quant/observers/emd.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def cal_min_max(self, inputs):
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2**(self._quant_bits - 1) - 1
quant_var = paddle.clip(
paddle.round(inputs / scale * self.qmax), -self.qmax - 1,
self.qmax)
Expand Down
2 changes: 2 additions & 0 deletions paddleslim/quant/observers/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def cal_kl_threshold(hist, bin_width, bits):
assert hist.ndim == 1
hist_bins = hist.shape[0]
starting_iter = int((hist_bins - 1) * 0.5)
if isinstance(bits,tuple):
bits = bits[0] + bits[1]
quant_range = 2**(bits - 1) - 1

P_sum = np.sum(np.array(hist).ravel())
Expand Down
26 changes: 19 additions & 7 deletions paddleslim/quant/observers/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class UniformObserver(BaseObserver):
an integer value ensuring that zero is quantized without error.

Args:
quant_bits (int): The number of bits for quantization.
quant_bits (int) or (Tuple): The number of bits for quantization.
sign (bool): Whether the quantized integer includes a sign.
symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric.
In symmetric quantization, the range of floating point values is relaxed to be symmetric
Expand Down Expand Up @@ -56,12 +56,24 @@ def __init__(
def qmin_qmax(self):
""" Calculate the range of the quantized integer based on the specified
quant_bits, sign, and symmetric properties."""
if self._sign:
self._qmin = -2**(self.bit_length() - 1)
self._qmax = 2**(self.bit_length() - 1) - 1
else:
self._qmin = 0
self._qmax = 2**self.bit_length()
if isinstance(self._quant_bits,tuple):
if (self._quant_bits[0]==4 and self._quant_bits[1]==3 and len(self._quant_bits)==2):
self._qmin = -448.0
self._qmax = 448.0
elif (self._quant_bits[0]==5 and self._quant_bits[1]==2 and len(self._quant_bits)==2):
self._qmin = 57344.0
self._qmax = 57344.0
else:
raise NotImplementedError(
"Currently, only float8_e4m3 and float8_e5m2 formats are supported. Please set quant_bits to (4,3) or (5,2) for the corresponding format."
)
else:
if self._sign:
self._qmin = -2**(self.bit_length() - 1)
self._qmax = 2**(self.bit_length() - 1) - 1
else:
self._qmin = 0
self._qmax = 2**self.bit_length()
return self._qmin, self._qmax

@abc.abstractmethod
Expand Down