Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 1b72eba

Browse files
authored
Fixup generation utils for prototype release (#2065)
* Fixup generation utils for prototype release * Move test under integration b/c relies on T5 * fix lint
1 parent 7c63ea7 commit 1b72eba

File tree

2 files changed

+28
-25
lines changed

2 files changed

+28
-25
lines changed

test/torchtext_unittest/prototype/test_generate.py renamed to test/integration_tests/test_generate.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from torchtext.models import T5_BASE_GENERATION
5-
from torchtext.prototype.generate import DEFAULT_MAX_SEQ_LEN, GenerationUtil
5+
from torchtext.prototype.generate import DEFAULT_MAX_SEQ_LEN, GenerationUtils
66
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
77

88

@@ -26,9 +26,9 @@ def setUp(self) -> None:
2626
torch.manual_seed(0)
2727

2828
def test_greedy_generate_with_t5(self) -> None:
29-
generation_model = GenerationUtil(self.model)
29+
generation_model = GenerationUtils(self.model)
3030

31-
tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30)
31+
tokens = generation_model.generate(self.inputs, num_beams=1, max_length=30)
3232
generated_text = self.transform.decode(tokens.tolist())
3333

3434
expected_generated_text = [
@@ -42,13 +42,13 @@ def test_greedy_generate_with_t5(self) -> None:
4242
self.assertEqual(generated_text, expected_generated_text)
4343

4444
def test_generate_errors_with_incorrect_beams(self) -> None:
45-
generation_model = GenerationUtil(self.model, is_encoder_decoder=True)
45+
generation_model = GenerationUtils(self.model, is_encoder_decoder=True)
4646

4747
with self.assertRaises(ValueError):
4848
generation_model.generate(self.inputs, num_beams=0)
4949

5050
@patch("logging.Logger.warning")
5151
def test_warns_when_no_max_len_provided(self, mock) -> None:
52-
generation_model = GenerationUtil(self.model)
52+
generation_model = GenerationUtils(self.model)
5353
generation_model.generate(self.inputs)
54-
mock.assert_called_with(f"`max_len` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")
54+
mock.assert_called_with(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")

torchtext/prototype/generate.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
DEFAULT_MAX_SEQ_LEN = 256
1111

1212

13-
class GenerationUtil:
13+
class GenerationUtils:
1414
"""Wrapper to provide generation utils for encoder/decoder models and decoder models.
1515
1616
Example:
1717
>>> model = T5_BASE_GENERATION.get_model()
18-
>>> generative_model = GenerationUtil(model=model)
19-
>>> generative_model.generate(input_ids, num_beams=1, max_len=100)
18+
>>> generative_model = GenerationUtils(model=model)
19+
>>> generative_model.generate(input_ids, num_beams=1, max_length=100)
2020
2121
The wrapper can work with *any* model as long as it meets the following requirements:
2222
1. Is an encoder/decoder or decoder based model.
@@ -26,15 +26,18 @@ class GenerationUtil:
2626
>>> from transformers import T5Model
2727
>>> model = T5Model.from_pretrained("t5-base")
2828
>>> generative_model = GenerationUtils(model=model, is_huggingface_model=True)
29-
>>> generative_model.generate(input_ids, num_beams=1, max_len=100)
29+
>>> generative_model.generate(input_ids, num_beams=1, max_length=100)
30+
31+
`Note`: We cannot make any claims about the stability of APIs from HuggingFace so all models used from the `transformers`
32+
library are marked 'experimental.'
3033
3134
More examples can be found in the `notebooks` directory of this repository.
3235
"""
3336

34-
def __init__(self, model: nn.Module, is_encoder_decoder: bool = True, is_huggingface_model: bool = False) -> None:
37+
def __init__(self, model: nn.Module, **kwargs) -> None:
3538
self.model = model
36-
self.is_encoder_decoder = is_encoder_decoder
37-
self.is_huggingface_model = is_huggingface_model
39+
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True)
40+
self.is_huggingface_model = kwargs.pop("is_huggingface_model", False)
3841

3942
def _prepare_decoder_ids_for_generation(
4043
self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, **model_kwargs
@@ -45,13 +48,13 @@ def _prepare_decoder_ids_for_generation(
4548
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx
4649

4750
def greedy_search(
48-
self, input_ids: torch.Tensor, max_len: int, eos_idx: int, pad_idx: Optional[int] = None, **model_kwargs
51+
self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: Optional[int] = None, **model_kwargs
4952
) -> torch.Tensor:
5053
"""Greedy search decoding for text generation. Takes the most likely next token every time.
5154
5255
Inputs:
5356
input_ids (Tensor): Text prompt(s) for greedy generation.
54-
max_len (int): Max length to generate responses.
57+
max_length (int): Max length to generate responses.
5558
eos_idx (int): End of sequence index.
5659
pad_idx (int): Padding index.
5760
**model_kwargs
@@ -87,20 +90,20 @@ def greedy_search(
8790
if eos_idx is not None:
8891
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_idx).long())
8992

90-
# Stop iterating once all sequences are finished or exceed the max_len
91-
if unfinished_sequences.max() == 0 or len(input_ids[0]) >= max_len:
93+
# Stop iterating once all sequences are finished or exceed the max_length
94+
if unfinished_sequences.max() == 0 or len(input_ids[0]) >= max_length:
9295
break
9396

9497
return input_ids
9598

96-
def beam_search(self, input_ids: torch.Tensor, num_beams: int, max_len: Optional[int]) -> torch.Tensor:
99+
def beam_search(self, input_ids: torch.Tensor, num_beams: int, max_length: Optional[int]) -> torch.Tensor:
97100
raise NotImplementedError()
98101

99102
def generate(
100103
self,
101104
inputs: Optional[torch.Tensor] = None,
102105
num_beams: Optional[int] = None,
103-
max_len: Optional[int] = None,
106+
max_length: Optional[int] = None,
104107
pad_idx: int = 0,
105108
eos_idx: int = 1,
106109
) -> torch.Tensor:
@@ -112,7 +115,7 @@ def generate(
112115
Args:
113116
input_ids (Tensor): Ids of tokenized input tokens. The 'seed' text for generation.
114117
num_beams (int): If provided, specifies the number of beams to use in beam search generation.
115-
max_len (int): Max length to generate responses.
118+
max_length (int): Max length to generate responses.
116119
pad_idx (int): Padding index. Defaults to 0.
117120
eos_idx (int): End of sequence index. Defaults to 1.
118121
@@ -128,14 +131,14 @@ def generate(
128131
model_kwargs["encoder_outputs"] = encoder(inputs)
129132
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, **model_kwargs)
130133

131-
if max_len is None:
134+
if max_length is None:
132135
# Too hard to try to figure out the exact max_seq_length for each model
133-
logger.warning(f"`max_len` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")
134-
max_len = DEFAULT_MAX_SEQ_LEN
136+
logger.warning(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")
137+
max_length = DEFAULT_MAX_SEQ_LEN
135138

136139
if num_beams == 1 or num_beams is None:
137-
return self.greedy_search(inputs, max_len, eos_idx, pad_idx=pad_idx, **model_kwargs)
140+
return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, **model_kwargs)
138141
elif num_beams > 1:
139-
return self.beam_search(inputs, num_beams, max_len)
142+
return self.beam_search(inputs, num_beams, max_length)
140143
else:
141144
raise ValueError("`num_beams` must be >= 1.")

0 commit comments

Comments
 (0)