Skip to content

Commit b9a4cb1

Browse files
JunnYugongenlei
andauthored
Add MPNet Model (PaddlePaddle#869)
* add mpnet * update * update tokenizer and update readme * update readme & add docs * rm unused figure * update * update * update copyright Co-authored-by: gongenlei <[email protected]>
1 parent 211c0f5 commit b9a4cb1

File tree

12 files changed

+2662
-0
lines changed

12 files changed

+2662
-0
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# MPNet with PaddleNLP
2+
3+
[MPNet: Masked and Permuted Pre-training for Language Understanding - Microsoft Research](https://www.microsoft.com/en-us/research/publication/mpnet-masked-and-permuted-pre-training-for-language-understanding/)
4+
5+
**摘要:**
6+
BERT采用掩码语言建模(MLM)进行预训练,是最成功的预训练模型之一。由于BERT忽略了预测标记之间的依赖关系,XLNet引入了置换语言建模(PLM)进行预训练来解决这个问题。然而,XLNet没有利用句子的完整位置信息,因此会受到预训练和微调之间的位置差异的影响。在本文中,我们提出了MPNet,这是一种新的预训练方法,它继承了BERT和XLNet的优点并避免了它们的局限性。MPNet通过置换语言建模(相对于BERT中的MLM)利用预测标记之间的依赖性,并以辅助位置信息作为输入,使模型能够看到完整的句子,从而减少位置差异(相对于XLNet中的PLM)。我们在大规模数据集(超过160GB的文本语料库)上预训练了MPNet模型,并对各种下游任务(GLUE、SQuAD 等)进行微调。实验结果表明,在相同的模型设置下,MPNet大大优于MLM和PLM,并且与之前最先进的预训练方法(例如 BERT、XLNet、RoBERTa)相比,在这些任务上取得了更好的结果。原始代码和预训练模型可从 https://github.com/microsoft/MPNet 下载得到。
7+
8+
本项目是 MPNet 在 Paddle 2.x上的开源实现。
9+
10+
## 快速开始
11+
12+
### 下游任务微调
13+
14+
#### 1、GLUE
15+
以QQP数据集为例,运行其他glue数据集,请参考`train.sh`文件。(超参数遵循原论文的仓库的[README](https://github.com/microsoft/MPNet/blob/master/MPNet/README.glue.md)
16+
17+
##### (1)模型微调:
18+
```shell
19+
unset CUDA_VISIBLE_DEVICES
20+
cd glue
21+
python -m paddle.distributed.launch --gpus "0" run_glue.py \
22+
--model_type mpnet \
23+
--model_name_or_path mpnet-base \
24+
--task_name qqp \
25+
--max_seq_length 128 \
26+
--batch_size 32 \
27+
--learning_rate 1e-5 \
28+
--scheduler_type linear \
29+
--weight_decay 0.1 \
30+
--warmup_steps 5666 \
31+
--max_steps 113272 \
32+
--logging_steps 500 \
33+
--save_steps 2000 \
34+
--seed 42 \
35+
--output_dir qqp/ \
36+
--device gpu
37+
```
38+
其中参数释义如下:
39+
- `model_type` 指示了模型类型,当前支持BERT、ELECTRA、ERNIE、CONVBERT、MPNET模型。
40+
- `model_name_or_path` 模型名称或者路径,其中mpnet模型当前仅支持mpnet-base几种规格。
41+
- `task_name` 表示 Fine-tuning 的任务,当前支持CoLA、SST-2、MRPC、STS-B、QQP、MNLI、QNLI、RTE和WNLI。
42+
- `max_seq_length` 表示最大句子长度,超过该长度将被截断。
43+
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
44+
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。
45+
- `scheduler_type` scheduler类型,可选linear和cosine。
46+
- `warmup_steps` warmup步数。
47+
- `max_steps` 表示最大训练步数。
48+
- `logging_steps` 表示日志打印间隔。
49+
- `save_steps` 表示模型保存及评估间隔。
50+
- `output_dir` 表示模型保存路径。
51+
- `device` 表示使用的设备类型。默认为GPU,可以配置为CPU、GPU、XPU。若希望使用多GPU训练,将其设置为GPU,同时环境变量CUDA_VISIBLE_DEVICES配置要使用的GPU id。
52+
53+
##### (2)模型预测:
54+
```bash
55+
cd glue
56+
python run_predict.py --task_name qqp --ckpt_path qqp/best-qqp_ft_model_106000.pdparams
57+
```
58+
59+
##### (3)压缩template文件夹为zip文件,然后提交到[GLUE排行榜](https://gluebenchmark.com/leaderboard)
60+
61+
62+
###### GLUE开发集结果:
63+
64+
| task | cola | sst-2 | mrpc | sts-b | qqp | mnli | qnli | rte | avg |
65+
|--------------------------------|-------|-------|-------------|------------------|-------------|------|-------|-------|-------|
66+
| **metric** | **mcc** | **acc** | **acc/f1** | **pearson/spearman** | **acc/f1** | **acc(m/mm)** | **acc** | **acc** | |
67+
| Paper | **65.0** | **95.5** | **91.8**/空 | 91.1/空 | **91.9**/空 | **88.5**/空 | 93.3 | 85.8 | **87.9** |
68+
| Mine | 64.4 | 95.4 | 90.4/93.1 | **91.6**/91.3 | **91.9**/89.0 | 87.7/88.2 | **93.6** | **86.6** | 87.7 |
69+
70+
###### GLUE测试集结果对比:
71+
72+
| task | cola | sst-2 | mrpc | sts-b | qqp | mnli-m | qnli | rte | avg |
73+
|--------------------------------|-------|-------|-------|-------|-----|-------|-------|-------|----------|
74+
| **metric** | **mcc** | **acc** | **acc/f1** | **pearson/spearman** | **acc/f1** | **acc(m/mm)** | **acc** | **acc** | |
75+
| Paper | **64.0** | **96.0** | 89.1/空 | 90.7/空 | **89.9**/空 | **88\.5**/空 | 93\.1 | 81.0 | **86.5** |
76+
| Mine | 60.5 | 95.9 | **91.6**/88.9 | **90.8**/90.3 | 89.7/72.5 | 87.6/86.6 | **93.3** | **82.4** | **86.5** |
77+
78+
#### 2、SQuAD v1.1
79+
80+
使用Paddle提供的预训练模型运行SQuAD v1.1数据集的Fine-tuning
81+
82+
```bash
83+
unset CUDA_VISIBLE_DEVICES
84+
cd squad
85+
python -m paddle.distributed.launch --gpus "0" run_squad.py \
86+
--model_type mpnet \
87+
--model_name_or_path mpnet-base \
88+
--max_seq_length 512 \
89+
--batch_size 16 \
90+
--learning_rate 2e-5 \
91+
--num_train_epochs 4 \
92+
--scheduler_type linear \
93+
--logging_steps 25 \
94+
--save_steps 25 \
95+
--warmup_proportion 0.1 \
96+
--weight_decay 0.1 \
97+
--output_dir squad1.1/ \
98+
--device gpu \
99+
--do_train \
100+
--seed 42 \
101+
--do_predict
102+
```
103+
104+
训练过程中模型会自动对结果进行评估,其中最好的结果如下所示:
105+
106+
```python
107+
{
108+
"exact": 86.84957426679281,
109+
"f1": 92.82031917884066,
110+
"total": 10570,
111+
"HasAns_exact": 86.84957426679281,
112+
"HasAns_f1": 92.82031917884066,
113+
"HasAns_total": 10570
114+
}
115+
```
116+
117+
#### 3、SQuAD v2.0
118+
对于 SQuAD v2.0,按如下方式启动 Fine-tuning:
119+
120+
```bash
121+
unset CUDA_VISIBLE_DEVICES
122+
cd squad
123+
python -m paddle.distributed.launch --gpus "0" run_squad.py \
124+
--model_type mpnet \
125+
--model_name_or_path mpnet-base \
126+
--max_seq_length 512 \
127+
--batch_size 16 \
128+
--learning_rate 2e-5 \
129+
--num_train_epochs 4 \
130+
--scheduler_type linear \
131+
--logging_steps 200 \
132+
--save_steps 200 \
133+
--warmup_proportion 0.1 \
134+
--weight_decay 0.1 \
135+
--output_dir squad2/ \
136+
--device gpu \
137+
--do_train \
138+
--seed 42 \
139+
--do_predict \
140+
--version_2_with_negative
141+
```
142+
143+
* `version_2_with_negative`: 使用squad2.0数据集和评价指标的标志。
144+
145+
训练过程中模型会自动对结果进行评估,其中最好的结果如下所示:
146+
147+
```python
148+
{
149+
"exact": 82.27912069401162,
150+
"f1": 85.2774124891565,
151+
"total": 11873,
152+
"HasAns_exact": 80.34750337381917,
153+
"HasAns_f1": 86.35268530427743,
154+
"HasAns_total": 5928,
155+
"NoAns_exact": 84.20521446593776,
156+
"NoAns_f1": 84.20521446593776,
157+
"NoAns_total": 5945,
158+
"best_exact": 82.86869367472417,
159+
"best_exact_thresh": -2.450321674346924,
160+
"best_f1": 85.67634263296013,
161+
"best_f1_thresh": -2.450321674346924
162+
}
163+
```
164+
165+
# Tips:
166+
- 对于SQUAD任务:根据这个[issues](https://github.com/microsoft/MPNet/issues/3)所说,论文中汇报的是`best_exact``best_f1`
167+
- 对于GLUE任务:根据这个[issues](https://github.com/microsoft/MPNet/issues/7)所说,部分任务采用了热启动初始化的方法。
168+
169+
# Reference
170+
171+
```bibtex
172+
@article{song2020mpnet,
173+
title={MPNet: Masked and Permuted Pre-training for Language Understanding},
174+
author={Song, Kaitao and Tan, Xu and Qin, Tao and Lu, Jianfeng and Liu, Tie-Yan},
175+
journal={arXiv preprint arXiv:2004.09297},
176+
year={2020}
177+
}
178+
```
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import OrderedDict
16+
import argparse
17+
18+
huggingface_to_paddle = {
19+
".attn.": ".",
20+
"intermediate.dense": "ffn",
21+
"output.dense": "ffn_output",
22+
".output.LayerNorm.": ".layer_norm.",
23+
".LayerNorm.": ".layer_norm.",
24+
"lm_head.decoder.bias": "lm_head.decoder_bias",
25+
}
26+
27+
skip_weights = ["lm_head.decoder.weight", "lm_head.bias"]
28+
dont_transpose = [
29+
"_embeddings.weight",
30+
".LayerNorm.weight",
31+
".layer_norm.weight",
32+
"relative_attention_bias.weight",
33+
]
34+
35+
36+
def convert_pytorch_checkpoint_to_paddle(pytorch_checkpoint_path,
37+
paddle_dump_path):
38+
import torch
39+
import paddle
40+
41+
pytorch_state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
42+
paddle_state_dict = OrderedDict()
43+
for k, v in pytorch_state_dict.items():
44+
transpose = False
45+
if k in skip_weights:
46+
continue
47+
if k[-7:] == ".weight":
48+
if not any([w in k for w in dont_transpose]):
49+
if v.ndim == 2:
50+
v = v.transpose(0, 1)
51+
transpose = True
52+
oldk = k
53+
for huggingface_name, paddle_name in huggingface_to_paddle.items():
54+
k = k.replace(huggingface_name, paddle_name)
55+
56+
print(f"Converting: {oldk} => {k} | is_transpose {transpose}")
57+
paddle_state_dict[k] = v.data.numpy()
58+
59+
paddle.save(paddle_state_dict, paddle_dump_path)
60+
61+
62+
if __name__ == "__main__":
63+
parser = argparse.ArgumentParser()
64+
parser.add_argument(
65+
"--pytorch_checkpoint_path",
66+
default="weights/hg/mpnet-base/pytorch_model.bin",
67+
type=str,
68+
required=False,
69+
help="Path to the Pytorch checkpoint path.", )
70+
parser.add_argument(
71+
"--paddle_dump_path",
72+
default="weights/pd/mpnet-base/model_state.pdparams",
73+
type=str,
74+
required=False,
75+
help="Path to the output Paddle model.", )
76+
args = parser.parse_args()
77+
convert_pytorch_checkpoint_to_paddle(args.pytorch_checkpoint_path,
78+
args.paddle_dump_path)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# task name ["cola","sst-2","mrpc","sts-b","qqp","mnli", "rte", "qnli"]
2+
3+
python run_predict.py --task_name qqp --ckpt_path qqp/best-qqp_ft_model_106000.pdparams

0 commit comments

Comments
 (0)