Skip to content

Commit 729198e

Browse files
author
gongenlei
authored
[FasterTransformer] Support BART (PaddlePaddle#1021)
* feat: add FT bart * refactor: remove print * feat: add other files * refactor: remove print * refactor: add pos_offset attribute * refactor: rename attribute * refactor: add ActivationType attribute * refactor: mv encoder to FasterBART
1 parent 97b20d4 commit 729198e

File tree

13 files changed

+2719
-48
lines changed

13 files changed

+2719
-48
lines changed

paddlenlp/ops/CMakeLists.txt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ option(WITH_SP "Compiled with sentencepiece. Only works when WITH_GPT a
3131
option(WITH_DECODER "Compile with Transformer Decoder" ON)
3232
option(WITH_ENCODER "Compile with Transformer Encoder" ON)
3333
option(WITH_STATIC_LIB "Compile static lib" OFF)
34+
option(WITH_BART "Compile with BART" ON)
3435

3536
if(NOT WITH_GPU)
3637
message(FATAL_ERROR "Faster transformer custom op doesn't support CPU. Please add the flag -DWITH_GPU=ON to use GPU. ")
@@ -57,8 +58,12 @@ if(WITH_DECODER)
5758
list(APPEND decoder_op_files fusion_decoder_op.cc fusion_decoder_op.cu)
5859
endif()
5960

60-
if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER)
61-
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON must be set to use FasterTransformer. ")
61+
if(WITH_BART)
62+
list(APPEND decoding_op_files fusion_bart_decoding_op.cc fusion_bart_decoding_op.cu)
63+
endif()
64+
65+
if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER AND NOT WITH_BART)
66+
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON or/and -DWITH_BART=ON must be set to use FasterTransformer. ")
6267
endif()
6368

6469
set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
@@ -177,6 +182,11 @@ file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}
177182
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/bert_encoder_transformer.h bert_encoder_transformer_h_src)
178183
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_dst)
179184

185+
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/open_decoder.h open_decoder_h_src)
186+
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/open_decoder.h open_decoder_h_dst)
187+
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/decoding_sampling.h decoding_sampling_h_src)
188+
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/decoding_sampling.h decoding_sampling_h_dst)
189+
180190
set(OPT_OPEN_ATTN_COMMAND sed -i -e "370,392d" -e "410,454d" -e "229d" ${open_attention_h_dst})
181191

182192
# TODO(guosheng): `find` seems meeting errors missing argument to `-exec', fix it
@@ -192,6 +202,8 @@ set(FT_PATCH_COMMAND
192202
&& cp ${sampling_h_src} ${trans_dst}
193203
&& cp ${arguments_h_src} ${trans_dst}
194204
&& cp ${bert_encoder_transformer_h_src} ${bert_encoder_transformer_h_dst}
205+
&& cp ${open_decoder_h_src} ${open_decoder_h_dst}
206+
&& cp ${decoding_sampling_h_src} ${decoding_sampling_h_dst}
195207
&& cat ${cuda_kernels_h_src} >> ${cuda_kernels_h_dst}
196208
&& cat ${lightseq_kernels_cu_src} >> ${topk_kernels_dst}
197209
&& cat ${cuda_kernels_cu_src} >> ${cuda_kernels_cu_dst}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import time
17+
from pprint import pprint
18+
19+
import paddle
20+
from paddlenlp.ops import FasterBART
21+
from paddlenlp.transformers import BartForConditionalGeneration, BartTokenizer
22+
from paddlenlp.data import Pad
23+
from paddlenlp.utils.log import logger
24+
25+
26+
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
27+
"""
28+
Post-process the decoded sequence.
29+
"""
30+
eos_pos = len(seq) - 1
31+
for i, idx in enumerate(seq):
32+
if idx == eos_idx:
33+
eos_pos = i
34+
break
35+
seq = [
36+
idx for idx in seq[:eos_pos + 1]
37+
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
38+
]
39+
return seq
40+
41+
42+
def prepare_input(tokenizer, sentences, pad_id):
43+
word_pad = Pad(pad_id, dtype="int64")
44+
tokenized = tokenizer(sentences)
45+
inputs = word_pad([i["input_ids"] for i in tokenized])
46+
input_ids = paddle.to_tensor(inputs)
47+
return input_ids
48+
49+
50+
def parse_args():
51+
parser = argparse.ArgumentParser()
52+
parser.add_argument(
53+
"--model_name_or_path",
54+
default="bart-base",
55+
type=str,
56+
help="The model name to specify the bart to use. Can be one of ['bart-base', 'bart-large',]. "
57+
)
58+
parser.add_argument(
59+
"--batch_size", default=1, type=int, help="Batch size. ")
60+
parser.add_argument(
61+
"--decoding_strategy",
62+
default='beam_search',
63+
type=str,
64+
help="The decoding strategy. Can be one of [beam_search, beam_search_v2, topk_sampling, topp_sampling]"
65+
)
66+
parser.add_argument(
67+
"--beam_size",
68+
default=1,
69+
type=int,
70+
help="The parameters for beam search. ")
71+
parser.add_argument(
72+
"--topk",
73+
default=1,
74+
type=int,
75+
help="The number of candidate to procedure beam search. ")
76+
parser.add_argument(
77+
"--topp",
78+
default=0.0,
79+
type=float,
80+
help="The probability threshold to procedure topp sampling. ")
81+
parser.add_argument(
82+
"--max_out_len", default=50, type=int, help="Maximum output length. ")
83+
parser.add_argument(
84+
"--beam_search_diversity_rate",
85+
default=0.0,
86+
type=float,
87+
help="The diversity of beam search. ")
88+
parser.add_argument(
89+
"--n_best",
90+
default=1,
91+
type=int,
92+
help="The number of decoded sentences to output. ")
93+
parser.add_argument(
94+
"--rel_len",
95+
action="store_true",
96+
help=" Indicating whether max_out_len in configurations is the length relative to \
97+
that of source text. Only works in `v2` temporarily.")
98+
parser.add_argument(
99+
"--alpha",
100+
default=0.6,
101+
type=float,
102+
help="The power number in length penalty calculation. Only works in `v2` temporarily."
103+
)
104+
parser.add_argument(
105+
"--use_fp16_decoding",
106+
action="store_true",
107+
help="Whether to use fp16 decoding to predict. ")
108+
parser.add_argument(
109+
"--decoding_lib",
110+
default="../../build/lib/libdecoding_op.so",
111+
type=str,
112+
help="Path of libdecoding_op.so. ")
113+
args = parser.parse_args()
114+
return args
115+
116+
117+
def do_predict(args):
118+
place = "gpu"
119+
place = paddle.set_device(place)
120+
121+
tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path)
122+
logger.info('Loading the model parameters, please wait...')
123+
model = BartForConditionalGeneration.from_pretrained(
124+
args.model_name_or_path)
125+
# Set evaluate mode
126+
model.eval()
127+
sentences = [
128+
"I love that girl, but <mask> does not <mask> me.",
129+
"She is so <mask> that I can not help glance at <mask>.",
130+
"Nothing's gonna <mask> my love for you.",
131+
"Drop everything now. Meet me in the pouring <mask>. Kiss me on the sidewalk.",
132+
]
133+
134+
bos_id = model.bart.config['bos_token_id']
135+
eos_id = model.bart.config['eos_token_id']
136+
pad_id = model.bart.config['pad_token_id']
137+
input_ids = prepare_input(tokenizer, sentences, pad_id)
138+
139+
# Define model
140+
faster_bart = FasterBART(
141+
model=model,
142+
decoding_strategy=args.decoding_strategy,
143+
beam_size=args.beam_size,
144+
topk=args.topk,
145+
topp=args.topp,
146+
max_out_len=args.max_out_len,
147+
beam_search_diversity_rate=args.beam_search_diversity_rate,
148+
decoding_lib=args.decoding_lib,
149+
use_fp16_decoding=args.use_fp16_decoding,
150+
rel_len=args.rel_len,
151+
alpha=args.alpha)
152+
153+
# Set evaluate mode
154+
faster_bart.eval()
155+
156+
with paddle.no_grad():
157+
for i in range(100):
158+
# For warmup.
159+
if 50 == i:
160+
paddle.fluid.core._cuda_synchronize(place)
161+
start = time.perf_counter()
162+
finished_seq = faster_bart(input_ids)
163+
paddle.fluid.core._cuda_synchronize(place)
164+
logger.info("Average test time for decoding is %f ms" % (
165+
(time.perf_counter() - start) / 50 * 1000))
166+
167+
# Output
168+
if args.decoding_strategy.startswith('beam_search'):
169+
finished_seq = finished_seq.numpy().transpose([1, 2, 0])
170+
for ins in finished_seq:
171+
for beam_idx, beam in enumerate(ins):
172+
if beam_idx >= args.n_best:
173+
break
174+
generated_ids = post_process_seq(beam, bos_id, eos_id)
175+
print(tokenizer.convert_ids_to_string(generated_ids))
176+
elif args.decoding_strategy in ['topk_sampling', 'topp_sampling']:
177+
finished_seq = finished_seq.numpy().transpose([1, 0])
178+
for ins in finished_seq:
179+
generated_ids = post_process_seq(ins, bos_id, eos_id)
180+
print(tokenizer.convert_ids_to_string(generated_ids))
181+
182+
183+
if __name__ == "__main__":
184+
args = parse_args()
185+
pprint(args)
186+
do_predict(args)

0 commit comments

Comments
 (0)