Skip to content

Commit 64b6f0f

Browse files
authored
Fix generate api bug and add diverse sibling search (PaddlePaddle#1041)
* Add offset mapping doc * fix eval hang because of unique endpoint * generate api support encoder-decoder * Add lightseq beam_search * optimize performence * add blockroughk kernel * optimize * minor fix * Fix generate api bug and add diverse sibling search * minor fix
1 parent a3db9c9 commit 64b6f0f

File tree

1 file changed

+60
-12
lines changed

1 file changed

+60
-12
lines changed

paddlenlp/transformers/generation_utils.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,14 @@ class GenerationMixin(object):
277277
"""
278278

279279
@staticmethod
280-
def prepare_input_ids_for_generation(bos_token_id):
280+
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
281+
batch_size = 1
281282
if bos_token_id is None:
282283
raise ValueError("`bos_token_id` should be defined when no "
283284
"`input_ids` are provided.")
284-
return paddle.ones([1, 1], dtype="int64") * bos_token_id
285+
if encoder_output is not None:
286+
batch_size = encoder_output.shape[0]
287+
return paddle.ones([batch_size, 1], dtype="int64") * bos_token_id
285288

286289
@staticmethod
287290
def prepare_attention_mask_for_generation(input_ids, pad_token_id,
@@ -338,6 +341,11 @@ def expand_inputs_for_generation(input_ids,
338341
seq_len = model_kwargs["seq_len"]
339342
model_kwargs["seq_len"] = paddle.index_select(seq_len, index)
340343

344+
if "encoder_output" in model_kwargs:
345+
encoder_output = model_kwargs["encoder_output"]
346+
model_kwargs["encoder_output"] = paddle.index_select(encoder_output,
347+
index)
348+
341349
return input_ids, model_kwargs
342350

343351
@staticmethod
@@ -441,6 +449,7 @@ def generate(self,
441449
eos_token_id=None,
442450
pad_token_id=None,
443451
num_return_sequences=1,
452+
diversity_rate=0.0,
444453
use_cache=True,
445454
**model_kwargs):
446455
r"""
@@ -489,6 +498,9 @@ def generate(self,
489498
None.
490499
num_return_sequences (int, optional): The number of returned
491500
sequences for each sequence in the batch. Default to 1.
501+
diversity_rate (float, optional): The diversity_rate for diverse
502+
siblings search. See this paper for more details.
503+
`https://arxiv.org/abs/1611.08562`.
492504
use_cache: (bool, optional): Whether or not use the model cache to
493505
speed up decoding. Default to True.
494506
model_kwargs (dict): It can be used to specify additional kwargs
@@ -617,7 +629,8 @@ def generate(self,
617629
if "decoder_input_ids" in model_kwargs:
618630
input_ids = model_kwargs.pop("decoder_input_ids")
619631
else:
620-
input_ids = self.prepare_input_ids_for_generation(bos_token_id)
632+
input_ids = self.prepare_input_ids_for_generation(
633+
bos_token_id, model_kwargs["encoder_output"])
621634

622635
if pad_token_id is None and eos_token_id is not None:
623636
print("Setting `pad_token_id` to `eos_token_id`:{} for "
@@ -673,8 +686,8 @@ def generate(self,
673686
input_ids, expand_size=num_beams, **model_kwargs)
674687

675688
return self.beam_search(input_ids, beam_scorer, logits_processors,
676-
max_length, pad_token_id, eos_token_id,
677-
**model_kwargs)
689+
max_length, diversity_rate, pad_token_id,
690+
eos_token_id, **model_kwargs)
678691

679692
else:
680693
raise ValueError(
@@ -835,7 +848,7 @@ def TopPProcess(probs, top_p, min_tokens_to_keep):
835848
return input_ids[:, origin_len:], scores
836849

837850
def beam_search(self, input_ids, beam_scorer, logits_processors, max_length,
838-
pad_token_id, eos_token_id, **model_kwargs):
851+
diversity_rate, pad_token_id, eos_token_id, **model_kwargs):
839852
batch_size = len(beam_scorer._beam_hyps)
840853
num_beams = beam_scorer.num_beams
841854

@@ -871,15 +884,50 @@ def beam_search(self, input_ids, beam_scorer, logits_processors, max_length,
871884
next_scores = paddle.log(next_scores)
872885

873886
next_scores = next_scores + beam_scores.unsqueeze(-1)
874-
# reshape for beam search
887+
875888
vocab_size = next_scores.shape[-1]
876-
next_scores = next_scores.reshape(
877-
[batch_size, num_beams * vocab_size])
889+
if diversity_rate == 0.0:
890+
# reshape for beam search
891+
next_scores = next_scores.reshape(
892+
[batch_size, num_beams * vocab_size])
878893

879-
next_scores, next_tokens = paddle.topk(
880-
next_scores, 2 * num_beams, axis=1)
894+
next_scores, next_tokens = paddle.topk(
895+
next_scores, 2 * num_beams, axis=1)
896+
897+
next_indices = next_tokens // vocab_size
898+
899+
else:
900+
next_scores, next_tokens = paddle.topk(
901+
next_scores, 2 * num_beams, axis=1)
902+
903+
sibling_score = paddle.tile(
904+
paddle.arange(1, 2 * num_beams + 1),
905+
repeat_times=[batch_size * num_beams, 1]) * diversity_rate
906+
907+
diversed_score = next_scores - sibling_score
908+
next_scores = next_scores.reshape(
909+
[batch_size, 2 * num_beams * num_beams])
910+
next_tokens = next_tokens.reshape(
911+
[batch_size, 2 * num_beams * num_beams])
912+
913+
diversed_score = diversed_score.reshape(
914+
[batch_size, 2 * num_beams * num_beams])
915+
diversed_score, diversed_tokens = paddle.topk(
916+
diversed_score, 2 * num_beams, axis=1)
917+
918+
# TODO
919+
# Use gather_nd() to select origan token and score
920+
next_scores = paddle.stack([
921+
paddle.index_select(next_scores[i], diversed_tokens[i])
922+
for i in range(next_scores.shape[0])
923+
])
924+
next_tokens = paddle.stack([
925+
paddle.index_select(next_tokens[i], diversed_tokens[i])
926+
for i in range(next_tokens.shape[0])
927+
])
928+
929+
next_indices = next_tokens // (2 * num_beams)
881930

882-
next_indices = next_tokens // vocab_size
883931
next_tokens = next_tokens % vocab_size
884932

885933
# stateless

0 commit comments

Comments
 (0)