Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 150ca1f

Browse files
committed
wip
1 parent 7a95b93 commit 150ca1f

File tree

3 files changed

+114
-41
lines changed

3 files changed

+114
-41
lines changed

test/torchtext_unittest/prototype/test_generate.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,13 @@ def test_warns_when_no_max_len_provided(self, mock) -> None:
5151
generation_model = GenerationUtil(self.model)
5252
generation_model.generate(self.inputs)
5353
mock.assert_called_with("`max_len` was not specified. Defaulting to 256 tokens.")
54-
55-
def test_beam_search(self) -> None:
56-
generation_model = GenerationUtil(self.model)
5754

58-
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30, beam_size_token=self.model.config.vocab_size)
55+
def test_warns_when_mp_with_greedy(self, mock) -> None:
56+
pass
5957

58+
def test_beam_search_with_t5_(self) -> None:
59+
generation_model = GenerationUtil(self.model)
60+
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30, beam_size_token=self.model.config.vocab_size)
6061
generated_text = self.transform.decode(tokens.tolist())
6162

6263
expected_generated_text = [
@@ -70,3 +71,36 @@ def test_beam_search(self) -> None:
7071
self.assertEqual(generated_text, expected_generated_text)
7172

7273

74+
75+
def test_hf_DELETE(self) -> None:
76+
from transformers import T5ForConditionalGeneration, T5Tokenizer
77+
from torchtext.prototype.generate import GenerationUtil
78+
79+
t5 = T5ForConditionalGeneration.from_pretrained("t5-base")
80+
test_sequence = ["summarize: studies have shown that owning a dog is good for you"]#, "Q: what is the capital of Alaska?"]
81+
generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True)
82+
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
83+
test_sequence_tk = t5_tokenizer(test_sequence, padding=True, return_tensors="pt").input_ids
84+
import time
85+
86+
start = time.time()
87+
tokens = generative_hf_t5.generate(
88+
test_sequence_tk,
89+
max_len=100,
90+
pad_idx=t5.config.pad_token_id,
91+
num_beams=10,
92+
93+
94+
)
95+
end = time.time() - start
96+
print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end)
97+
exit()
98+
99+
def test_jit_generate(self) -> None:
100+
generation_model = GenerationUtil(self.model)
101+
torch.jit.script(generation_model)
102+
103+
104+
105+
def test_beam_search_speed(self) -> None:
106+
pass

torchtext/prototype/generate.py

Lines changed: 76 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Dict, List, Optional, Tuple
2+
from typing import Any, Dict, List, Optional, Tuple, Union
33

44
import torch
55
import torch.nn.functional as F
@@ -13,7 +13,7 @@
1313
)
1414

1515
import logging
16-
16+
import warnings
1717
logger = logging.getLogger(__name__)
1818

1919

@@ -47,34 +47,52 @@ class GenerationUtil(nn.Module):
4747
More examples can be found in the `notebooks` directory of this repository.
4848
"""
4949

50+
_huggingface_model_input_values = {
51+
"return_dict": True,
52+
"use_cache": True,
53+
"output_hidden_states": True
54+
}
55+
5056
def __init__(self, model: nn.Module, **kwargs) -> None:
5157
super().__init__()
5258
self.model = model
5359
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True)
5460
self.is_huggingface_model = kwargs.pop("is_huggingface_model", False)
5561

56-
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs, model_kwargs):
57-
"""Modified from."""
62+
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
63+
"""Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592.
64+
65+
Args:
66+
inputs: (Tensor): Tokenized startings sequence(s).
67+
model_kwargs (Dict[str, Any]): Model keyword arguments to be modified for decoding.
68+
69+
Returns:
70+
Modified model_kwargs with addition of encoded input sequence(s).
71+
"""
5872
# Get encoder
5973
encoder = self.model.get_encoder()
6074

61-
# Prepare encoder args and encoder kwargs from model kwargs
62-
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
63-
encoder_kwargs = {}
64-
for argument, value in model_kwargs.items():
65-
if not any([argument.startswith(p) for p in irrelevant_prefix]):
66-
encoder_kwargs[argument] = value
75+
# Create copy of encoder kwargs
76+
encoder_kwargs = model_kwargs.copy()
6777

6878
# Forward pass
6979
if self.is_huggingface_model:
7080
encoder_kwargs["return_dict"] = True
81+
82+
# import pdb
83+
# pdb.set_trace()
84+
# print(encoder_kwargs.keys())
85+
86+
# assert torch.jit.isinstance(encoder_kwargs, Optional[Dict[str, bool]])
87+
7188
model_kwargs["encoder_outputs"] = encoder(inputs, **encoder_kwargs)
7289

7390
return model_kwargs
7491

7592
def _prepare_decoder_ids_for_generation(
7693
self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, model_kwargs: Optional[Dict[str, Any]] = None
7794
):
95+
"""Prepare decoder IDs for generation."""
7896
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
7997
return model_kwargs.pop("decoder_input_ids")
8098
else:
@@ -84,16 +102,23 @@ def _update_model_kwargs_for_generation(
84102
self,
85103
outputs: Dict[str, Any],
86104
model_kwargs: Dict[str, Any],
87-
is_encoder_decoder: bool = False,
88105
) -> Dict[str, Any]:
89-
"""Modified from."""
106+
"""After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
107+
108+
Args:
109+
outputs (Dict[str, Any]): LM output.
110+
model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
111+
112+
Returns:
113+
Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
114+
"""
90115
# Update past
91116
if "past_key_values" in outputs:
92-
model_kwargs["past"] = outputs.past_key_values
117+
model_kwargs["past"] = outputs["past_key_values"]
93118
elif "mems" in outputs:
94-
model_kwargs["past"] = outputs.mems
119+
model_kwargs["past"] = outputs["mems"]
95120
elif "past_buckets_states" in outputs:
96-
model_kwargs["past"] = outputs.past_buckets_states
121+
model_kwargs["past"] = outputs["past_buckets_states"]
97122
else:
98123
model_kwargs["past"] = None
99124

@@ -102,13 +127,19 @@ def _update_model_kwargs_for_generation(
102127
token_type_ids = model_kwargs["token_type_ids"]
103128
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
104129

105-
# Update attention mask
106-
if not is_encoder_decoder:
130+
if not self.is_encoder_decoder:
107131
if "attention_mask" in model_kwargs:
108132
attention_mask = model_kwargs["attention_mask"]
109133
model_kwargs["attention_mask"] = torch.cat(
110134
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
111135
)
136+
else:
137+
if "decoder_attention_mask" in model_kwargs:
138+
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
139+
model_kwargs["decoder_attention_mask"] = torch.cat(
140+
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
141+
dim=-1,
142+
)
112143

113144
return model_kwargs
114145

@@ -132,9 +163,7 @@ def greedy_search(
132163
while True:
133164
model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
134165
if self.is_huggingface_model:
135-
model_inputs["return_dict"] = True
136-
model_inputs["use_cache"] = True
137-
model_inputs["output_hidden_states"] = True
166+
model_inputs.update(self._huggingface_model_input_values)
138167

139168
# Get model output
140169
outputs = self.model(**model_inputs)
@@ -174,7 +203,7 @@ def beam_search(
174203
eos_idx: int,
175204
num_python_workers: int,
176205
max_inference_batch_size: int,
177-
model_kwargs,
206+
model_kwargs: Dict[str, Any],
178207
) -> torch.Tensor:
179208
"""Beam search implemented using Flashlight Text (https://github.com/flashlight/text).
180209
@@ -257,26 +286,32 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
257286
num_samples if timestep > 0 else 1, -1, -1
258287
)
259288

260-
# Forward pass
289+
# Preprocess inputs for generation
261290
model_inputs = self.model.prepare_inputs_for_generation(state_and_tokens, **new_model_kwargs)
262-
print(model_inputs.get("use_cache"), model_inputs.get("past_key_values"))
263-
264291
if self.is_huggingface_model:
265-
model_inputs["return_dict"] = True
266-
model_inputs["use_cache"] = True
267-
model_inputs["output_hidden_states"] = True
292+
model_inputs.update(self._huggingface_model_input_values)
268293

269-
print(model_inputs.get("use_cache"), model_inputs.get("past_key_values"))
294+
from typing import MappingProxyType
270295

296+
model_inputs = MappingProxyType(model_inputs)
297+
# Forward pass
271298
outputs = self.model(**model_inputs)
299+
300+
# Collect outputs
272301
output_key = "logits" if self.is_huggingface_model else "decoder_output"
273302
lm_scores = outputs[output_key]
274303

275304
# HF optimizations to reduce overhead in future `forward` calls
276305
if self.is_huggingface_model:
277306
new_model_kwargs = self._update_model_kwargs_for_generation(outputs, new_model_kwargs, is_encoder_decoder=self.is_encoder_decoder)
278307
if new_model_kwargs["past"] is not None:
279-
new_model_kwargs["past"] = self.model._reorder_cache(new_model_kwargs["past"], torch.Tensor(num_samples).to(dtype=torch.int32, device=self.model.device))
308+
import pdb
309+
pdb.set_trace()
310+
beam_indices += [start for _ in range(num_samples)]
311+
new_model_kwargs["past"] = self.model._reorder_cache(
312+
new_model_kwargs["past"],
313+
torch.Tensor(beam_indices).to(dtype=torch.int32) # I think this is correct?
314+
)
280315

281316
# Keep track of probabilities over vocab for this pairing
282317
# TODO: clean up duplicate code in these branches
@@ -302,7 +337,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
302337
)
303338
)
304339
)
305-
340+
306341
start += step
307342

308343
return out_probs, model_states
@@ -375,6 +410,8 @@ def forward(
375410
num_python_workers: int = 1,
376411
max_inference_batch_size: int = 16,
377412
):
413+
"""Calls self.generate() method."""
414+
warnings.warn("Forward method simply calls `GenerationUtils.generate()`. Please use generate method directly.")
378415
return self.generate(
379416
inputs=inputs,
380417
num_beams=num_beams,
@@ -388,41 +425,42 @@ def forward(
388425
max_inference_batch_size=max_inference_batch_size,
389426
)
390427

391-
392428
def generate(
393429
self,
394430
inputs: Optional[torch.Tensor] = None,
395431
num_beams: Optional[int] = None,
396432
max_len: Optional[int] = None,
397433
pad_idx: int = 0,
398434
eos_idx: int = 1,
435+
num_python_workers: int = 1,
399436
beam_threshold: int = 100,
400437
beam_size_token: Optional[int] = None,
401438
eos_score: float = -1.0,
402-
num_python_workers: int = 1,
403439
max_inference_batch_size: int = 16,
404440
) -> torch.Tensor:
405-
"""Generation method.
441+
"""Entrypoint generation method.
406442
407443
Args:
408444
input_ids (Tensor): Ids of tokenized input tokens. The 'seed' text for generation.
409445
num_beams (int): If provided, specifies the number of beams to use in beam search generation.
410446
max_len (int): Max length to generate responses.
411447
pad_idx (int): Padding index. Defaults to 0.
412448
eos_idx (int): End-of-sequence index. Defaults to 1.
449+
num_python_workers (int): If > 1, using multiprocessing on CPU.
413450
beam_size_token (int): Vocab size for the beam search algo to evaluate, can typically default to vocab size of the model.
414451
beam_threshold (int): Threshold before pruning; specific to beam search.
415452
eos_score (float): Score to input when `eos_idx` is generated; specific to beam search.
453+
max_inference_batch_size (int): In beam search, to avoid OOMs, can choose to batch smaller amounts of hypothesis; defaults to 16.
416454
417455
Returns:
418456
Tensor of Tensors containing output sequences as ids.
419457
420-
Conditions for generation: \
421-
1. `num_beams` == 1 or `num_beams` is None -> greedy search \
458+
Conditions for generation:
459+
1. `num_beams` == 1 or `num_beams` is None -> greedy search
422460
2. `num_beams` > 1 -> beam search
423461
"""
424462
model_kwargs = {}
425-
463+
426464
if self.is_encoder_decoder:
427465
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs, model_kwargs)
428466
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, model_kwargs=model_kwargs)
@@ -433,6 +471,8 @@ def generate(
433471
max_len = 256
434472

435473
if num_beams == 1 or num_beams is None:
474+
if num_python_workers > 1:
475+
logger.warning(f"Multiprocessing is not implemented for greedy search.")
436476
return self.greedy_search(inputs, max_len, eos_idx, pad_idx=pad_idx, model_kwargs=model_kwargs)
437477
elif num_beams > 1:
438478
if beam_size_token is None:

torchtext/prototype/models/t5/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def __init__(
137137
def prepare_inputs_for_generation(self, input_ids, encoder_outputs):
138138
return {"decoder_tokens": input_ids, "encoder_outputs": encoder_outputs}
139139

140-
@torch.jit.ignore
141140
def get_encoder(self) -> T5Encoder:
142141
return self.encoder
143142

0 commit comments

Comments
 (0)