From c931da942361c528740ad4186dc33cbdd9e82f9c Mon Sep 17 00:00:00 2001 From: shibing624 Date: Wed, 31 Jul 2024 22:51:17 +0800 Subject: [PATCH] update onnx usage. --- README.md | 70 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 0c21517..0e9111e 100644 --- a/README.md +++ b/README.md @@ -409,11 +409,77 @@ if __name__ == '__main__': #### ONNX推理加速 -支持将训练好的模型导出为ONNX格式,以便推理加速,或者在其他环境如C++部署模型调用,示例[examples/onnx_predict_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/onnx_predict_demo.py) +支持将训练好的模型导出为ONNX格式,以便推理加速,或者在其他环境如C++部署模型调用。 - GPU环境下导出ONNX模型,用ONNX模型推理,可以获得10倍以上的推理加速,需要安装`onnxruntime-gpu`库:`pip install onnxruntime-gpu` -- CPU环境下导出ONNX模型,用ONNX模型推理,可以获得5倍以上的推理加速,需要安装`onnxruntime`库:`pip install onnxruntime` +- CPU环境下导出ONNX模型,用ONNX模型推理,可以获得6倍以上的推理加速,需要安装`onnxruntime`库:`pip install onnxruntime` +示例[examples/onnx_predict_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/onnx_predict_demo.py) + +```python +import os +import shutil +import sys +import time + +import torch + +sys.path.append('..') +from pytextclassifier import BertClassifier + +m = BertClassifier(output_dir='models/bert-chinese-v1', num_classes=2, + model_type='bert', model_name='bert-base-chinese', num_epochs=1) +data = [ + ('education', '名师指导托福语法技巧:名词的复数形式'), + ('education', '中国高考成绩海外认可 是“狼来了”吗?'), + ('education', '公务员考虑越来越吃香,这是怎么回事?'), + ('education', '公务员考虑越来越吃香,这是怎么回事1?'), + ('education', '公务员考虑越来越吃香,这是怎么回事2?'), + ('education', '公务员考虑越来越吃香,这是怎么回事3?'), + ('education', '公务员考虑越来越吃香,这是怎么回事4?'), + ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), + ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), + ('sports', '米兰客场8战不败国米10年连胜1'), + ('sports', '米兰客场8战不败国米10年连胜2'), + ('sports', '米兰客场8战不败国米10年连胜3'), + ('sports', '米兰客场8战不败国米10年连胜4'), + ('sports', '米兰客场8战不败国米10年连胜5'), +] +m.train(data * 10) +m.load_model() + +samples = ['名师指导托福语法技巧', + '米兰客场8战不败', + '恒生AH溢指收平 A股对H股折价1.95%'] * 100 + +start_time = time.time() +predict_label_bert, predict_proba_bert = m.predict(samples) +print(f'predict_label_bert size: {len(predict_label_bert)}') +end_time = time.time() +elapsed_time_bert = end_time - start_time +print(f'Standard BERT model prediction time: {elapsed_time_bert} seconds') + +# convert to onnx, and load onnx model to predict, speed up 10x +save_onnx_dir = 'models/bert-chinese-v1/onnx' +m.model.convert_to_onnx(save_onnx_dir) +# copy label_vocab.json to save_onnx_dir +if os.path.exists(m.label_vocab_path): + shutil.copy(m.label_vocab_path, save_onnx_dir) + +# Manually delete the model and clear CUDA cache +del m +torch.cuda.empty_cache() + +m = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='bert', model_name=save_onnx_dir, + args={"onnx": True}) +m.load_model() +start_time = time.time() +predict_label_bert, predict_proba_bert = m.predict(samples) +print(f'predict_label_bert size: {len(predict_label_bert)}') +end_time = time.time() +elapsed_time_onnx = end_time - start_time +print(f'ONNX model prediction time: {elapsed_time_onnx} seconds') +``` ## Evaluation