Skip to content

Commit 14b6f89

Browse files
authored
Generate api support encoder-decoder (PaddlePaddle#966)
* Add offset mapping doc * fix eval hang because of unique endpoint * generate api support encoder-decoder
1 parent 034cd53 commit 14b6f89

File tree

2 files changed

+97
-14
lines changed

2 files changed

+97
-14
lines changed

paddlenlp/transformers/bart/modeling.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,12 @@ def __init__(self,
310310
max_position_embeddings, init_std)
311311
self.apply(self.init_weights)
312312

313+
def get_encoder(self):
314+
return self.encoder
315+
316+
def get_decoder(self):
317+
return self.decoder
318+
313319
def forward(self,
314320
input_ids,
315321
attention_mask=None,
@@ -323,7 +329,6 @@ def forward(self,
323329
if input_ids is None and encoder_output is None:
324330
raise ValueError(
325331
"You have to specify either input_ids or encoder_output")
326-
327332
if decoder_input_ids is None:
328333
assert input_ids is not None, "input_ids should be " \
329334
"specified when generating decoder_input_ids"
@@ -450,6 +455,12 @@ def __init__(self, bart):
450455
paddle.zeros((1, self.bart.config['vocab_size'])))
451456
self.apply(self.init_weights)
452457

458+
def get_encoder(self):
459+
return self.bart.get_encoder()
460+
461+
def get_decoder(self):
462+
return self.bart.get_decoder()
463+
453464
def forward(self,
454465
input_ids,
455466
attention_mask=None,
@@ -465,5 +476,46 @@ def forward(self,
465476
output[0] if use_cache else output,
466477
self.lm_head_weight,
467478
transpose_y=True) + self.final_logits_bias
479+
if use_cache:
480+
cache = output[1]
481+
return lm_logits, cache
482+
else:
483+
return lm_logits
484+
485+
def prepare_inputs_for_generation(self,
486+
decoder_input_ids,
487+
attention_mask=None,
488+
decoder_attention_mask=None,
489+
cache=None,
490+
use_cache=False,
491+
encoder_output=None,
492+
**kwargs):
493+
# cut decoder_input_ids if past is used
494+
if cache is not None:
495+
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
496+
if decoder_attention_mask is not None:
497+
decoder_attention_mask = decoder_attention_mask[:, :,
498+
-1, :].unsqueeze(
499+
2)
500+
501+
return {
502+
"input_ids": None,
503+
"decoder_input_ids": decoder_input_ids,
504+
"encoder_output": encoder_output,
505+
"decoder_attention_mask": decoder_attention_mask,
506+
"attention_mask": attention_mask,
507+
"use_cache": use_cache,
508+
"cache": cache
509+
}
468510

469-
return lm_logits
511+
def __getattr__(self, name):
512+
try:
513+
return super().__getattr__(name)
514+
except AttributeError as e:
515+
try:
516+
return getattr(getattr(self, self.base_model_prefix), name)
517+
except AttributeError:
518+
try:
519+
return getattr(self, self.base_model_prefix).config[name]
520+
except KeyError:
521+
raise e

paddlenlp/transformers/generation_utils.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def prepare_input_ids_for_generation(bos_token_id):
281281
if bos_token_id is None:
282282
raise ValueError("`bos_token_id` should be defined when no "
283283
"`input_ids` are provided.")
284-
return paddle.ones([1, 1]) * bos_token_id
284+
return paddle.ones([1, 1], dtype="int64") * bos_token_id
285285

286286
@staticmethod
287287
def prepare_attention_mask_for_generation(input_ids, pad_token_id,
@@ -341,7 +341,9 @@ def expand_inputs_for_generation(input_ids,
341341
return input_ids, model_kwargs
342342

343343
@staticmethod
344-
def update_model_kwargs_for_generation(outputs, model_kwargs):
344+
def update_model_kwargs_for_generation(outputs,
345+
model_kwargs,
346+
is_encoder_decoder=False):
345347
# Update the model inputs during generation.
346348
# Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
347349
# and they contain pad value, the result vectors updated by this method
@@ -366,7 +368,7 @@ def update_model_kwargs_for_generation(outputs, model_kwargs):
366368
axis=-1)
367369

368370
# update attention_mask
369-
if "attention_mask" in model_kwargs:
371+
if not is_encoder_decoder and "attention_mask" in model_kwargs:
370372
attention_mask = model_kwargs["attention_mask"]
371373
# nn.Pad2D don't support the data type `bool`
372374
if convert_dtype(attention_mask.dtype) == 'bool':
@@ -395,6 +397,22 @@ def update_scores_for_generation(scores, next_scores, length,
395397
scores = paddle.where(unfinished_flag, unfinished_scores, scores)
396398
return scores
397399

400+
def prepare_encoder_decoder_kwargs_for_generation(self, input_ids,
401+
model_kwargs):
402+
if "encoder_output" not in model_kwargs:
403+
# retrieve encoder hidden states
404+
encoder = self.get_encoder()
405+
encoder_kwargs = {
406+
argument: value
407+
for argument, value in model_kwargs.items()
408+
if not (argument.startswith("decoder_") or argument.startswith(
409+
"cross_attn"))
410+
}
411+
412+
model_kwargs["encoder_output"] = encoder(input_ids,
413+
**encoder_kwargs)
414+
return model_kwargs
415+
398416
def prepare_inputs_for_generation(self, input_ids, **kwargs):
399417
# Implement in subclasses for custom behavior to prepare inputs in the
400418
# generate method.
@@ -590,14 +608,22 @@ def generate(self,
590608
model_kwargs[
591609
"attention_mask"] = self.prepare_attention_mask_for_generation(
592610
input_ids, pad_token_id, eos_token_id)
611+
self.is_encoder_decoder = hasattr(self, 'encoder') and hasattr(
612+
self, 'decoder')
613+
if self.is_encoder_decoder:
614+
model_kwargs = self.prepare_encoder_decoder_kwargs_for_generation(
615+
input_ids, model_kwargs)
616+
# set input_ids as decoder_input_ids
617+
if "decoder_input_ids" in model_kwargs:
618+
input_ids = model_kwargs.pop("decoder_input_ids")
619+
else:
620+
input_ids = self.prepare_input_ids_for_generation(bos_token_id)
593621

594622
if pad_token_id is None and eos_token_id is not None:
595623
print("Setting `pad_token_id` to `eos_token_id`:{} for "
596624
"open-end generation.".format(eos_token_id))
597625
pad_token_id = eos_token_id
598626

599-
# TODO Add relevant processing for encoder_decoder model.
600-
601627
model_kwargs["use_cache"] = use_cache
602628
max_length += input_ids.shape[-1]
603629
min_length += input_ids.shape[-1]
@@ -671,7 +697,6 @@ def greedy_search(self, input_ids, logits_processors, max_length,
671697
logits = outputs[0] if isinstance(outputs, tuple) else outputs
672698
# [batch_size, vocab_size]
673699
logits = logits[:, -1, :]
674-
675700
# pre-process distribution
676701
logits = self.adjust_logits_during_generation(logits)
677702
logits = logits_processors(input_ids, logits)
@@ -700,8 +725,10 @@ def greedy_search(self, input_ids, logits_processors, max_length,
700725
if not paddle.any(unfinished_flag):
701726
break
702727

703-
model_kwargs = self.update_model_kwargs_for_generation(outputs,
704-
model_kwargs)
728+
model_kwargs = self.update_model_kwargs_for_generation(
729+
outputs,
730+
model_kwargs,
731+
is_encoder_decoder=self.is_encoder_decoder)
705732
return input_ids[:, origin_len:], scores
706733

707734
def sample(self,
@@ -801,8 +828,10 @@ def TopPProcess(probs, top_p, min_tokens_to_keep):
801828
# Stop when there is a </s> in all sentences
802829
if not paddle.any(unfinished_flag):
803830
break
804-
model_kwargs = self.update_model_kwargs_for_generation(outputs,
805-
model_kwargs)
831+
model_kwargs = self.update_model_kwargs_for_generation(
832+
outputs,
833+
model_kwargs,
834+
is_encoder_decoder=self.is_encoder_decoder)
806835
return input_ids[:, origin_len:], scores
807836

808837
def beam_search(self, input_ids, beam_scorer, logits_processors, max_length,
@@ -876,8 +905,10 @@ def beam_search(self, input_ids, beam_scorer, logits_processors, max_length,
876905

877906
if beam_scorer.is_done:
878907
break
879-
model_kwargs = self.update_model_kwargs_for_generation(outputs,
880-
model_kwargs)
908+
model_kwargs = self.update_model_kwargs_for_generation(
909+
outputs,
910+
model_kwargs,
911+
is_encoder_decoder=self.is_encoder_decoder)
881912
if model_kwargs["cache"] is not None:
882913
# reorder the cache
883914
model_kwargs["cache"] = map_structure(

0 commit comments

Comments
 (0)