Skip to content

[fix]修复自然语言处理模型自动压缩示例 #1839

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
67 changes: 49 additions & 18 deletions example/auto_compression/nlp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@

#### 3.1 准备环境
- python >= 3.6
- PaddlePaddle >= 2.4 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim >= 2.4
- PaddleNLP >= 2.3
- PaddlePaddle ==2.5 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim ==2.5
- PaddleNLP ==2.6

安装paddlepaddle:
```shell
# CPU
pip install paddlepaddle==2.4.1
pip install paddlepaddle==2.5.0
# GPU 以Ubuntu、CUDA 11.2为例
python -m pip install paddlepaddle-gpu==2.4.1.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
python -m pip install paddlepaddle-gpu==2.5.0.post116 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
```

安装paddleslim:
Expand Down Expand Up @@ -95,7 +95,6 @@ pip install paddlenlp
|:------:|:------:|:------:|:------:|:------:|:-----------:|:------:|:------:|
| PP-MiniLM | [afqmc](https://bj.bcebos.com/v1/paddle-slim-models/act/afqmc.tar) | [tnews](https://bj.bcebos.com/v1/paddle-slim-models/act/tnews.tar) | [iflytek](https://bj.bcebos.com/v1/paddle-slim-models/act/iflytek.tar) | [cmnli](https://bj.bcebos.com/v1/paddle-slim-models/act/cmnli.tar) | [ ocnli](https://bj.bcebos.com/v1/paddle-slim-models/act/ocnli.tar) | [cluewsc2020](https://bj.bcebos.com/v1/paddle-slim-models/act/cluewsc.tar) | [csl](https://bj.bcebos.com/v1/paddle-slim-models/act/csl.tar) |
| ERNIE 3.0-Medium | [afqmc](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/AFQMC.tar) | [tnews](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/TNEWS.tar) | [iflytek](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/IFLYTEK.tar) | [cmnli](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CMNLI.tar) | [ocnli](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/OCNLI.tar) | [cluewsc2020](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CLUEWSC2020.tar) | [csl](https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CSL.tar) |
| UIE-base | [报销工单](https://bj.bcebos.com/v1/paddle-slim-models/act/uie_base.tar) |

从上表获得模型超链接, 并用以下命令下载推理模型文件:

Expand All @@ -119,11 +118,6 @@ export CUDA_VISIBLE_DEVICES=0
python run.py --config_path='./configs/pp-minilm/auto/afqmc.yaml' --save_dir='./save_afqmc_pruned/'
```

自动压缩UIE系列模型需要使用 run_uie.py 脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中训练部分的参数,将任务名称、模型类型、数据集名称、压缩参数传入,配置完成后便可对模型进行蒸馏量化训练。
```shell
export CUDA_VISIBLE_DEVICES=0
python run_uie.py --config_path='./configs/uie/uie_base.yaml' --save_dir='./save_uie_qat/'
```

如仅需验证模型精度,或验证压缩之后模型精度,在启动```run.py```脚本时,将配置文件中模型文件夹 ```model_dir``` 改为压缩之后保存的文件夹路径 ```./save_afqmc_pruned``` ,命令加上```--eval True```即可:
```shell
Expand Down Expand Up @@ -212,12 +206,29 @@ QuantPost:

## 5. 预测部署


量化模型在GPU上可以使用TensorRT进行加速,在CPU上可以使用MKLDNN进行加速。

以下字段用于配置预测参数:

| 参数名 | 含义 |
|:------:|:------:|
| model_path | inference 模型文件所在目录,该目录下需要有文件 model.pdmodel 和 model.pdiparams 两个文件 |
| model_filename | 模型文件的名称,默认值为inference.pdmodel |
| params_filename | 参数文件的名称,默认值为inference.pdiparams |
| task_name | 要执行的任务名称,默认值为afqmc |
| dataset | 模型使用的数据集,默认值为clue |
| device | 用于推理的设备,默认为gpu,可选cpu或gpu |
| batch_size | 推理时的batch size,默认为32 |
| max_seq_len | 输入序列在分词后的最大长度,默认值为128,如果序列长于此值,将会被截断;如果短于此值,将会被填充|
| perf_warmup_steps | 性能测试的预热步数,默认值为20 |
| use_trt | 一个标志(flag),用于决定是否使用TensorRT推理 |
| precision | 推理精度,默认为fp32,可选fp16或int8 |
| use_mkldnn | 一个标志(flag),用于决定是否使用MKLDNN推理 |
| cpu_threads | CPU线程数,默认为1 |

- TensorRT预测:

环境配置:如果使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python)
#### 5.1 TensorRT预测:

首先下载量化好的模型:
```shell
Expand All @@ -227,10 +238,30 @@ tar -xf save_ppminilm_afqmc_new_calib.tar

```shell
python paddle_inference_eval.py \
--model_path=save_ernie3_afqmc_new_cablib \
--model_path=save_ppminilm_afqmc_new_calib \
--model_filename=inference.pdmodel \
--params_filename=inference.pdiparams \
--task_name='afqmc' \
--use_trt \
--precision=int8
```

- ERNIE 3.0-Medium:
```shell
python paddle_inference_eval.py \
--model_path=TNEWS \
--model_filename=infer.pdmodel \
--params_filename=infer.pdiparams \
--task_name='afqmc' \
--task_name='tnews' \
--use_trt \
--precision=fp32
```
```shell
python paddle_inference_eval.py \
--model_path=save_tnews_pruned \
--model_filename=infer.pdmodel \
--params_filename=infer.pdiparams \
--task_name='tnews' \
--use_trt \
--precision=int8
```
Expand All @@ -239,9 +270,9 @@ python paddle_inference_eval.py \

```shell
python paddle_inference_eval.py \
--model_path=save_ernie3_afqmc_new_cablib \
--model_filename=infer.pdmodel \
--params_filename=infer.pdiparams \
--model_path=save_ppminilm_afqmc_new_calib \
--model_filename=inference.pdmodel \
--params_filename=inference.pdiparams \
--task_name='afqmc' \
--device=cpu \
--use_mkldnn=True \
Expand Down
23 changes: 14 additions & 9 deletions example/auto_compression/nlp/configs/ernie3.0/tnews.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@ Global:
dataset: clue
batch_size: 16
max_seq_length: 128
TrainConfig:
epochs: 6
eval_iter: 1110
learning_rate: 2.0e-5
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.5700

# 剪枝
Prune:
prune_algo: transformer_pruner
pruned_ratio: 0.25

# 离线量化
QuantPost:
activation_bits: 8
quantize_op_types:
- depthwise_conv2d
- conv2d
weight_bits: 8

20 changes: 7 additions & 13 deletions example/auto_compression/nlp/configs/pp-minilm/auto/afqmc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,11 @@ Global:
dataset: clue
batch_size: 16
max_seq_length: 128
TransformerPrune:
pruned_ratio: 0.25
HyperParameterOptimization:
Distillation:

#离线量化
QuantPost:
TrainConfig:
epochs: 6
eval_iter: 1070
learning_rate: 2.0e-5
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.7403
activation_bits: 8
quantize_op_types:
- conv2d
- depthwise_conv2d
weight_bits: 8
21 changes: 14 additions & 7 deletions example/auto_compression/nlp/paddle_inference_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def parse_args():
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
help=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", )
parser.add_argument(
"--perf_warmup_steps",
Expand All @@ -107,7 +108,8 @@ def parse_args():
type=str,
default="fp32",
choices=["fp32", "fp16", "int8"],
help="The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'.",
help=
"The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'.",
)
parser.add_argument(
"--use_mkldnn",
Expand Down Expand Up @@ -156,8 +158,7 @@ def _convert_example(example,
}
elif "target" in example: # wsc
text, query, pronoun, query_idx, pronoun_idx = (
example["text"],
example["target"]["span1_text"],
example["text"], example["target"]["span1_text"],
example["target"]["span2_text"],
example["target"]["span1_index"],
example["target"]["span2_index"], )
Expand Down Expand Up @@ -209,6 +210,12 @@ def create_predictor(cls, args):
config = paddle.inference.Config(
os.path.join(args.model_path, args.model_filename),
os.path.join(args.model_path, args.params_filename))
# config.switch_ir_debug(True)
# 适用于ERNIE 3.0-Medium模型
# config.exp_disable_tensorrt_ops(["elementwise_add"])
# config.exp_disable_tensorrt_ops(["fused_embedding_eltwise_layernorm"])
# config.exp_disable_tensorrt_ops(["tmp_3"])

if args.device == "gpu":
# set GPU configs accordingly
config.enable_use_gpu(100, 0)
Expand Down Expand Up @@ -239,8 +246,8 @@ def create_predictor(cls, args):
dynamic_shape_file = os.path.join(args.model_path,
"dynamic_shape.txt")
if os.path.exists(dynamic_shape_file):
config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file,
True)
config.enable_tuned_tensorrt_dynamic_shape(
dynamic_shape_file, True)
print("trt set dynamic shape done!")
else:
config.collect_shape_range_info(dynamic_shape_file)
Expand Down Expand Up @@ -365,4 +372,4 @@ def main():

if __name__ == "__main__":
paddle.set_device("cpu")
main()
main()
58 changes: 36 additions & 22 deletions paddleslim/quant/advanced/auto_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
from .metrics import mse_loss
from paddle.distributed.fleet.meta_parallel import (
ColumnParallelLinear,
RowParallelLinear,
)
RowParallelLinear, )
__all__ = ['AutoClip']


class AutoClip(nn.Layer):
"""
AutoClip from AWQ[https://arxiv.org/abs/2306.00978]
"""

def __init__(
self,
model,
Expand All @@ -39,8 +40,7 @@ def __init__(
n_grid=20,
max_shrink=0.5,
n_sample_token=512,
group_size=128,
):
group_size=128, ):
super(AutoClip, self).__init__()
self.model = model
self.weight_bits = weight_bits
Expand All @@ -59,15 +59,17 @@ def __init__(
def _apply_hook(self):
self._forward_hook_list = []
for _, sub_layer in self.model.named_sublayers():
if type(sub_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]:
if type(sub_layer) in [
ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear
]:
forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
self._forward_pre_hook)
self._forward_hook_list.append(forward_pre_hook_handle)

def _forward_pre_hook(self, layer, input):
self._sample_scale(input, layer.full_name())
return input

def _sample_scale(self, input, name):
input = input[0] if type(input) == tuple else input
input.stop_gradient = True
Expand All @@ -80,7 +82,6 @@ def _sample_scale(self, input, name):
else:
self.sampled_inputs[name] = input


def auto_clip(self, group_size=128, oc_batch_size=256):
"""
search clip scale for each layer and update the layer's weight
Expand All @@ -89,7 +90,7 @@ def auto_clip(self, group_size=128, oc_batch_size=256):
name = sub_layer.full_name()
if name not in self.sampled_inputs or 'out_linear' in sub_name:
continue

weight = sub_layer.weight.cast('float16')
weight_t = paddle.transpose(weight, perm=[1, 0])
x = self.sampled_inputs[name].cast('float16')
Expand All @@ -98,33 +99,41 @@ def auto_clip(self, group_size=128, oc_batch_size=256):
x = x.reshape([1, x.shape[0], -1, group_size])
x = x[:, 0::x.shape[1] // self.n_sample_token]
weight_t = weight_t.reshape([weight_t.shape[0], 1, -1, group_size])
oc_batch_size = oc_batch_size if weight_t.shape[0] % oc_batch_size == 0 else 128 # prevent OOM
oc_batch_size = oc_batch_size if weight_t.shape[
0] % oc_batch_size == 0 else 128 # prevent OOM
assert weight_t.shape[0] % oc_batch_size == 0

w_all = weight_t
best_max_val_all = []

for i_b in range(weight_t.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
w = w_all[i_b * oc_batch_size:(i_b + 1) * oc_batch_size]

org_max_val = w.abs().max(axis=-1, keepdim=True) # co, 1, n_group, 1
org_max_val = w.abs().max(
axis=-1, keepdim=True) # co, 1, n_group, 1
best_max_val = org_max_val.clone()
min_errs = paddle.ones_like(org_max_val, dtype='float16') * 1e9
org_out = (x * w).sum(axis=-1) # co, n_token, n_group
for i_s in range(int(self.max_shrink * self.n_grid)):
max_val = org_max_val * (1 - i_s / self.n_grid)
max_val_tmp = max_val
cur_w = paddle.where(w > max_val_tmp, max_val_tmp, w)
cur_w = paddle.where(cur_w < - max_val_tmp, - max_val_tmp, cur_w)
cur_w = paddle.where(cur_w < -max_val_tmp, -max_val_tmp,
cur_w)
org_w_shape = cur_w.shape
cur_w_r = cur_w.reshape([-1, self.group_size]).transpose([1, 0])
quant_dequant_weight = fake_quant(cur_w_r, method='abs_max_channel_wise', weight_bits=4)
quant_dequant_weight = quant_dequant_weight.transpose([1, 0]).reshape(org_w_shape)
cur_w_r = cur_w.reshape([-1,
self.group_size]).transpose([1, 0])
quant_dequant_weight = fake_quant(
cur_w_r, method='abs_max_channel_wise', weight_bits=4)
quant_dequant_weight = quant_dequant_weight.transpose(
[1, 0]).reshape(org_w_shape)
cur_out = (x * quant_dequant_weight).sum(axis=-1)
# co, 1, n_group, 1
tmp = (cur_out - org_out).detach().clone()
err = paddle.pow(tmp, 2).mean(axis=1).reshape(min_errs.shape)
print('block {} search s {} err {}'.format(i_b, i_s, err.mean().item()))
err = paddle.pow(tmp,
2).mean(axis=1).reshape(min_errs.shape)
print('block {} search s {} err {}'.format(
i_b, i_s, err.mean().item()))
del cur_w, cur_out, quant_dequant_weight, tmp, cur_w_r
paddle.device.cuda.empty_cache()

Expand All @@ -143,16 +152,21 @@ def auto_clip(self, group_size=128, oc_batch_size=256):
if 'w_0' in param.name:
param_tmp = param.transpose(perm=[1, 0]).cast('float16')
tmp_shape = param_tmp.shape
param_tmp = param_tmp.reshape([best_max_val.shape[0], best_max_val.shape[1], -1])
best_max_val = paddle.tile(best_max_val, repeat_times=(1, 1, param_tmp.shape[-1]))
param_tmp = paddle.where(param_tmp > best_max_val, best_max_val, param_tmp)
param_tmp = paddle.where(param_tmp < - best_max_val, - best_max_val, param_tmp)
param_tmp = param_tmp.reshape(
[best_max_val.shape[0], best_max_val.shape[1], -1])
best_max_val = paddle.tile(
best_max_val, repeat_times=(1, 1, param_tmp.shape[-1]))
param_tmp = paddle.where(param_tmp > best_max_val,
best_max_val, param_tmp)
param_tmp = paddle.where(param_tmp < -best_max_val,
-best_max_val, param_tmp)
param_tmp = param_tmp.reshape(tmp_shape).cast(param.dtype)
param_tmp = param_tmp.transpose(perm=[1, 0])
paddle.assign(param_tmp, output=param)
del param_tmp
paddle.device.cuda.empty_cache()
break

del best_max_val, weight_t, x, weight, self.sampled_inputs[name], w_all, best_max_val_all
del best_max_val, weight_t, x, weight, self.sampled_inputs[
name], w_all, best_max_val_all
paddle.device.cuda.empty_cache()
2 changes: 1 addition & 1 deletion paddleslim/quant/advanced/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,4 @@ def fasterquant(self,

self.quantized = True
del H, Q, Hinv, W, Losses
paddle.device.cuda.empty_cache()
paddle.device.cuda.empty_cache()
2 changes: 0 additions & 2 deletions paddleslim/quant/advanced/piecewise_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
origin_out = paddle.matmul(act, weight)
w_abs_max = weight.abs().max(axis=-1, keepdim=True)
rw_abs_max = w_abs_max.reshape(act_abs_max.shape)

smooth_scale_out = None
global_loss = float('inf')
best_scale = None
Expand Down Expand Up @@ -184,5 +183,4 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
print('Find Better K-Piece {}'.format(k_piece))
if not self.search_piece:
break

return best_scale