@@ -277,11 +277,14 @@ class GenerationMixin(object):
277
277
"""
278
278
279
279
@staticmethod
280
- def prepare_input_ids_for_generation (bos_token_id ):
280
+ def prepare_input_ids_for_generation (bos_token_id , encoder_output = None ):
281
+ batch_size = 1
281
282
if bos_token_id is None :
282
283
raise ValueError ("`bos_token_id` should be defined when no "
283
284
"`input_ids` are provided." )
284
- return paddle .ones ([1 , 1 ], dtype = "int64" ) * bos_token_id
285
+ if encoder_output is not None :
286
+ batch_size = encoder_output .shape [0 ]
287
+ return paddle .ones ([batch_size , 1 ], dtype = "int64" ) * bos_token_id
285
288
286
289
@staticmethod
287
290
def prepare_attention_mask_for_generation (input_ids , pad_token_id ,
@@ -338,6 +341,11 @@ def expand_inputs_for_generation(input_ids,
338
341
seq_len = model_kwargs ["seq_len" ]
339
342
model_kwargs ["seq_len" ] = paddle .index_select (seq_len , index )
340
343
344
+ if "encoder_output" in model_kwargs :
345
+ encoder_output = model_kwargs ["encoder_output" ]
346
+ model_kwargs ["encoder_output" ] = paddle .index_select (encoder_output ,
347
+ index )
348
+
341
349
return input_ids , model_kwargs
342
350
343
351
@staticmethod
@@ -441,6 +449,7 @@ def generate(self,
441
449
eos_token_id = None ,
442
450
pad_token_id = None ,
443
451
num_return_sequences = 1 ,
452
+ diversity_rate = 0.0 ,
444
453
use_cache = True ,
445
454
** model_kwargs ):
446
455
r"""
@@ -489,6 +498,9 @@ def generate(self,
489
498
None.
490
499
num_return_sequences (int, optional): The number of returned
491
500
sequences for each sequence in the batch. Default to 1.
501
+ diversity_rate (float, optional): The diversity_rate for diverse
502
+ siblings search. See this paper for more details.
503
+ `https://arxiv.org/abs/1611.08562`.
492
504
use_cache: (bool, optional): Whether or not use the model cache to
493
505
speed up decoding. Default to True.
494
506
model_kwargs (dict): It can be used to specify additional kwargs
@@ -617,7 +629,8 @@ def generate(self,
617
629
if "decoder_input_ids" in model_kwargs :
618
630
input_ids = model_kwargs .pop ("decoder_input_ids" )
619
631
else :
620
- input_ids = self .prepare_input_ids_for_generation (bos_token_id )
632
+ input_ids = self .prepare_input_ids_for_generation (
633
+ bos_token_id , model_kwargs ["encoder_output" ])
621
634
622
635
if pad_token_id is None and eos_token_id is not None :
623
636
print ("Setting `pad_token_id` to `eos_token_id`:{} for "
@@ -673,8 +686,8 @@ def generate(self,
673
686
input_ids , expand_size = num_beams , ** model_kwargs )
674
687
675
688
return self .beam_search (input_ids , beam_scorer , logits_processors ,
676
- max_length , pad_token_id , eos_token_id ,
677
- ** model_kwargs )
689
+ max_length , diversity_rate , pad_token_id ,
690
+ eos_token_id , ** model_kwargs )
678
691
679
692
else :
680
693
raise ValueError (
@@ -835,7 +848,7 @@ def TopPProcess(probs, top_p, min_tokens_to_keep):
835
848
return input_ids [:, origin_len :], scores
836
849
837
850
def beam_search (self , input_ids , beam_scorer , logits_processors , max_length ,
838
- pad_token_id , eos_token_id , ** model_kwargs ):
851
+ diversity_rate , pad_token_id , eos_token_id , ** model_kwargs ):
839
852
batch_size = len (beam_scorer ._beam_hyps )
840
853
num_beams = beam_scorer .num_beams
841
854
@@ -871,15 +884,50 @@ def beam_search(self, input_ids, beam_scorer, logits_processors, max_length,
871
884
next_scores = paddle .log (next_scores )
872
885
873
886
next_scores = next_scores + beam_scores .unsqueeze (- 1 )
874
- # reshape for beam search
887
+
875
888
vocab_size = next_scores .shape [- 1 ]
876
- next_scores = next_scores .reshape (
877
- [batch_size , num_beams * vocab_size ])
889
+ if diversity_rate == 0.0 :
890
+ # reshape for beam search
891
+ next_scores = next_scores .reshape (
892
+ [batch_size , num_beams * vocab_size ])
878
893
879
- next_scores , next_tokens = paddle .topk (
880
- next_scores , 2 * num_beams , axis = 1 )
894
+ next_scores , next_tokens = paddle .topk (
895
+ next_scores , 2 * num_beams , axis = 1 )
896
+
897
+ next_indices = next_tokens // vocab_size
898
+
899
+ else :
900
+ next_scores , next_tokens = paddle .topk (
901
+ next_scores , 2 * num_beams , axis = 1 )
902
+
903
+ sibling_score = paddle .tile (
904
+ paddle .arange (1 , 2 * num_beams + 1 ),
905
+ repeat_times = [batch_size * num_beams , 1 ]) * diversity_rate
906
+
907
+ diversed_score = next_scores - sibling_score
908
+ next_scores = next_scores .reshape (
909
+ [batch_size , 2 * num_beams * num_beams ])
910
+ next_tokens = next_tokens .reshape (
911
+ [batch_size , 2 * num_beams * num_beams ])
912
+
913
+ diversed_score = diversed_score .reshape (
914
+ [batch_size , 2 * num_beams * num_beams ])
915
+ diversed_score , diversed_tokens = paddle .topk (
916
+ diversed_score , 2 * num_beams , axis = 1 )
917
+
918
+ # TODO
919
+ # Use gather_nd() to select origan token and score
920
+ next_scores = paddle .stack ([
921
+ paddle .index_select (next_scores [i ], diversed_tokens [i ])
922
+ for i in range (next_scores .shape [0 ])
923
+ ])
924
+ next_tokens = paddle .stack ([
925
+ paddle .index_select (next_tokens [i ], diversed_tokens [i ])
926
+ for i in range (next_tokens .shape [0 ])
927
+ ])
928
+
929
+ next_indices = next_tokens // (2 * num_beams )
881
930
882
- next_indices = next_tokens // vocab_size
883
931
next_tokens = next_tokens % vocab_size
884
932
885
933
# stateless
0 commit comments