Skip to content

Commit 0082916

Browse files
LiuChiachiLiuChiachi
andauthored
Add encoder op (PaddlePaddle#952)
* add encoder op * support fp16 * update encoder sample remove useless config use pad to max seq len * commit before pull * update patches and cmakelist * upadte the interface for encoder op * support jit, and support replacing forward function of nn.TransformerEncoder(and *Layer) object * support jit move encoder sample to semantic index, and update readme delete fp16 add device config info fix disable faster encoder bug * remove unsupported argument add docstring and warning process for False bias_attr of EncoderLayer fix docstring format add \n, and add docstring for infer_transformer_encoder Co-authored-by: LiuChiachi <[email protected]>
1 parent 61c638f commit 0082916

File tree

14 files changed

+2436
-10
lines changed

14 files changed

+2436
-10
lines changed

examples/semantic_indexing/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,32 @@ python -u -m paddle.distributed.launch --gpus "0" \
232232
0.9800204038619995
233233
```
234234

235+
## 使用Faster Transformer进行快速预测
236+
不同于上述原生预测的是,使用FasterTransformer的预测是使用了集成了Faster Transformer库的Paddle自定算子,在一定的配置下,可以对TransformerEncoder的预测进行加速。
237+
238+
```shell
239+
python -u -m paddle.distributed.launch --gpus "0" faster_predict.py \
240+
--init_from_params "batch_neg_v1.0/model_state.pdparams" \
241+
--output_emb_size 256 \
242+
--batch_size 32 \
243+
--max_seq_length 64 \
244+
--text_pair_file ${your_input_file} \
245+
246+
```
247+
248+
执行上述操作后,可以得到与原生预测非常接近的结果,在float32下,且batch_size=32, max_seq_len=64时,二种方式预测下最终余弦相似度的最大绝对误差约为3.93e-6。
249+
250+
通过比较,可以得到在不同batch_size, max_seq_len下,使用集成了FasterTransformer的高性能算子可以对Encoder部分的推理进行加速(其余参数都与默认值相同)。在NVIDIA Tesla V100,16GB的机器上,使用单卡预测得到部分性能数据如下,从表中可以看出在更小的batch_size和max_seq_len上,使用FasterTransformer预测更有优势。
251+
252+
| batch size | max_seq_len | FT加速算子(单位:s) | Paddle原生(单位:s) |
253+
| ---------- | ----------- | ------------------- | ------------------- |
254+
| 16 | 16 | 22.645333290100098 | 51.55912470817566 |
255+
| 16 | 32 | 27.326106071472168 | 57.17143130302429 |
256+
| 16 | 64 | 33.31318140029907 | 52.44770574569702 |
257+
| 32 | 16 | 12.891342163085938 | 22.621662139892578 |
258+
| 32 | 32 | 17.206310987472534 | 22.18772006034851 |
259+
260+
235261
## 模型介绍
236262
简要介绍 In-batch negatives 策略和 HardestNeg 策略思路
237263

examples/semantic_indexing/data.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def create_dataloader(dataset,
2626
trans_fn=None):
2727
if trans_fn:
2828
dataset = dataset.map(trans_fn)
29-
3029
shuffle = True if mode == 'train' else False
3130
if mode == 'train':
3231
batch_sampler = paddle.io.DistributedBatchSampler(
@@ -42,7 +41,10 @@ def create_dataloader(dataset,
4241
return_list=True)
4342

4443

45-
def convert_example(example, tokenizer, max_seq_length=512):
44+
def convert_example(example,
45+
tokenizer,
46+
max_seq_length=512,
47+
pad_to_max_seq_len=False):
4648
"""
4749
Builds model inputs from a sequence.
4850
@@ -65,11 +67,13 @@ def convert_example(example, tokenizer, max_seq_length=512):
6567

6668
result = []
6769
for key, text in example.items():
68-
encoded_inputs = tokenizer(text=text, max_seq_len=max_seq_length)
70+
encoded_inputs = tokenizer(
71+
text=text,
72+
max_seq_len=max_seq_length,
73+
pad_to_max_seq_len=pad_to_max_seq_len)
6974
input_ids = encoded_inputs["input_ids"]
7075
token_type_ids = encoded_inputs["token_type_ids"]
7176
result += [input_ids, token_type_ids]
72-
7377
return result
7478

7579

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from functools import partial
17+
import argparse
18+
from pprint import pprint
19+
import numpy as np
20+
21+
import paddle
22+
import paddle.nn as nn
23+
import paddle.nn.functional as F
24+
from paddle.nn import TransformerEncoder, TransformerEncoderLayer
25+
26+
from paddlenlp.transformers import ErnieTokenizer, ErnieModel
27+
from paddlenlp.data import Pad, Tuple
28+
from paddlenlp.datasets import load_dataset
29+
from paddlenlp.ops import enable_faster_encoder, disable_faster_encoder
30+
31+
from data import read_text_pair, convert_example, create_dataloader
32+
33+
34+
def parse_args():
35+
parser = argparse.ArgumentParser()
36+
parser.add_argument(
37+
"--text_pair_file",
38+
type=str,
39+
required=True,
40+
help="The full path of input file")
41+
parser.add_argument(
42+
"--output_emb_size",
43+
default=None,
44+
type=int,
45+
help="output_embedding_size")
46+
parser.add_argument(
47+
"--params_path",
48+
type=str,
49+
required=True,
50+
help="The path to model parameters to be loaded.")
51+
parser.add_argument(
52+
"--max_seq_length",
53+
default=64,
54+
type=int,
55+
help="The maximum total input sequence length after tokenization. "
56+
"Sequences longer than this will be truncated, sequences shorter will be padded."
57+
)
58+
parser.add_argument(
59+
"--dropout", default=0.0, type=float, help="Dropout probability.")
60+
parser.add_argument(
61+
"--batch_size",
62+
default=32,
63+
type=int,
64+
help="Batch size per GPU/CPU for training.")
65+
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
66+
parser.add_argument(
67+
"--pad_to_max_seq_len",
68+
action="store_true",
69+
help="Whether to pad to max_seq_len.")
70+
71+
args = parser.parse_args()
72+
return args
73+
74+
75+
class SemanticIndexingPredictor(nn.Layer):
76+
def __init__(self,
77+
pretrained_model,
78+
output_emb_size,
79+
n_layer=12,
80+
n_head=12,
81+
hidden_size=768,
82+
dim_feedforward=3072,
83+
activation="relu",
84+
bos_id=0,
85+
dropout=0,
86+
max_seq_len=128,
87+
is_gelu=False):
88+
super(SemanticIndexingPredictor, self).__init__()
89+
size_per_head = hidden_size // n_head
90+
self.bos_id = bos_id
91+
self.ptm = pretrained_model
92+
self.dropout = nn.Dropout(dropout if dropout is not None else 0.0)
93+
self.output_emb_size = output_emb_size
94+
if output_emb_size > 0:
95+
weight_attr = paddle.ParamAttr(
96+
initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
97+
self.emb_reduce_linear = paddle.nn.Linear(
98+
768, output_emb_size, weight_attr=weight_attr)
99+
encoder_layer = TransformerEncoderLayer(
100+
hidden_size, n_head, dim_feedforward, dropout=dropout)
101+
self.ptm.encoder = TransformerEncoder(encoder_layer, n_layer)
102+
103+
def get_pooled_embedding(self,
104+
input_ids,
105+
token_type_ids=None,
106+
position_ids=None,
107+
attention_mask=None):
108+
src_mask = (input_ids != self.bos_id
109+
).astype(self.ptm.encoder.layers[0].norm1.bias.dtype)
110+
src_mask = paddle.unsqueeze(src_mask, axis=[1, 2])
111+
src_mask.stop_gradient = True
112+
113+
ones = paddle.ones_like(input_ids, dtype="int64")
114+
seq_length = paddle.cumsum(ones, axis=1)
115+
position_ids = seq_length - ones
116+
position_ids.stop_gradient = True
117+
118+
embedding_output = self.ptm.embeddings(
119+
input_ids=input_ids,
120+
position_ids=position_ids,
121+
token_type_ids=token_type_ids)
122+
sequence_output = self.ptm.encoder(embedding_output, src_mask)
123+
cls_embedding = self.ptm.pooler(sequence_output)
124+
125+
if self.output_emb_size > 0:
126+
cls_embedding = self.emb_reduce_linear(cls_embedding)
127+
cls_embedding = self.dropout(cls_embedding)
128+
cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)
129+
130+
return cls_embedding
131+
132+
def forward(self,
133+
query_input_ids,
134+
title_input_ids,
135+
query_token_type_ids=None,
136+
query_position_ids=None,
137+
query_attention_mask=None,
138+
title_token_type_ids=None,
139+
title_position_ids=None,
140+
title_attention_mask=None):
141+
query_cls_embedding = self.get_pooled_embedding(
142+
query_input_ids, query_token_type_ids, query_position_ids,
143+
query_attention_mask)
144+
title_cls_embedding = self.get_pooled_embedding(
145+
title_input_ids, title_token_type_ids, title_position_ids,
146+
title_attention_mask)
147+
cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding,
148+
axis=-1)
149+
return cosine_sim
150+
151+
def load(self, init_from_params):
152+
if init_from_params and os.path.isfile(init_from_params):
153+
state_dict = paddle.load(init_from_params)
154+
self.set_state_dict(state_dict)
155+
print("Loaded parameters from %s" % init_from_params)
156+
else:
157+
raise ValueError(
158+
"Please set --params_path with correct pretrained model file")
159+
160+
161+
def do_predict(args):
162+
place = paddle.set_device("gpu")
163+
paddle.seed(args.seed)
164+
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
165+
166+
trans_func = partial(
167+
convert_example,
168+
tokenizer=tokenizer,
169+
max_seq_length=args.max_seq_length,
170+
pad_to_max_seq_len=args.pad_to_max_seq_len)
171+
172+
batchify_fn = lambda samples, fn=Tuple(
173+
Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input
174+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # query_segment
175+
Pad(axis=0, pad_val=tokenizer.pad_token_id), # title_input
176+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # tilte_segment
177+
): [data for data in fn(samples)]
178+
179+
valid_ds = load_dataset(
180+
read_text_pair, data_path=args.text_pair_file, lazy=False)
181+
182+
valid_data_loader = create_dataloader(
183+
valid_ds,
184+
mode="predict",
185+
batch_size=args.batch_size,
186+
batchify_fn=batchify_fn,
187+
trans_fn=trans_func)
188+
189+
pretrained_model = ErnieModel.from_pretrained("ernie-1.0")
190+
191+
model = SemanticIndexingPredictor(
192+
pretrained_model,
193+
args.output_emb_size,
194+
max_seq_len=args.max_seq_length,
195+
dropout=args.dropout)
196+
model.eval()
197+
model.load(args.params_path)
198+
model = enable_faster_encoder(model)
199+
cosine_sims = []
200+
for batch_data in valid_data_loader:
201+
query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch_data
202+
query_input_ids = paddle.to_tensor(query_input_ids)
203+
query_token_type_ids = paddle.to_tensor(query_token_type_ids)
204+
title_input_ids = paddle.to_tensor(title_input_ids)
205+
title_token_type_ids = paddle.to_tensor(title_token_type_ids)
206+
batch_cosine_sim = model(
207+
query_input_ids=query_input_ids,
208+
title_input_ids=title_input_ids,
209+
query_token_type_ids=query_token_type_ids,
210+
title_token_type_ids=title_token_type_ids).numpy()
211+
cosine_sims.append(batch_cosine_sim)
212+
213+
cosine_sims = np.concatenate(cosine_sims, axis=0)
214+
for cosine in cosine_sims:
215+
print('{}'.format(cosine))
216+
model = disable_faster_encoder(model)
217+
218+
219+
if __name__ == "__main__":
220+
args = parse_args()
221+
pprint(args)
222+
do_predict(args)

examples/semantic_indexing/predict.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")
3939
parser.add_argument("--output_emb_size", default=None, type=int, help="output_embedding_size")
4040
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
41+
parser.add_argument("--pad_to_max_seq_len", action="store_true", help="Whether to pad to max seq length.")
4142
args = parser.parse_args()
4243
# yapf: enable
4344

@@ -86,7 +87,8 @@ def predict(model, data_loader):
8687
trans_func = partial(
8788
convert_example,
8889
tokenizer=tokenizer,
89-
max_seq_length=args.max_seq_length)
90+
max_seq_length=args.max_seq_length,
91+
pad_to_max_seq_len=args.pad_to_max_seq_len)
9092

9193
batchify_fn = lambda samples, fn=Tuple(
9294
Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input

paddlenlp/ops/CMakeLists.txt

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ option(WITH_TRANSFORMER "Compile with Transformer"
2727
option(WITH_GPT "Compile with GPT" OFF)
2828
option(WITH_UNIFIED "Compile with Unified Transformer" ON)
2929
option(WITH_DECODER "Compile with Transformer Decoder" ON)
30+
option(WITH_ENCODER "Compile with Transformer Encoder" ON)
3031

3132
if(NOT WITH_GPU)
3233
message(FATAL_ERROR "Faster transformer custom op doesn't support CPU. Please add the flag -DWITH_GPU=ON to use GPU. ")
@@ -44,12 +45,16 @@ if(WITH_UNIFIED)
4445
list(APPEND decoding_op_files fusion_unified_decoding_op.cc fusion_unified_decoding_op.cu)
4546
endif()
4647

48+
if(WITH_ENCODER)
49+
list(APPEND decoding_op_files fusion_encoder_op.cc fusion_encoder_op.cu)
50+
endif()
51+
4752
if(WITH_DECODER)
4853
list(APPEND decoder_op_files fusion_decoder_op.cc fusion_decoder_op.cu)
4954
endif()
5055

51-
if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER)
52-
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON must be set to use FasterTransformer. ")
56+
if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER)
57+
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON must be set to use FasterTransformer. ")
5358
endif()
5459

5560
set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
@@ -161,6 +166,13 @@ file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}
161166
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/cuda/topk_kernels.cuh topk_kernels_cuh_src)
162167
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/topk_kernels.cuh topk_kernels_cuh_dst)
163168

169+
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/open_attention.h open_attention_h_dst)
170+
171+
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/bert_encoder_transformer.h bert_encoder_transformer_h_src)
172+
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_dst)
173+
174+
set(OPT_OPEN_ATTN_COMMAND sed -i -e "370,392d" -e "410,454d" -e "229d" ${open_attention_h_dst})
175+
#set(OPT_BERT_ENCODER_COMMAND sed -i -e "552,592d" -e "118a bool is_gelu_=true;" ${bert_encoder_transformer_h_dst})
164176

165177
# TODO(guosheng): `find` seems meeting errors missing argument to `-exec', fix it
166178
set(MUTE_COMMAND grep -rl "printf(\"\\[WARNING\\]" ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/ | xargs -i{} sed -i "s/printf(\"\\WWARNING\\W decoding[^)]\\{1,\\})/ /" {})
@@ -174,6 +186,7 @@ set(FT_PATCH_COMMAND
174186
&& cp ${beamsearch_h_src} ${trans_dst}
175187
&& cp ${sampling_h_src} ${trans_dst}
176188
&& cp ${arguments_h_src} ${trans_dst}
189+
&& cp ${bert_encoder_transformer_h_src} ${bert_encoder_transformer_h_dst}
177190
&& cat ${cuda_kernels_h_src} >> ${cuda_kernels_h_dst}
178191
&& cat ${cuda_kernels_cu_src} >> ${cuda_kernels_cu_dst}
179192
&& cat ${decoding_kernels_cu_src} >> ${decoding_kernels_cu_dst}
@@ -182,6 +195,7 @@ set(FT_PATCH_COMMAND
182195
&& cat ${trans_decoder_h_src} >> ${open_decoder_h_dst}
183196
&& cat ${trans_cuda_kernels_h_src} >> ${cuda_kernels_h_dst}
184197
&& cat ${trans_decoding_kernels_cu_src} >> ${decoding_kernels_cu_dst}
198+
&& ${OPT_OPEN_ATTN_COMMAND}
185199
&& ${MUTE_COMMAND}
186200
)
187201

@@ -282,3 +296,5 @@ if(ON_INFER AND WITH_GPT)
282296
endif()
283297

284298
add_subdirectory(faster_transformer)
299+
300+

paddlenlp/ops/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515
from .faster_transformer.transformer.decoding import *
1616
from .faster_transformer.transformer.faster_transformer import *
1717
from .faster_transformer.transformer.decoder import *
18+
from .faster_transformer.transformer.encoder import *
1819
from .einsum import *
1920
from .distributed import *
2021
from . import optimizer
22+
23+
paddle.nn.TransformerEncoderLayer._ft_forward = encoder_layer_forward
24+
paddle.nn.TransformerEncoder._ft_forward = encoder_forward
25+
26+
paddle.nn.TransformerEncoderLayer._ori_forward = paddle.nn.TransformerEncoderLayer.forward
27+
paddle.nn.TransformerEncoder._ori_forward = paddle.nn.TransformerEncoder.forward

0 commit comments

Comments
 (0)