Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit c008115

Browse files
committed
Separate encoding/decoding logic for T5 model in preparation for generation
1 parent 7c7b640 commit c008115

File tree

6 files changed

+306
-113
lines changed

6 files changed

+306
-113
lines changed

test/integration_tests/prototype/test_models.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,22 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text):
6666

6767
model_input = transform(test_text)
6868
if model.encoder_only:
69-
actual = model(model_input)["encoder_output"]
69+
actual = model(encoder_tokens=model_input)["encoder_output"]
70+
if not is_jit:
71+
self._t5_get_encoder(model, model_input, actual)
7072
else:
71-
actual = model(model_input)["decoder_output"]
73+
actual = model(encoder_tokens=model_input)["decoder_output"]
7274

7375
expected = torch.load(expected_asset_path)
7476
torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06)
7577

78+
def _t5_get_encoder(self, model, model_input, encoder_output):
79+
encoder = model.get_encoder()
80+
# Need to set the tgt_key_padding_mask to ensure the same results
81+
encoder_padding_mask = model_input.eq(model.padding_idx)
82+
output_from_get_encoder = encoder(tgt=model_input, tgt_key_padding_mask=encoder_padding_mask)["encoder_output"]
83+
assert torch.all(output_from_get_encoder.eq(encoder_output))
84+
7685
@nested_params(["jit", "not_jit"])
7786
def test_t5_model(self, name) -> None:
7887
configuration, type = self.model_name.split("_")
@@ -93,7 +102,8 @@ def test_t5_model(self, name) -> None:
93102
],
94103
)
95104
class TestT5Wrapper(TorchtextTestCase):
96-
@parameterized.expand(["jit", "not_jit"])
105+
# No longer Torchscriptable
106+
@parameterized.expand(["no_jit"])
97107
def test_t5_wrapper(self, name) -> None:
98108
configuration = self.configuration
99109
test_text = ["translate English to French: I want to eat pizza for dinner."]
@@ -113,7 +123,8 @@ def test_t5_wrapper(self, name) -> None:
113123

114124

115125
class TestT5WrapperCheckpoint(TorchtextTestCase):
116-
@parameterized.expand(["jit", "not_jit"])
126+
# No longer Torchscriptable
127+
@parameterized.expand(["no_jit"])
117128
def test_t5_wrapper_checkpoint(self, name) -> None:
118129
test_text = ["translate English to French: I want to eat pizza for dinner."]
119130
expected_text = ["Je veux manger de la pizza pour le dîner."]
@@ -127,7 +138,7 @@ def test_t5_wrapper_checkpoint(self, name) -> None:
127138
padding_idx=0,
128139
)
129140
model = T5Wrapper(
130-
checkpoint="https://download.pytorch.org/models/text/t5.base.generation.pt",
141+
checkpoint="https://download.pytorch.org/models/text/t5.base.generation.v2.pt",
131142
t5_config=config,
132143
transform=transform,
133144
freeze_model=True,

torchtext/prototype/generate.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from typing import Optional
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import nn
6+
7+
import logging
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class GenerationUtil:
13+
"""Wrapper to provide generation utils for encoder/decoder models and decoder models.
14+
15+
Example:
16+
>>> model = T5_BASE_GENERATION.get_model()
17+
>>> generative_model = GenerationUtil(model=model)
18+
>>> generative_model.generate(input_ids, num_beams=1, max_len=100)
19+
20+
The wrapper can work with *any* model as long as it meets the following requirements:
21+
1. Is an encoder/decoder or decoder based model.
22+
2. Includes a `get_encoder` method (if applicable) and a `prepare_inputs_for_generation` method.
23+
24+
This means that popular HuggingFace implementation of T5, Bart, and GPT-2 can all be used with these generation utils!
25+
>>> from transformers import T5Model
26+
>>> model = T5Model.from_pretrained("t5-base")
27+
>>> generative_model = GenerationUtil(model=model, is_huggingface_model=True)
28+
>>> generative_model.generate(input_ids, num_beams=1, max_len=100)
29+
30+
More examples can be found in the `notebooks` directory of this repository.
31+
"""
32+
def __init__(self, model: nn.Module, is_encoder_decoder: bool = True, is_huggingface_model: bool = False) -> None:
33+
self.model = model
34+
self.is_encoder_decoder = is_encoder_decoder
35+
self.is_huggingface_model = is_huggingface_model
36+
37+
def _prepare_decoder_ids_for_generation(
38+
self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, **model_kwargs
39+
):
40+
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
41+
return model_kwargs.pop("decoder_input_ids")
42+
else:
43+
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx
44+
45+
def greedy_search(
46+
self, input_ids: torch.Tensor, max_len: int, eos_idx: int, pad_idx: Optional[int] = None, **model_kwargs
47+
) -> torch.Tensor:
48+
"""Greedy search decoding for text generation. Takes the most likely next token every time.
49+
50+
Inputs:
51+
input_ids (Tensor): Text prompt(s) for greedy generation.
52+
max_len (int): Max length to generate responses.
53+
eos_idx (int): End of sequence index.
54+
pad_idx (int): Padding index.
55+
**model_kwargs
56+
57+
Returns:
58+
Batch of sequences decoded by greedy search.
59+
"""
60+
unfinished_sequences = torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long)
61+
62+
while True:
63+
model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
64+
if self.is_huggingface_model:
65+
model_inputs["return_dict"] = True
66+
model_inputs["output_hidden_states"] = True
67+
68+
# Get model output
69+
outputs = self.model(**model_inputs)
70+
output_key = "logits" if self.is_huggingface_model else "decoder_output"
71+
decoder_output = outputs[output_key]
72+
73+
# Calculate probabilities and take the most likely next token
74+
probs = F.log_softmax(decoder_output[:, -1], dim=-1)
75+
_, next_tokens = torch.topk(probs, 1)
76+
77+
# For any finished sequences, padding idx should be the last token
78+
if eos_idx is not None:
79+
if pad_idx is not None:
80+
next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences)
81+
82+
# Append the next tokens to the previous tokens
83+
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
84+
85+
if eos_idx is not None:
86+
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_idx).long())
87+
88+
# Stop iterating once all sequences are finished or exceed the max_len
89+
if unfinished_sequences.max() == 0 or len(input_ids[0]) >= max_len:
90+
break
91+
92+
return input_ids
93+
94+
def beam_search(self, input_ids: torch.Tensor, num_beams: int, max_len: Optional[int]) -> torch.Tensor:
95+
raise NotImplementedError()
96+
97+
def generate(
98+
self,
99+
inputs: Optional[torch.Tensor] = None,
100+
num_beams: Optional[int] = None,
101+
max_len: Optional[int] = None,
102+
pad_idx: int = 0,
103+
eos_idx: int = 1,
104+
) -> torch.Tensor:
105+
"""Generation method.
106+
107+
`num_beams` == 1 or `num_beams` is None -> greedy search
108+
`num_beams` > 1 -> beam search
109+
110+
Args:
111+
input_ids (Tensor): Ids of tokenized input tokens. The 'seed' text for generation.
112+
num_beams (int): If provided, specifies the number of beams to use in beam search generation.
113+
max_len (int): Max length to generate responses.
114+
pad_idx (int): Padding index. Defaults to 0.
115+
eos_idx (int): End of sequence index. Defaults to 1.
116+
117+
Returns:
118+
Tensor of Tensors containing output sequences as ids.
119+
120+
`Note`: If one beam is provided or no beams are specified, the generation method will default to greedy search.
121+
"""
122+
model_kwargs = {}
123+
124+
if self.is_encoder_decoder:
125+
encoder = self.model.get_encoder()
126+
model_kwargs["encoder_outputs"] = encoder(inputs)
127+
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, **model_kwargs)
128+
129+
if max_len is None:
130+
# Too hard to try to figure out the exact max_seq_length for each model
131+
logger.warning("`max_len` was not specified. Defaulting to 256 tokens.")
132+
max_len = 256
133+
134+
if num_beams == 1 or num_beams is None:
135+
return self.greedy_search(inputs, max_len, eos_idx, pad_idx=pad_idx, **model_kwargs)
136+
elif num_beams > 1:
137+
return self.beam_search(inputs, num_beams, max_len)
138+
else:
139+
raise ValueError("`num_beams` must be >= 1.")

torchtext/prototype/models/t5/bundler.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ def build_model_from_huggingface_ckpt(
176176

177177
t5_model_state_dict = {
178178
"token_embeddings.weight": hf_weights["shared.weight"],
179-
"norm1.weight": hf_weights["encoder.final_layer_norm.weight"],
179+
"encoder.token_embeddings.weight": hf_weights["shared.weight"],
180+
"encoder.norm.weight": hf_weights["encoder.final_layer_norm.weight"],
180181
"encoder.layers.0.self_attn.relative_attention_bias.weight": hf_weights[
181182
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
182183
],
@@ -210,7 +211,7 @@ def build_model_from_huggingface_ckpt(
210211

211212
# Convert decoder layers if model is encoder-decoder
212213
if not config.encoder_only:
213-
t5_model_state_dict["norm2.weight"] = hf_weights["decoder.final_layer_norm.weight"]
214+
t5_model_state_dict["decoder.norm.weight"] = hf_weights["decoder.final_layer_norm.weight"]
214215
t5_model_state_dict["decoder.layers.0.self_attn.relative_attention_bias.weight"] = hf_weights[
215216
"decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
216217
]
@@ -331,7 +332,7 @@ def config(self) -> T5Conf:
331332
"""
332333

333334
T5_BASE_ENCODER = T5Bundle(
334-
_path=urljoin(_TEXT_BUCKET, "t5.base.encoder.pt"),
335+
_path=urljoin(_TEXT_BUCKET, "t5.base.encoder.v2.pt"),
335336
_config=T5Conf(encoder_only=True),
336337
transform=lambda: T5Transform(
337338
urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"),
@@ -344,7 +345,7 @@ def config(self) -> T5Conf:
344345
T5_BASE_ENCODER.__doc__ = ENCODER_DOC.format("BASE", "base")
345346

346347
T5_BASE = T5Bundle(
347-
_path=urljoin(_TEXT_BUCKET, "t5.base.pt"),
348+
_path=urljoin(_TEXT_BUCKET, "t5.base.v2.pt"),
348349
_config=T5Conf(encoder_only=False),
349350
transform=lambda: T5Transform(
350351
urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"),
@@ -357,7 +358,7 @@ def config(self) -> T5Conf:
357358
T5_BASE.__doc__ = MODEL_DOC.format("BASE", "base")
358359

359360
T5_BASE_GENERATION = T5Bundle(
360-
_path=urljoin(_TEXT_BUCKET, "t5.base.generation.pt"),
361+
_path=urljoin(_TEXT_BUCKET, "t5.base.generation.v2.pt"),
361362
_config=T5Conf(encoder_only=False, linear_head=True),
362363
transform=lambda: T5Transform(
363364
urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"),
@@ -370,7 +371,7 @@ def config(self) -> T5Conf:
370371
T5_BASE_GENERATION.__doc__ = GENERATION_DOC.format("BASE", "base")
371372

372373
T5_SMALL_ENCODER = T5Bundle(
373-
_path=urljoin(_TEXT_BUCKET, "t5.small.encoder.pt"),
374+
_path=urljoin(_TEXT_BUCKET, "t5.small.encoder.v2.pt"),
374375
_config=T5Conf(
375376
encoder_only=True,
376377
embedding_dim=512,
@@ -391,7 +392,7 @@ def config(self) -> T5Conf:
391392

392393

393394
T5_SMALL = T5Bundle(
394-
_path=urljoin(_TEXT_BUCKET, "t5.small.pt"),
395+
_path=urljoin(_TEXT_BUCKET, "t5.small.v2.pt"),
395396
_config=T5Conf(
396397
encoder_only=False,
397398
embedding_dim=512,
@@ -411,7 +412,7 @@ def config(self) -> T5Conf:
411412
T5_SMALL.__doc__ = MODEL_DOC.format("SMALL", "small")
412413

413414
T5_SMALL_GENERATION = T5Bundle(
414-
_path=urljoin(_TEXT_BUCKET, "t5.small.generation.pt"),
415+
_path=urljoin(_TEXT_BUCKET, "t5.small.generation.v2.pt"),
415416
_config=T5Conf(
416417
encoder_only=False,
417418
linear_head=True,
@@ -432,7 +433,7 @@ def config(self) -> T5Conf:
432433
T5_SMALL_GENERATION.__doc__ = GENERATION_DOC.format("SMALL", "small")
433434

434435
T5_LARGE_ENCODER = T5Bundle(
435-
_path=urljoin(_TEXT_BUCKET, "t5.large.encoder.pt"),
436+
_path=urljoin(_TEXT_BUCKET, "t5.large.encoder.v2.pt"),
436437
_config=T5Conf(
437438
encoder_only=True,
438439
embedding_dim=1024,
@@ -452,7 +453,7 @@ def config(self) -> T5Conf:
452453
T5_LARGE_ENCODER.__doc__ = ENCODER_DOC.format("LARGE", "large")
453454

454455
T5_LARGE = T5Bundle(
455-
_path=urljoin(_TEXT_BUCKET, "t5.large.pt"),
456+
_path=urljoin(_TEXT_BUCKET, "t5.large.v2.pt"),
456457
_config=T5Conf(
457458
encoder_only=False,
458459
embedding_dim=1024,
@@ -472,7 +473,7 @@ def config(self) -> T5Conf:
472473
T5_LARGE.__doc__ = MODEL_DOC.format("LARGE", "large")
473474

474475
T5_LARGE_GENERATION = T5Bundle(
475-
_path=urljoin(_TEXT_BUCKET, "t5.large.generation.pt"),
476+
_path=urljoin(_TEXT_BUCKET, "t5.large.generation.v2.pt"),
476477
_config=T5Conf(
477478
encoder_only=False,
478479
linear_head=True,
@@ -493,7 +494,7 @@ def config(self) -> T5Conf:
493494
T5_LARGE_GENERATION.__doc__ = GENERATION_DOC.format("LARGE", "large")
494495

495496
T5_3B_ENCODER = T5Bundle(
496-
_path=urljoin(_TEXT_BUCKET, "t5.3b.encoder.pt"),
497+
_path=urljoin(_TEXT_BUCKET, "t5.3b.encoder.v2.pt"),
497498
_config=T5Conf(
498499
encoder_only=True,
499500
embedding_dim=1024,
@@ -514,7 +515,7 @@ def config(self) -> T5Conf:
514515
T5_3B_ENCODER.__doc__ = ENCODER_DOC.format("3B", "3B")
515516

516517
T5_3B = T5Bundle(
517-
_path=urljoin(_TEXT_BUCKET, "t5.3b.pt"),
518+
_path=urljoin(_TEXT_BUCKET, "t5.3b.v2.pt"),
518519
_config=T5Conf(
519520
encoder_only=False,
520521
embedding_dim=1024,
@@ -535,7 +536,7 @@ def config(self) -> T5Conf:
535536
T5_3B.__doc__ = MODEL_DOC.format("3B", "3B")
536537

537538
T5_3B_GENERATION = T5Bundle(
538-
_path=urljoin(_TEXT_BUCKET, "t5.3b.generation.pt"),
539+
_path=urljoin(_TEXT_BUCKET, "t5.3b.generation.v2.pt"),
539540
_config=T5Conf(
540541
encoder_only=False,
541542
linear_head=True,
@@ -557,7 +558,7 @@ def config(self) -> T5Conf:
557558
T5_3B_GENERATION.__doc__ = GENERATION_DOC.format("3B", "3B")
558559

559560
T5_11B_ENCODER = T5Bundle(
560-
_path=urljoin(_TEXT_BUCKET, "t5.11b.encoder.pt"),
561+
_path=urljoin(_TEXT_BUCKET, "t5.11b.encoder.v2.pt"),
561562
_config=T5Conf(
562563
encoder_only=True,
563564
embedding_dim=1024,
@@ -578,7 +579,7 @@ def config(self) -> T5Conf:
578579
T5_11B_ENCODER.__doc__ = ENCODER_DOC.format("11B", "11B")
579580

580581
T5_11B = T5Bundle(
581-
_path=urljoin(_TEXT_BUCKET, "t5.11b.pt"),
582+
_path=urljoin(_TEXT_BUCKET, "t5.11b.v2.pt"),
582583
_config=T5Conf(
583584
encoder_only=False,
584585
embedding_dim=1024,
@@ -599,7 +600,7 @@ def config(self) -> T5Conf:
599600
T5_11B.__doc__ = MODEL_DOC.format("11B", "11B")
600601

601602
T5_11B_GENERATION = T5Bundle(
602-
_path=urljoin(_TEXT_BUCKET, "t5.11b.generation.pt"),
603+
_path=urljoin(_TEXT_BUCKET, "t5.11b.generation.v2.pt"),
603604
_config=T5Conf(
604605
encoder_only=False,
605606
linear_head=True,

0 commit comments

Comments
 (0)