基于 BERT 模型的中文文本分类工具。
python 3.7
torch 1.7
tqdm
sklearn
transformers 4.8.1
从 THUCNews 中随机抽取20万条新闻标题,一共有10个类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐,每类2万条标题数据。数据集按如下划分:
- 训练集:18万条新闻标题,每个类别的标题数为18000
- 验证集:1万条新闻标题,每个类别的标题数为1000
- 测试集:1万条新闻标题,每个类别的标题数为1000
可以按照 data 文件夹中的数据格式来准备自己任务所需的数据,并调整 config.py 中的相关配置参数。
从 huggingface 官网上下载 bert-base-chinese 模型权重、配置文件和词典到 pretrained_bert 文件夹中,下载地址:https://huggingface.co/bert-base-chinese/tree/main
文本分类模型训练:
python main.py --mode train --data_dir ./data --pretrained_bert_dir ./pretrained_bert
训练中间日志如下:
模型在验证集上的效果如下:
文本分类 demo 展示:
python main.py --mode demo --data_dir ./data --pretrained_bert_dir ./pretrained_bert
对 data 文件夹下的 input.txt 中的文本进行分类预测:
python main.py --mode predict --data_dir ./data --pretrained_bert_dir ./pretrained_bert --input_file ./data/input.txt
输出如下结果: