@@ -281,7 +281,7 @@ def prepare_input_ids_for_generation(bos_token_id):
281
281
if bos_token_id is None :
282
282
raise ValueError ("`bos_token_id` should be defined when no "
283
283
"`input_ids` are provided." )
284
- return paddle .ones ([1 , 1 ]) * bos_token_id
284
+ return paddle .ones ([1 , 1 ], dtype = "int64" ) * bos_token_id
285
285
286
286
@staticmethod
287
287
def prepare_attention_mask_for_generation (input_ids , pad_token_id ,
@@ -341,7 +341,9 @@ def expand_inputs_for_generation(input_ids,
341
341
return input_ids , model_kwargs
342
342
343
343
@staticmethod
344
- def update_model_kwargs_for_generation (outputs , model_kwargs ):
344
+ def update_model_kwargs_for_generation (outputs ,
345
+ model_kwargs ,
346
+ is_encoder_decoder = False ):
345
347
# Update the model inputs during generation.
346
348
# Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
347
349
# and they contain pad value, the result vectors updated by this method
@@ -366,7 +368,7 @@ def update_model_kwargs_for_generation(outputs, model_kwargs):
366
368
axis = - 1 )
367
369
368
370
# update attention_mask
369
- if "attention_mask" in model_kwargs :
371
+ if not is_encoder_decoder and "attention_mask" in model_kwargs :
370
372
attention_mask = model_kwargs ["attention_mask" ]
371
373
# nn.Pad2D don't support the data type `bool`
372
374
if convert_dtype (attention_mask .dtype ) == 'bool' :
@@ -395,6 +397,22 @@ def update_scores_for_generation(scores, next_scores, length,
395
397
scores = paddle .where (unfinished_flag , unfinished_scores , scores )
396
398
return scores
397
399
400
+ def prepare_encoder_decoder_kwargs_for_generation (self , input_ids ,
401
+ model_kwargs ):
402
+ if "encoder_output" not in model_kwargs :
403
+ # retrieve encoder hidden states
404
+ encoder = self .get_encoder ()
405
+ encoder_kwargs = {
406
+ argument : value
407
+ for argument , value in model_kwargs .items ()
408
+ if not (argument .startswith ("decoder_" ) or argument .startswith (
409
+ "cross_attn" ))
410
+ }
411
+
412
+ model_kwargs ["encoder_output" ] = encoder (input_ids ,
413
+ ** encoder_kwargs )
414
+ return model_kwargs
415
+
398
416
def prepare_inputs_for_generation (self , input_ids , ** kwargs ):
399
417
# Implement in subclasses for custom behavior to prepare inputs in the
400
418
# generate method.
@@ -590,14 +608,22 @@ def generate(self,
590
608
model_kwargs [
591
609
"attention_mask" ] = self .prepare_attention_mask_for_generation (
592
610
input_ids , pad_token_id , eos_token_id )
611
+ self .is_encoder_decoder = hasattr (self , 'encoder' ) and hasattr (
612
+ self , 'decoder' )
613
+ if self .is_encoder_decoder :
614
+ model_kwargs = self .prepare_encoder_decoder_kwargs_for_generation (
615
+ input_ids , model_kwargs )
616
+ # set input_ids as decoder_input_ids
617
+ if "decoder_input_ids" in model_kwargs :
618
+ input_ids = model_kwargs .pop ("decoder_input_ids" )
619
+ else :
620
+ input_ids = self .prepare_input_ids_for_generation (bos_token_id )
593
621
594
622
if pad_token_id is None and eos_token_id is not None :
595
623
print ("Setting `pad_token_id` to `eos_token_id`:{} for "
596
624
"open-end generation." .format (eos_token_id ))
597
625
pad_token_id = eos_token_id
598
626
599
- # TODO Add relevant processing for encoder_decoder model.
600
-
601
627
model_kwargs ["use_cache" ] = use_cache
602
628
max_length += input_ids .shape [- 1 ]
603
629
min_length += input_ids .shape [- 1 ]
@@ -671,7 +697,6 @@ def greedy_search(self, input_ids, logits_processors, max_length,
671
697
logits = outputs [0 ] if isinstance (outputs , tuple ) else outputs
672
698
# [batch_size, vocab_size]
673
699
logits = logits [:, - 1 , :]
674
-
675
700
# pre-process distribution
676
701
logits = self .adjust_logits_during_generation (logits )
677
702
logits = logits_processors (input_ids , logits )
@@ -700,8 +725,10 @@ def greedy_search(self, input_ids, logits_processors, max_length,
700
725
if not paddle .any (unfinished_flag ):
701
726
break
702
727
703
- model_kwargs = self .update_model_kwargs_for_generation (outputs ,
704
- model_kwargs )
728
+ model_kwargs = self .update_model_kwargs_for_generation (
729
+ outputs ,
730
+ model_kwargs ,
731
+ is_encoder_decoder = self .is_encoder_decoder )
705
732
return input_ids [:, origin_len :], scores
706
733
707
734
def sample (self ,
@@ -801,8 +828,10 @@ def TopPProcess(probs, top_p, min_tokens_to_keep):
801
828
# Stop when there is a </s> in all sentences
802
829
if not paddle .any (unfinished_flag ):
803
830
break
804
- model_kwargs = self .update_model_kwargs_for_generation (outputs ,
805
- model_kwargs )
831
+ model_kwargs = self .update_model_kwargs_for_generation (
832
+ outputs ,
833
+ model_kwargs ,
834
+ is_encoder_decoder = self .is_encoder_decoder )
806
835
return input_ids [:, origin_len :], scores
807
836
808
837
def beam_search (self , input_ids , beam_scorer , logits_processors , max_length ,
@@ -876,8 +905,10 @@ def beam_search(self, input_ids, beam_scorer, logits_processors, max_length,
876
905
877
906
if beam_scorer .is_done :
878
907
break
879
- model_kwargs = self .update_model_kwargs_for_generation (outputs ,
880
- model_kwargs )
908
+ model_kwargs = self .update_model_kwargs_for_generation (
909
+ outputs ,
910
+ model_kwargs ,
911
+ is_encoder_decoder = self .is_encoder_decoder )
881
912
if model_kwargs ["cache" ] is not None :
882
913
# reorder the cache
883
914
model_kwargs ["cache" ] = map_structure (
0 commit comments