10
10
DEFAULT_MAX_SEQ_LEN = 256
11
11
12
12
13
- class GenerationUtil :
13
+ class GenerationUtils :
14
14
"""Wrapper to provide generation utils for encoder/decoder models and decoder models.
15
15
16
16
Example:
17
17
>>> model = T5_BASE_GENERATION.get_model()
18
- >>> generative_model = GenerationUtil (model=model)
19
- >>> generative_model.generate(input_ids, num_beams=1, max_len =100)
18
+ >>> generative_model = GenerationUtils (model=model)
19
+ >>> generative_model.generate(input_ids, num_beams=1, max_length =100)
20
20
21
21
The wrapper can work with *any* model as long as it meets the following requirements:
22
22
1. Is an encoder/decoder or decoder based model.
@@ -26,15 +26,18 @@ class GenerationUtil:
26
26
>>> from transformers import T5Model
27
27
>>> model = T5Model.from_pretrained("t5-base")
28
28
>>> generative_model = GenerationUtils(model=model, is_huggingface_model=True)
29
- >>> generative_model.generate(input_ids, num_beams=1, max_len=100)
29
+ >>> generative_model.generate(input_ids, num_beams=1, max_length=100)
30
+
31
+ `Note`: We cannot make any claims about the stability of APIs from HuggingFace so all models used from the `transformers`
32
+ library are marked 'experimental.'
30
33
31
34
More examples can be found in the `notebooks` directory of this repository.
32
35
"""
33
36
34
- def __init__ (self , model : nn .Module , is_encoder_decoder : bool = True , is_huggingface_model : bool = False ) -> None :
37
+ def __init__ (self , model : nn .Module , ** kwargs ) -> None :
35
38
self .model = model
36
- self .is_encoder_decoder = is_encoder_decoder
37
- self .is_huggingface_model = is_huggingface_model
39
+ self .is_encoder_decoder = kwargs . pop ( " is_encoder_decoder" , True )
40
+ self .is_huggingface_model = kwargs . pop ( " is_huggingface_model" , False )
38
41
39
42
def _prepare_decoder_ids_for_generation (
40
43
self , batch_size : int , pad_idx : int = 0 , device : Optional [torch .device ] = None , ** model_kwargs
@@ -45,13 +48,13 @@ def _prepare_decoder_ids_for_generation(
45
48
return torch .ones ((batch_size , 1 ), dtype = torch .long , device = device ) * pad_idx
46
49
47
50
def greedy_search (
48
- self , input_ids : torch .Tensor , max_len : int , eos_idx : int , pad_idx : Optional [int ] = None , ** model_kwargs
51
+ self , input_ids : torch .Tensor , max_length : int , eos_idx : int , pad_idx : Optional [int ] = None , ** model_kwargs
49
52
) -> torch .Tensor :
50
53
"""Greedy search decoding for text generation. Takes the most likely next token every time.
51
54
52
55
Inputs:
53
56
input_ids (Tensor): Text prompt(s) for greedy generation.
54
- max_len (int): Max length to generate responses.
57
+ max_length (int): Max length to generate responses.
55
58
eos_idx (int): End of sequence index.
56
59
pad_idx (int): Padding index.
57
60
**model_kwargs
@@ -87,20 +90,20 @@ def greedy_search(
87
90
if eos_idx is not None :
88
91
unfinished_sequences = unfinished_sequences .mul ((next_tokens != eos_idx ).long ())
89
92
90
- # Stop iterating once all sequences are finished or exceed the max_len
91
- if unfinished_sequences .max () == 0 or len (input_ids [0 ]) >= max_len :
93
+ # Stop iterating once all sequences are finished or exceed the max_length
94
+ if unfinished_sequences .max () == 0 or len (input_ids [0 ]) >= max_length :
92
95
break
93
96
94
97
return input_ids
95
98
96
- def beam_search (self , input_ids : torch .Tensor , num_beams : int , max_len : Optional [int ]) -> torch .Tensor :
99
+ def beam_search (self , input_ids : torch .Tensor , num_beams : int , max_length : Optional [int ]) -> torch .Tensor :
97
100
raise NotImplementedError ()
98
101
99
102
def generate (
100
103
self ,
101
104
inputs : Optional [torch .Tensor ] = None ,
102
105
num_beams : Optional [int ] = None ,
103
- max_len : Optional [int ] = None ,
106
+ max_length : Optional [int ] = None ,
104
107
pad_idx : int = 0 ,
105
108
eos_idx : int = 1 ,
106
109
) -> torch .Tensor :
@@ -112,7 +115,7 @@ def generate(
112
115
Args:
113
116
input_ids (Tensor): Ids of tokenized input tokens. The 'seed' text for generation.
114
117
num_beams (int): If provided, specifies the number of beams to use in beam search generation.
115
- max_len (int): Max length to generate responses.
118
+ max_length (int): Max length to generate responses.
116
119
pad_idx (int): Padding index. Defaults to 0.
117
120
eos_idx (int): End of sequence index. Defaults to 1.
118
121
@@ -128,14 +131,14 @@ def generate(
128
131
model_kwargs ["encoder_outputs" ] = encoder (inputs )
129
132
inputs = self ._prepare_decoder_ids_for_generation (len (inputs ), device = inputs .device , ** model_kwargs )
130
133
131
- if max_len is None :
134
+ if max_length is None :
132
135
# Too hard to try to figure out the exact max_seq_length for each model
133
- logger .warning (f"`max_len ` was not specified. Defaulting to { DEFAULT_MAX_SEQ_LEN } tokens." )
134
- max_len = DEFAULT_MAX_SEQ_LEN
136
+ logger .warning (f"`max_length ` was not specified. Defaulting to { DEFAULT_MAX_SEQ_LEN } tokens." )
137
+ max_length = DEFAULT_MAX_SEQ_LEN
135
138
136
139
if num_beams == 1 or num_beams is None :
137
- return self .greedy_search (inputs , max_len , eos_idx , pad_idx = pad_idx , ** model_kwargs )
140
+ return self .greedy_search (inputs , max_length , eos_idx , pad_idx = pad_idx , ** model_kwargs )
138
141
elif num_beams > 1 :
139
- return self .beam_search (inputs , num_beams , max_len )
142
+ return self .beam_search (inputs , num_beams , max_length )
140
143
else :
141
144
raise ValueError ("`num_beams` must be >= 1." )
0 commit comments