Skip to content

Commit

Permalink
update onnx usage.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jul 31, 2024
1 parent 9654a79 commit c931da9
Showing 1 changed file with 68 additions and 2 deletions.
70 changes: 68 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,77 @@ if __name__ == '__main__':

#### ONNX推理加速

支持将训练好的模型导出为ONNX格式,以便推理加速,或者在其他环境如C++部署模型调用,示例[examples/onnx_predict_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/onnx_predict_demo.py)
支持将训练好的模型导出为ONNX格式,以便推理加速,或者在其他环境如C++部署模型调用

- GPU环境下导出ONNX模型,用ONNX模型推理,可以获得10倍以上的推理加速,需要安装`onnxruntime-gpu`库:`pip install onnxruntime-gpu`
- CPU环境下导出ONNX模型,用ONNX模型推理,可以获得5倍以上的推理加速,需要安装`onnxruntime`库:`pip install onnxruntime`
- CPU环境下导出ONNX模型,用ONNX模型推理,可以获得6倍以上的推理加速,需要安装`onnxruntime`库:`pip install onnxruntime`

示例[examples/onnx_predict_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/onnx_predict_demo.py)

```python
import os
import shutil
import sys
import time

import torch

sys.path.append('..')
from pytextclassifier import BertClassifier

m = BertClassifier(output_dir='models/bert-chinese-v1', num_classes=2,
model_type='bert', model_name='bert-base-chinese', num_epochs=1)
data = [
('education', '名师指导托福语法技巧:名词的复数形式'),
('education', '中国高考成绩海外认可 是“狼来了”吗?'),
('education', '公务员考虑越来越吃香,这是怎么回事?'),
('education', '公务员考虑越来越吃香,这是怎么回事1?'),
('education', '公务员考虑越来越吃香,这是怎么回事2?'),
('education', '公务员考虑越来越吃香,这是怎么回事3?'),
('education', '公务员考虑越来越吃香,这是怎么回事4?'),
('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
('sports', '米兰客场8战不败国米10年连胜1'),
('sports', '米兰客场8战不败国米10年连胜2'),
('sports', '米兰客场8战不败国米10年连胜3'),
('sports', '米兰客场8战不败国米10年连胜4'),
('sports', '米兰客场8战不败国米10年连胜5'),
]
m.train(data * 10)
m.load_model()

samples = ['名师指导托福语法技巧',
'米兰客场8战不败',
'恒生AH溢指收平 A股对H股折价1.95%'] * 100

start_time = time.time()
predict_label_bert, predict_proba_bert = m.predict(samples)
print(f'predict_label_bert size: {len(predict_label_bert)}')
end_time = time.time()
elapsed_time_bert = end_time - start_time
print(f'Standard BERT model prediction time: {elapsed_time_bert} seconds')

# convert to onnx, and load onnx model to predict, speed up 10x
save_onnx_dir = 'models/bert-chinese-v1/onnx'
m.model.convert_to_onnx(save_onnx_dir)
# copy label_vocab.json to save_onnx_dir
if os.path.exists(m.label_vocab_path):
shutil.copy(m.label_vocab_path, save_onnx_dir)

# Manually delete the model and clear CUDA cache
del m
torch.cuda.empty_cache()

m = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='bert', model_name=save_onnx_dir,
args={"onnx": True})
m.load_model()
start_time = time.time()
predict_label_bert, predict_proba_bert = m.predict(samples)
print(f'predict_label_bert size: {len(predict_label_bert)}')
end_time = time.time()
elapsed_time_onnx = end_time - start_time
print(f'ONNX model prediction time: {elapsed_time_onnx} seconds')
```

## Evaluation

Expand Down

0 comments on commit c931da9

Please # to comment.