|
| 1 | +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import argparse |
| 16 | +import time |
| 17 | +from pprint import pprint |
| 18 | + |
| 19 | +import paddle |
| 20 | +from paddlenlp.ops import FasterBART |
| 21 | +from paddlenlp.transformers import BartForConditionalGeneration, BartTokenizer |
| 22 | +from paddlenlp.data import Pad |
| 23 | +from paddlenlp.utils.log import logger |
| 24 | + |
| 25 | + |
| 26 | +def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): |
| 27 | + """ |
| 28 | + Post-process the decoded sequence. |
| 29 | + """ |
| 30 | + eos_pos = len(seq) - 1 |
| 31 | + for i, idx in enumerate(seq): |
| 32 | + if idx == eos_idx: |
| 33 | + eos_pos = i |
| 34 | + break |
| 35 | + seq = [ |
| 36 | + idx for idx in seq[:eos_pos + 1] |
| 37 | + if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx) |
| 38 | + ] |
| 39 | + return seq |
| 40 | + |
| 41 | + |
| 42 | +def prepare_input(tokenizer, sentences, pad_id): |
| 43 | + word_pad = Pad(pad_id, dtype="int64") |
| 44 | + tokenized = tokenizer(sentences) |
| 45 | + inputs = word_pad([i["input_ids"] for i in tokenized]) |
| 46 | + input_ids = paddle.to_tensor(inputs) |
| 47 | + return input_ids |
| 48 | + |
| 49 | + |
| 50 | +def parse_args(): |
| 51 | + parser = argparse.ArgumentParser() |
| 52 | + parser.add_argument( |
| 53 | + "--model_name_or_path", |
| 54 | + default="bart-base", |
| 55 | + type=str, |
| 56 | + help="The model name to specify the bart to use. Can be one of ['bart-base', 'bart-large',]. " |
| 57 | + ) |
| 58 | + parser.add_argument( |
| 59 | + "--batch_size", default=1, type=int, help="Batch size. ") |
| 60 | + parser.add_argument( |
| 61 | + "--decoding_strategy", |
| 62 | + default='beam_search', |
| 63 | + type=str, |
| 64 | + help="The decoding strategy. Can be one of [beam_search, beam_search_v2, topk_sampling, topp_sampling]" |
| 65 | + ) |
| 66 | + parser.add_argument( |
| 67 | + "--beam_size", |
| 68 | + default=1, |
| 69 | + type=int, |
| 70 | + help="The parameters for beam search. ") |
| 71 | + parser.add_argument( |
| 72 | + "--topk", |
| 73 | + default=1, |
| 74 | + type=int, |
| 75 | + help="The number of candidate to procedure beam search. ") |
| 76 | + parser.add_argument( |
| 77 | + "--topp", |
| 78 | + default=0.0, |
| 79 | + type=float, |
| 80 | + help="The probability threshold to procedure topp sampling. ") |
| 81 | + parser.add_argument( |
| 82 | + "--max_out_len", default=50, type=int, help="Maximum output length. ") |
| 83 | + parser.add_argument( |
| 84 | + "--beam_search_diversity_rate", |
| 85 | + default=0.0, |
| 86 | + type=float, |
| 87 | + help="The diversity of beam search. ") |
| 88 | + parser.add_argument( |
| 89 | + "--n_best", |
| 90 | + default=1, |
| 91 | + type=int, |
| 92 | + help="The number of decoded sentences to output. ") |
| 93 | + parser.add_argument( |
| 94 | + "--rel_len", |
| 95 | + action="store_true", |
| 96 | + help=" Indicating whether max_out_len in configurations is the length relative to \ |
| 97 | + that of source text. Only works in `v2` temporarily.") |
| 98 | + parser.add_argument( |
| 99 | + "--alpha", |
| 100 | + default=0.6, |
| 101 | + type=float, |
| 102 | + help="The power number in length penalty calculation. Only works in `v2` temporarily." |
| 103 | + ) |
| 104 | + parser.add_argument( |
| 105 | + "--use_fp16_decoding", |
| 106 | + action="store_true", |
| 107 | + help="Whether to use fp16 decoding to predict. ") |
| 108 | + parser.add_argument( |
| 109 | + "--decoding_lib", |
| 110 | + default="../../build/lib/libdecoding_op.so", |
| 111 | + type=str, |
| 112 | + help="Path of libdecoding_op.so. ") |
| 113 | + args = parser.parse_args() |
| 114 | + return args |
| 115 | + |
| 116 | + |
| 117 | +def do_predict(args): |
| 118 | + place = "gpu" |
| 119 | + place = paddle.set_device(place) |
| 120 | + |
| 121 | + tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path) |
| 122 | + logger.info('Loading the model parameters, please wait...') |
| 123 | + model = BartForConditionalGeneration.from_pretrained( |
| 124 | + args.model_name_or_path) |
| 125 | + # Set evaluate mode |
| 126 | + model.eval() |
| 127 | + sentences = [ |
| 128 | + "I love that girl, but <mask> does not <mask> me.", |
| 129 | + "She is so <mask> that I can not help glance at <mask>.", |
| 130 | + "Nothing's gonna <mask> my love for you.", |
| 131 | + "Drop everything now. Meet me in the pouring <mask>. Kiss me on the sidewalk.", |
| 132 | + ] |
| 133 | + |
| 134 | + bos_id = model.bart.config['bos_token_id'] |
| 135 | + eos_id = model.bart.config['eos_token_id'] |
| 136 | + pad_id = model.bart.config['pad_token_id'] |
| 137 | + input_ids = prepare_input(tokenizer, sentences, pad_id) |
| 138 | + |
| 139 | + # Define model |
| 140 | + faster_bart = FasterBART( |
| 141 | + model=model, |
| 142 | + decoding_strategy=args.decoding_strategy, |
| 143 | + beam_size=args.beam_size, |
| 144 | + topk=args.topk, |
| 145 | + topp=args.topp, |
| 146 | + max_out_len=args.max_out_len, |
| 147 | + beam_search_diversity_rate=args.beam_search_diversity_rate, |
| 148 | + decoding_lib=args.decoding_lib, |
| 149 | + use_fp16_decoding=args.use_fp16_decoding, |
| 150 | + rel_len=args.rel_len, |
| 151 | + alpha=args.alpha) |
| 152 | + |
| 153 | + # Set evaluate mode |
| 154 | + faster_bart.eval() |
| 155 | + |
| 156 | + with paddle.no_grad(): |
| 157 | + for i in range(100): |
| 158 | + # For warmup. |
| 159 | + if 50 == i: |
| 160 | + paddle.fluid.core._cuda_synchronize(place) |
| 161 | + start = time.perf_counter() |
| 162 | + finished_seq = faster_bart(input_ids) |
| 163 | + paddle.fluid.core._cuda_synchronize(place) |
| 164 | + logger.info("Average test time for decoding is %f ms" % ( |
| 165 | + (time.perf_counter() - start) / 50 * 1000)) |
| 166 | + |
| 167 | + # Output |
| 168 | + if args.decoding_strategy.startswith('beam_search'): |
| 169 | + finished_seq = finished_seq.numpy().transpose([1, 2, 0]) |
| 170 | + for ins in finished_seq: |
| 171 | + for beam_idx, beam in enumerate(ins): |
| 172 | + if beam_idx >= args.n_best: |
| 173 | + break |
| 174 | + generated_ids = post_process_seq(beam, bos_id, eos_id) |
| 175 | + print(tokenizer.convert_ids_to_string(generated_ids)) |
| 176 | + elif args.decoding_strategy in ['topk_sampling', 'topp_sampling']: |
| 177 | + finished_seq = finished_seq.numpy().transpose([1, 0]) |
| 178 | + for ins in finished_seq: |
| 179 | + generated_ids = post_process_seq(ins, bos_id, eos_id) |
| 180 | + print(tokenizer.convert_ids_to_string(generated_ids)) |
| 181 | + |
| 182 | + |
| 183 | +if __name__ == "__main__": |
| 184 | + args = parse_args() |
| 185 | + pprint(args) |
| 186 | + do_predict(args) |
0 commit comments