本项目基于paddlepaddle框架复现了Luke预训练模型,主要复现Open Entity和SQuAD1.1数据集的结果。
说明 -此项目为旧版,新版在此处https://github.com/Beacontownfc/paddle_luke_stable
项目参考:
复现的代码未达到论文精度,但运行原论文代码也未达到论文精度(所有超参与原论文代码一致)
网络 | opt | batch_size | 数据集 | F1 | F1(原论文代码) |
---|---|---|---|---|---|
Luke-large | AdamW | 2 | Open Entity | 77.50 | 77.50 |
原论文代码运行Loss曲线及训练日志: 原论文代码Loss曲线 原论文代码训练日志
复现代码运行Loss曲线及训练日志: 复现代码Loss曲线 复现代码训练日志
由于SQuAD1.1数据集比较特殊,不提供测试集,因此对比验证集的结果
在SQuAD1.1数据集上,成功复现了论文精度
网络 | opt | batch_size | 数据集 | F1 | F1(原论文) | EM | EM(原论文) |
---|---|---|---|---|---|---|---|
Luke-large | AdamW | 8 | SQuAD1.1 | 94.95 | 95.0 | 89.76 | 89.8 |
复现代码及训练日志: 复现代码训练日志
首先下载预训练权重,下载地址:
百度网盘
解压至./reading_comprehension
和./open_entity
两个路径下
下载Open Entity数据集
下载地址
把下载好的文件解压,并把解压后的Open Entity目录下的train.json
、test.json
和dev.json
复制至./open_entity/data
,或者可以直接使用./open_entity/data
路径下的open entity数据集
下载SQuAD1.1数据集
下载地址
,下载解压至./reading_comprehension/squad_data/squad
下,同时需要下载由官方提供的维基百科数据集
下载地址
, 下载解压至./reading_comprehension/squad_data
代码结构
├─open_entity
| ├─paddle_luke.pt #预训练权重
| ├─data # 数据集文件夹
| | ├─train.json #open entity 训练集
| | ├─dev.json #open entity 验证集
| | ├─test.json #open entity 测试集
| | ├─merges.txt #tokenizer 文件
| | ├─entity_vocab.tsv #实体词文件
| | ├─vocab.json #tokenizer 文件
| ├─luke_model #LUKE模型文件
| | ├─utils
| | ├─entity_vocab.py
| | ├─interwiki_db.py
| | ├─model.py
| ├─datagenerator.py #数据生成器文件
| ├─main.py #运行训练并测试
| ├─open_entity.py #LUKE下游任务
| ├─trainer.py #训练
| ├─utils.py
| ├─word_tokenizer.py
├─reading_comprehension
| ├─paddle_luke.pt
| ├─luke_model
| | ├─utils
| | ├─model.py
| ├─squad_data
| | ├─squad
| | | ├─train-v1.1.json # SQuAD1.1数据集训练集
| | | ├─dev-v1.1.json # SQuAD1.1数据集验证集
| | ├─squad_change
| | ├─entity_vocab.tsv
| | ├─merges.txt
| | ├─metadata.json
| | ├─vocab.json
| | ├─enwiki_20160305.pkl #LUKE官方提供的维基百科数据集
| | ├─enwiki_20181220_redirects.pkl #LUKE官方提供的维基百科数据集
| | ├─enwiki_20160305_redirects.pkl #LUKE官方提供的维基百科数据集
| ├─src
| ├─utils
| ├─create_squad_data.py
| ├─main.py
| ├─reading_comprehension.py #LUKE下游任务
安装第三方库
pip install -r requirements.txt
python main.py 2>&1 | tee train.log
运行结束后你将看到如下结果:
Results: %s {
"test_f1": 0.775093934514224111,
"test_precision": 0.792535675082327,
"test_recall": 0.75840336134453781
}
python create_squad_data.py
python main.py 2>&1 | tee train.log
运行结束后你将看到如下结果:
{"exact_match": 89.75691579943235, "f1": 94.95702001984502}
说明
1、本项目在Aistudio平台,使用Tesla V100训练
2、本项目基于PaddlePaddle开发。