Skip to content

Commit

Permalink
update onnx demo
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jul 31, 2024
1 parent 72605da commit 6a9f67f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
25 changes: 24 additions & 1 deletion examples/bert_classification_zh_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
@author:XuMing(xuming624@qq.com)
@description:
"""
import shutil
import sys

sys.path.append('..')
Expand All @@ -17,9 +18,17 @@
('education', '名师指导托福语法技巧:名词的复数形式'),
('education', '中国高考成绩海外认可 是“狼来了”吗?'),
('education', '公务员考虑越来越吃香,这是怎么回事?'),
('education', '公务员考虑越来越吃香,这是怎么回事1?'),
('education', '公务员考虑越来越吃香,这是怎么回事2?'),
('education', '公务员考虑越来越吃香,这是怎么回事3?'),
('education', '公务员考虑越来越吃香,这是怎么回事4?'),
('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
('sports', '米兰客场8战不败国米10年连胜'),
('sports', '米兰客场8战不败国米10年连胜1'),
('sports', '米兰客场8战不败国米10年连胜2'),
('sports', '米兰客场8战不败国米10年连胜3'),
('sports', '米兰客场8战不败国米10年连胜4'),
('sports', '米兰客场8战不败国米10年连胜5'),
]
m.train(data)
print(m)
Expand Down Expand Up @@ -50,3 +59,17 @@
'美EB-5项目“15日快速移民”将推迟',
'恒生AH溢指收平 A股对H股折价1.95%'])
print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')

# convert to onnx, and load onnx model to predict, speed up 10x
save_onnx_dir = 'models/onnx'
m.model.convert_to_onnx(save_onnx_dir)
# copy label_vocab.json to save_onnx_dir
shutil.copy('models/bert-chinese/label_vocab.json', save_onnx_dir)
m = BertClassifier(output_dir=save_onnx_dir, num_classes=10, model_type='bert', model_name=save_onnx_dir,
args={"onnx": True})
m.load_model()
predict_label, predict_proba = m.predict(
['顺义北京苏活88平米起精装房在售',
'美EB-5项目“15日快速移民”将推迟',
'恒生AH溢指收平 A股对H股折价1.95%'])
print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
6 changes: 5 additions & 1 deletion pytextclassifier/bert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
pwd_path = os.path.abspath(os.path.dirname(__file__))
device = 'cuda' if torch.cuda.is_available() else (
'mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu')
use_cuda = torch.cuda.is_available()
default_use_cuda = torch.cuda.is_available()
os.environ["TOKENIZERS_PARALLELISM"] = "false"


Expand All @@ -37,6 +37,7 @@ def __init__(
max_seq_length=128,
multi_label=False,
labels_sep=',',
use_cuda=None,
args=None,
):

Expand All @@ -51,6 +52,7 @@ def __init__(
@param max_seq_length: max seq length, trim longer sentence.
@param multi_label: bool, multi label or single label
@param labels_sep: label separator, default is ','
@param use_cuda: bool, use cuda or not
@param args: dict, train args
"""
default_args = {
Expand All @@ -65,6 +67,8 @@ def __init__(
if args and isinstance(args, dict):
train_args.update_from_dict(args)
train_args.update_from_dict(default_args)
if use_cuda is None:
use_cuda = default_use_cuda

self.model = BertClassificationModel(
model_type=model_type,
Expand Down

0 comments on commit 6a9f67f

Please # to comment.