@@ -281,7 +281,7 @@ def prepare_input_ids_for_generation(bos_token_id):
281281 if bos_token_id is None :
282282 raise ValueError ("`bos_token_id` should be defined when no "
283283 "`input_ids` are provided." )
284- return paddle .ones ([1 , 1 ]) * bos_token_id
284+ return paddle .ones ([1 , 1 ], dtype = "int64" ) * bos_token_id
285285
286286 @staticmethod
287287 def prepare_attention_mask_for_generation (input_ids , pad_token_id ,
@@ -341,7 +341,9 @@ def expand_inputs_for_generation(input_ids,
341341 return input_ids , model_kwargs
342342
343343 @staticmethod
344- def update_model_kwargs_for_generation (outputs , model_kwargs ):
344+ def update_model_kwargs_for_generation (outputs ,
345+ model_kwargs ,
346+ is_encoder_decoder = False ):
345347 # Update the model inputs during generation.
346348 # Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
347349 # and they contain pad value, the result vectors updated by this method
@@ -366,7 +368,7 @@ def update_model_kwargs_for_generation(outputs, model_kwargs):
366368 axis = - 1 )
367369
368370 # update attention_mask
369- if "attention_mask" in model_kwargs :
371+ if not is_encoder_decoder and "attention_mask" in model_kwargs :
370372 attention_mask = model_kwargs ["attention_mask" ]
371373 # nn.Pad2D don't support the data type `bool`
372374 if convert_dtype (attention_mask .dtype ) == 'bool' :
@@ -395,6 +397,22 @@ def update_scores_for_generation(scores, next_scores, length,
395397 scores = paddle .where (unfinished_flag , unfinished_scores , scores )
396398 return scores
397399
400+ def prepare_encoder_decoder_kwargs_for_generation (self , input_ids ,
401+ model_kwargs ):
402+ if "encoder_output" not in model_kwargs :
403+ # retrieve encoder hidden states
404+ encoder = self .get_encoder ()
405+ encoder_kwargs = {
406+ argument : value
407+ for argument , value in model_kwargs .items ()
408+ if not (argument .startswith ("decoder_" ) or argument .startswith (
409+ "cross_attn" ))
410+ }
411+
412+ model_kwargs ["encoder_output" ] = encoder (input_ids ,
413+ ** encoder_kwargs )
414+ return model_kwargs
415+
398416 def prepare_inputs_for_generation (self , input_ids , ** kwargs ):
399417 # Implement in subclasses for custom behavior to prepare inputs in the
400418 # generate method.
@@ -590,14 +608,22 @@ def generate(self,
590608 model_kwargs [
591609 "attention_mask" ] = self .prepare_attention_mask_for_generation (
592610 input_ids , pad_token_id , eos_token_id )
611+ self .is_encoder_decoder = hasattr (self , 'encoder' ) and hasattr (
612+ self , 'decoder' )
613+ if self .is_encoder_decoder :
614+ model_kwargs = self .prepare_encoder_decoder_kwargs_for_generation (
615+ input_ids , model_kwargs )
616+ # set input_ids as decoder_input_ids
617+ if "decoder_input_ids" in model_kwargs :
618+ input_ids = model_kwargs .pop ("decoder_input_ids" )
619+ else :
620+ input_ids = self .prepare_input_ids_for_generation (bos_token_id )
593621
594622 if pad_token_id is None and eos_token_id is not None :
595623 print ("Setting `pad_token_id` to `eos_token_id`:{} for "
596624 "open-end generation." .format (eos_token_id ))
597625 pad_token_id = eos_token_id
598626
599- # TODO Add relevant processing for encoder_decoder model.
600-
601627 model_kwargs ["use_cache" ] = use_cache
602628 max_length += input_ids .shape [- 1 ]
603629 min_length += input_ids .shape [- 1 ]
@@ -671,7 +697,6 @@ def greedy_search(self, input_ids, logits_processors, max_length,
671697 logits = outputs [0 ] if isinstance (outputs , tuple ) else outputs
672698 # [batch_size, vocab_size]
673699 logits = logits [:, - 1 , :]
674-
675700 # pre-process distribution
676701 logits = self .adjust_logits_during_generation (logits )
677702 logits = logits_processors (input_ids , logits )
@@ -700,8 +725,10 @@ def greedy_search(self, input_ids, logits_processors, max_length,
700725 if not paddle .any (unfinished_flag ):
701726 break
702727
703- model_kwargs = self .update_model_kwargs_for_generation (outputs ,
704- model_kwargs )
728+ model_kwargs = self .update_model_kwargs_for_generation (
729+ outputs ,
730+ model_kwargs ,
731+ is_encoder_decoder = self .is_encoder_decoder )
705732 return input_ids [:, origin_len :], scores
706733
707734 def sample (self ,
@@ -801,8 +828,10 @@ def TopPProcess(probs, top_p, min_tokens_to_keep):
801828 # Stop when there is a </s> in all sentences
802829 if not paddle .any (unfinished_flag ):
803830 break
804- model_kwargs = self .update_model_kwargs_for_generation (outputs ,
805- model_kwargs )
831+ model_kwargs = self .update_model_kwargs_for_generation (
832+ outputs ,
833+ model_kwargs ,
834+ is_encoder_decoder = self .is_encoder_decoder )
806835 return input_ids [:, origin_len :], scores
807836
808837 def beam_search (self , input_ids , beam_scorer , logits_processors , max_length ,
@@ -876,8 +905,10 @@ def beam_search(self, input_ids, beam_scorer, logits_processors, max_length,
876905
877906 if beam_scorer .is_done :
878907 break
879- model_kwargs = self .update_model_kwargs_for_generation (outputs ,
880- model_kwargs )
908+ model_kwargs = self .update_model_kwargs_for_generation (
909+ outputs ,
910+ model_kwargs ,
911+ is_encoder_decoder = self .is_encoder_decoder )
881912 if model_kwargs ["cache" ] is not None :
882913 # reorder the cache
883914 model_kwargs ["cache" ] = map_structure (
0 commit comments