Skip to content

Commit 375df59

Browse files
authored
Expose beam search diversity rate for ft (PaddlePaddle#1072)
* expose beam search diversity rate for ft * rename param * fix generation api * fix bart
1 parent af7cf94 commit 375df59

File tree

5 files changed

+48
-24
lines changed

5 files changed

+48
-24
lines changed

examples/machine_translation/transformer/faster_transformer/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ python encoder_decoding_predict.py --config ../configs/transformer.base.yaml --d
139139
* 当使用 `topk_sampling` 的时候,需要指定 `--topk` 的值
140140
* 当使用 `topp_sampling` 的时候,需要指定 `topp` 的值,并且需要保证 `--topk` 的值为 0
141141
* `--beam_size`: 解码策略是 `beam_search` 的时候,beam size 的大小,数据类型是 `int`
142+
* `--diversity_rate`: 解码策略是 `beam_search` 的时候,设置 diversity rate 的大小,数据类型是 `float`。当设置的 `diversity_rate` 大于 0 的时候,FasterTransformer 仅支持 beam size 为 1,4,16,64
142143
* `--topk`: 解码策略是 `topk_sampling` 的时候,topk 计算的 k 值的大小,数据类型是 `int`
143144
* `--topp`: 解码策略是 `topp_sampling` 的时候,p 的大小,数据类型是 `float`
144145

examples/machine_translation/transformer/faster_transformer/encoder_decoding_predict.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def parse_args():
4545
help="Decoding strategy. Can be one of ['beam_search', 'topk_sampling', 'topp_sampling']. "
4646
)
4747
parser.add_argument("--beam_size", default=5, type=int, help="Beam size. ")
48+
parser.add_argument(
49+
"--diversity_rate",
50+
default=0.0,
51+
type=float,
52+
help="The diversity rate for beam search. ")
4853
parser.add_argument(
4954
"--topk",
5055
default=4,
@@ -144,6 +149,7 @@ def do_predict(args):
144149
decoding_strategy=args.decoding_strategy,
145150
beam_size=args.beam_size,
146151
max_out_len=args.max_out_len,
152+
diversity_rate=args.diversity_rate,
147153
decoding_lib=args.decoding_lib,
148154
use_fp16_decoding=args.use_fp16_decoding)
149155

@@ -206,6 +212,7 @@ def do_predict(args):
206212
args.use_fp16_decoding = ARGS.use_fp16_decoding
207213
args.decoding_strategy = ARGS.decoding_strategy
208214
args.beam_size = ARGS.beam_size
215+
args.diversity_rate = ARGS.diversity_rate
209216
args.topk = ARGS.topk
210217
args.topp = ARGS.topp
211218
args.profile = ARGS.profile

paddlenlp/ops/faster_transformer/sample/bart_decoding_sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def parse_args():
8181
parser.add_argument(
8282
"--max_out_len", default=50, type=int, help="Maximum output length. ")
8383
parser.add_argument(
84-
"--beam_search_diversity_rate",
84+
"--diversity_rate",
8585
default=0.0,
8686
type=float,
8787
help="The diversity of beam search. ")
@@ -144,7 +144,7 @@ def do_predict(args):
144144
topk=args.topk,
145145
topp=args.topp,
146146
max_out_len=args.max_out_len,
147-
beam_search_diversity_rate=args.beam_search_diversity_rate,
147+
diversity_rate=args.diversity_rate,
148148
decoding_lib=args.decoding_lib,
149149
use_fp16_decoding=args.use_fp16_decoding,
150150
rel_len=args.rel_len,

paddlenlp/ops/faster_transformer/transformer/decoding.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def infer_transformer_decoding(
3636
ffn_inter_bias, ffn_out_weight, ffn_out_bias, decoder_ln_weight,
3737
decoder_ln_bias, linear_weight, linear_bias, pos_emb,
3838
_decoding_strategy, _beam_size, _topk, _topp, _n_head, _size_per_head,
39-
_n_layer, _bos_id, _eos_id, _max_out_len, _beam_search_diversity_rate,
40-
_rel_len, _alpha):
39+
_n_layer, _bos_id, _eos_id, _max_out_len, _diversity_rate, _rel_len,
40+
_alpha):
4141
helper = LayerHelper('fusion_decoding', **locals())
4242

4343
inputs = {
@@ -88,7 +88,7 @@ def infer_transformer_decoding(
8888
'bos_id': _bos_id,
8989
'eos_id': _eos_id,
9090
'max_len': _max_out_len,
91-
'beam_search_diversity_rate': _beam_search_diversity_rate,
91+
'beam_search_diversity_rate': _diversity_rate,
9292
"rel_len": _rel_len,
9393
"alpha": _alpha
9494
}
@@ -175,8 +175,8 @@ def infer_unified_decoding(
175175
trans_weight, trans_bias, lm_ln_weight, lm_ln_bias, linear_weight,
176176
linear_bias, pos_emb, type_emb, _decoding_strategy, _beam_size, _topk,
177177
_topp, _n_head, _size_per_head, _n_layer, _bos_id, _eos_id,
178-
_max_out_len, _beam_search_diversity_rate, _unk_id, _mask_id,
179-
_temperature, _len_penalty, _normalize_before, _pos_bias, _hidden_act):
178+
_max_out_len, _diversity_rate, _unk_id, _mask_id, _temperature,
179+
_len_penalty, _normalize_before, _pos_bias, _hidden_act):
180180
helper = LayerHelper('fusion_unified_decoding', **locals())
181181

182182
inputs = {
@@ -225,7 +225,7 @@ def infer_unified_decoding(
225225
"bos_id": _bos_id,
226226
"eos_id": _eos_id,
227227
"max_len": _max_out_len,
228-
"beam_search_diversity_rate": _beam_search_diversity_rate,
228+
"beam_search_diversity_rate": _diversity_rate,
229229
"unk_id": _unk_id,
230230
"mask_id": _mask_id,
231231
"temperature": _temperature,
@@ -264,8 +264,8 @@ def infer_bart_decoding(
264264
ffn_inter_bias, ffn_out_weight, ffn_out_bias, decoder_ln_weight,
265265
decoder_ln_bias, linear_weight, linear_bias, pos_emb,
266266
_decoding_strategy, _beam_size, _topk, _topp, _n_head, _size_per_head,
267-
_n_layer, _bos_id, _eos_id, _max_out_len, _beam_search_diversity_rate,
268-
_rel_len, _alpha):
267+
_n_layer, _bos_id, _eos_id, _max_out_len, _diversity_rate, _rel_len,
268+
_alpha):
269269

270270
helper = LayerHelper('fusion_bart_decoding', **locals())
271271

@@ -317,7 +317,7 @@ def infer_bart_decoding(
317317
'bos_id': _bos_id,
318318
'eos_id': _eos_id,
319319
'max_len': _max_out_len,
320-
'beam_search_diversity_rate': _beam_search_diversity_rate,
320+
'beam_search_diversity_rate': _diversity_rate,
321321
"rel_len": _rel_len,
322322
"alpha": _alpha
323323
}
@@ -391,7 +391,7 @@ def __init__(self,
391391
topk=1,
392392
topp=0.0,
393393
max_out_len=256,
394-
beam_search_diversity_rate=0.0,
394+
diversity_rate=0.0,
395395
decoding_lib=None,
396396
use_fp16_decoding=False,
397397
rel_len=False,
@@ -564,8 +564,8 @@ def forward(self, enc_output, memory_seq_lens):
564564
self._decoding_strategy, self._beam_size, self._topk, self._topp,
565565
self._n_head,
566566
int(self._d_model / self._n_head), self._num_decoder_layers,
567-
self._bos_id, self._eos_id, self._max_out_len,
568-
self._beam_search_diversity_rate, self._rel_len, self._alpha)
567+
self._bos_id, self._eos_id, self._max_out_len, self._diversity_rate,
568+
self._rel_len, self._alpha)
569569

570570
ids = finalize(
571571
self._beam_size,
@@ -1048,7 +1048,7 @@ def forward(self,
10481048
eos_id=1,
10491049
temperature=1.0,
10501050
length_penalty=1.0,
1051-
beam_search_diversity_rate=0.0,
1051+
diversity_rate=0.0,
10521052
pos_bias=True):
10531053
output_ids, parent_ids, sequence_length = infer_unified_decoding(
10541054
cache_k=cache_k,
@@ -1093,7 +1093,7 @@ def forward(self,
10931093
_bos_id=bos_id,
10941094
_eos_id=eos_id,
10951095
_max_out_len=max_out_len,
1096-
_beam_search_diversity_rate=beam_search_diversity_rate,
1096+
_diversity_rate=diversity_rate,
10971097
_unk_id=self._unk_id,
10981098
_mask_id=self._mask_id,
10991099
_temperature=temperature,
@@ -1120,7 +1120,7 @@ def __init__(self,
11201120
topk=1,
11211121
topp=0.0,
11221122
max_out_len=256,
1123-
beam_search_diversity_rate=0.0,
1123+
diversity_rate=0.0,
11241124
decoding_lib=None,
11251125
use_fp16_decoding=False,
11261126
rel_len=False,
@@ -1321,8 +1321,8 @@ def forward(self, enc_output, memory_seq_lens):
13211321
self._decoding_strategy, self._beam_size, self._topk, self._topp,
13221322
self._n_head,
13231323
int(self._d_model / self._n_head), self._num_decoder_layers,
1324-
self._bos_id, self._eos_id, self._max_out_len,
1325-
self._beam_search_diversity_rate, self._rel_len, self._alpha)
1324+
self._bos_id, self._eos_id, self._max_out_len, self._diversity_rate,
1325+
self._rel_len, self._alpha)
13261326

13271327
ids = finalize(
13281328
self._beam_size,

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ class FasterTransformer(TransformerModel):
9191
`topp` are kept for top-p sampling. Defaults to 4.
9292
max_out_len (int, optional):
9393
The maximum output length. Defaults to 256.
94+
diversity_rate (float, optional):
95+
The diversity rate for beam search. Defaults to 0.0.
9496
use_fp16_decoding(bool, optional): Whether to use fp16 for decoding.
9597
rel_len(bool, optional):
9698
Indicating whether `max_out_len` in is the length relative to that
@@ -123,6 +125,7 @@ def __init__(self,
123125
topk=1,
124126
topp=0.0,
125127
max_out_len=256,
128+
diversity_rate=0.0,
126129
decoding_lib=None,
127130
use_fp16_decoding=False,
128131
rel_len=False,
@@ -141,6 +144,7 @@ def __init__(self,
141144
self.topk = args.pop("topk")
142145
self.topp = args.pop("topp")
143146
self.max_out_len = args.pop("max_out_len")
147+
self.diversity_rate = args.pop("diversity_rate")
144148
self.decoding_lib = args.pop("decoding_lib")
145149
self.use_fp16_decoding = args.pop("use_fp16_decoding")
146150
self.rel_len = args.pop("rel_len")
@@ -177,6 +181,7 @@ def __init__(self,
177181
topk=topk,
178182
topp=topp,
179183
max_out_len=max_out_len,
184+
diversity_rate=self.diversity_rate,
180185
decoding_lib=self.decoding_lib,
181186
use_fp16_decoding=self.use_fp16_decoding,
182187
rel_len=self.rel_len,
@@ -480,6 +485,9 @@ def __init__(self,
480485
self.d_model = d_model
481486
self.max_length = max_length
482487
self.output_time_major = kwargs.pop("output_time_major", True)
488+
# Only works for Faster Transformer.
489+
# TODO: original version supports diversity rate.
490+
diversity_rate = kwargs.pop("diversity_rate", 0.0)
483491
use_fp16_decoding = kwargs.pop("use_fp16_decoding", False)
484492
use_ft = kwargs.pop("use_ft", True)
485493
beam_search_version = kwargs.pop("beam_search_version", "v1")
@@ -507,6 +515,7 @@ def __init__(self,
507515
eos_id=eos_id,
508516
beam_size=beam_size,
509517
max_out_len=max_out_len,
518+
diversity_rate=diversity_rate,
510519
decoding_strategy=decoding_strategy,
511520
use_fp16_decoding=use_fp16_decoding,
512521
rel_len=rel_len,
@@ -786,7 +795,7 @@ def sample(self,
786795
temperature=temperature)
787796

788797
def beam_search(self, input_ids, beam_scorer, logits_processors, max_length,
789-
pad_token_id, eos_token_id, **model_kwargs):
798+
diversity_rate, pad_token_id, eos_token_id, **model_kwargs):
790799
max_length -= input_ids.shape[-1]
791800
model_inputs = self.prepare_inputs_for_generation(input_ids,
792801
**model_kwargs)
@@ -796,6 +805,7 @@ def beam_search(self, input_ids, beam_scorer, logits_processors, max_length,
796805
model_inputs=model_inputs,
797806
max_length=max_length,
798807
num_beams=beam_scorer.num_beams,
808+
diversity_rate=diversity_rate,
799809
temperature=temperature)
800810

801811
def forward(self,
@@ -804,6 +814,7 @@ def forward(self,
804814
top_k=4,
805815
top_p=0.0,
806816
num_beams=4,
817+
diversity_rate=0.0,
807818
temperature=1.0,
808819
model_inputs=None,
809820
**model_kwargs):
@@ -823,6 +834,7 @@ def forward(self,
823834
cache_v=cache_v,
824835
memory_seq_lens=seq_len,
825836
beam_size=num_beams,
837+
diversity_rate=diversity_rate,
826838
topk=top_k,
827839
topp=top_p,
828840
max_out_len=max_length,
@@ -946,7 +958,7 @@ def sample(self,
946958
temperature=temperature)
947959

948960
def beam_search(self, input_ids, beam_scorer, logits_processors, max_length,
949-
pad_token_id, eos_token_id, **model_kwargs):
961+
diversity_rate, pad_token_id, eos_token_id, **model_kwargs):
950962
max_length -= input_ids.shape[-1]
951963
model_inputs = self.prepare_inputs_for_generation(input_ids,
952964
**model_kwargs)
@@ -956,6 +968,7 @@ def beam_search(self, input_ids, beam_scorer, logits_processors, max_length,
956968
model_inputs=model_inputs,
957969
max_length=max_length,
958970
num_beams=beam_scorer.num_beams,
971+
diversity_rate=diversity_rate,
959972
temperature=temperature)
960973

961974
def forward(self,
@@ -964,6 +977,7 @@ def forward(self,
964977
top_k=4,
965978
top_p=0.0,
966979
num_beams=4,
980+
diversity_rate=0.0,
967981
temperature=1.0,
968982
model_inputs=None,
969983
**model_kwargs):
@@ -983,6 +997,7 @@ def forward(self,
983997
cache_v=cache_v,
984998
memory_seq_lens=seq_len,
985999
beam_size=num_beams,
1000+
diversity_rate=diversity_rate,
9861001
topk=top_k,
9871002
topp=top_p,
9881003
max_out_len=max_length,
@@ -1001,7 +1016,7 @@ def __init__(self,
10011016
topk=1,
10021017
topp=0.0,
10031018
max_out_len=256,
1004-
beam_search_diversity_rate=0.0,
1019+
diversity_rate=0.0,
10051020
decoding_lib=None,
10061021
use_fp16_decoding=False,
10071022
rel_len=False,
@@ -1023,7 +1038,7 @@ def __init__(self,
10231038
topk=topk,
10241039
topp=topp,
10251040
max_out_len=max_out_len,
1026-
beam_search_diversity_rate=beam_search_diversity_rate,
1041+
diversity_rate=diversity_rate,
10271042
decoding_lib=decoding_lib,
10281043
use_fp16_decoding=use_fp16_decoding)
10291044

@@ -1032,7 +1047,8 @@ def forward(self, input_ids):
10321047
mem_seq_lens = paddle.sum(paddle.cast(
10331048
input_ids != self.pad_id, dtype="int32"),
10341049
axis=-1,
1035-
keepdim=True)
1050+
keepdim=True,
1051+
dtype="int32")
10361052
if self.use_fp16_decoding:
10371053
encoder_output = paddle.cast(encoder_output, "float16")
10381054
return self.decoding(encoder_output, mem_seq_lens)

0 commit comments

Comments
 (0)