From 6a9f67fa87f19ed6f2a5909b53175c34201998b0 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Wed, 31 Jul 2024 21:31:45 +0800 Subject: [PATCH] update onnx demo --- examples/bert_classification_zh_demo.py | 25 ++++++++++++++++++++++++- pytextclassifier/bert_classifier.py | 6 +++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/examples/bert_classification_zh_demo.py b/examples/bert_classification_zh_demo.py index fa336a2..d245558 100644 --- a/examples/bert_classification_zh_demo.py +++ b/examples/bert_classification_zh_demo.py @@ -3,6 +3,7 @@ @author:XuMing(xuming624@qq.com) @description: """ +import shutil import sys sys.path.append('..') @@ -17,9 +18,17 @@ ('education', '名师指导托福语法技巧:名词的复数形式'), ('education', '中国高考成绩海外认可 是“狼来了”吗?'), ('education', '公务员考虑越来越吃香,这是怎么回事?'), + ('education', '公务员考虑越来越吃香,这是怎么回事1?'), + ('education', '公务员考虑越来越吃香,这是怎么回事2?'), + ('education', '公务员考虑越来越吃香,这是怎么回事3?'), + ('education', '公务员考虑越来越吃香,这是怎么回事4?'), ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'), ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'), - ('sports', '米兰客场8战不败国米10年连胜'), + ('sports', '米兰客场8战不败国米10年连胜1'), + ('sports', '米兰客场8战不败国米10年连胜2'), + ('sports', '米兰客场8战不败国米10年连胜3'), + ('sports', '米兰客场8战不败国米10年连胜4'), + ('sports', '米兰客场8战不败国米10年连胜5'), ] m.train(data) print(m) @@ -50,3 +59,17 @@ '美EB-5项目“15日快速移民”将推迟', '恒生AH溢指收平 A股对H股折价1.95%']) print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') + + # convert to onnx, and load onnx model to predict, speed up 10x + save_onnx_dir = 'models/onnx' + m.model.convert_to_onnx(save_onnx_dir) + # copy label_vocab.json to save_onnx_dir + shutil.copy('models/bert-chinese/label_vocab.json', save_onnx_dir) + m = BertClassifier(output_dir=save_onnx_dir, num_classes=10, model_type='bert', model_name=save_onnx_dir, + args={"onnx": True}) + m.load_model() + predict_label, predict_proba = m.predict( + ['顺义北京苏活88平米起精装房在售', + '美EB-5项目“15日快速移民”将推迟', + '恒生AH溢指收平 A股对H股折价1.95%']) + print(f'predict_label: {predict_label}, predict_proba: {predict_proba}') diff --git a/pytextclassifier/bert_classifier.py b/pytextclassifier/bert_classifier.py index 07d5a8f..bad250a 100644 --- a/pytextclassifier/bert_classifier.py +++ b/pytextclassifier/bert_classifier.py @@ -21,7 +21,7 @@ pwd_path = os.path.abspath(os.path.dirname(__file__)) device = 'cuda' if torch.cuda.is_available() else ( 'mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu') -use_cuda = torch.cuda.is_available() +default_use_cuda = torch.cuda.is_available() os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -37,6 +37,7 @@ def __init__( max_seq_length=128, multi_label=False, labels_sep=',', + use_cuda=None, args=None, ): @@ -51,6 +52,7 @@ def __init__( @param max_seq_length: max seq length, trim longer sentence. @param multi_label: bool, multi label or single label @param labels_sep: label separator, default is ',' + @param use_cuda: bool, use cuda or not @param args: dict, train args """ default_args = { @@ -65,6 +67,8 @@ def __init__( if args and isinstance(args, dict): train_args.update_from_dict(args) train_args.update_from_dict(default_args) + if use_cuda is None: + use_cuda = default_use_cuda self.model = BertClassificationModel( model_type=model_type,