Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit de07b69

Browse files
committed
wip
1 parent 1def6fa commit de07b69

File tree

2 files changed

+115
-39
lines changed

2 files changed

+115
-39
lines changed

test/torchtext_unittest/prototype/test_generate.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,13 @@ def test_warns_when_no_max_len_provided(self, mock) -> None:
5151
generation_model = GenerationUtil(self.model)
5252
generation_model.generate(self.inputs)
5353
mock.assert_called_with("`max_len` was not specified. Defaulting to 256 tokens.")
54-
55-
def test_beam_search(self) -> None:
56-
generation_model = GenerationUtil(self.model)
5754

58-
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30, beam_size_token=self.model.config.vocab_size)
55+
def test_warns_when_mp_with_greedy(self, mock) -> None:
56+
pass
5957

58+
def test_beam_search_with_t5_(self) -> None:
59+
generation_model = GenerationUtil(self.model)
60+
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30, beam_size_token=self.model.config.vocab_size)
6061
generated_text = self.transform.decode(tokens.tolist())
6162

6263
expected_generated_text = [
@@ -70,3 +71,36 @@ def test_beam_search(self) -> None:
7071
self.assertEqual(generated_text, expected_generated_text)
7172

7273

74+
75+
def test_hf_DELETE(self) -> None:
76+
from transformers import T5ForConditionalGeneration, T5Tokenizer
77+
from torchtext.prototype.generate import GenerationUtil
78+
79+
t5 = T5ForConditionalGeneration.from_pretrained("t5-base")
80+
test_sequence = ["summarize: studies have shown that owning a dog is good for you"]#, "Q: what is the capital of Alaska?"]
81+
generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True)
82+
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
83+
test_sequence_tk = t5_tokenizer(test_sequence, padding=True, return_tensors="pt").input_ids
84+
import time
85+
86+
start = time.time()
87+
tokens = generative_hf_t5.generate(
88+
test_sequence_tk,
89+
max_len=100,
90+
pad_idx=t5.config.pad_token_id,
91+
num_beams=10,
92+
93+
94+
)
95+
end = time.time() - start
96+
print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end)
97+
exit()
98+
99+
def test_jit_generate(self) -> None:
100+
generation_model = GenerationUtil(self.model)
101+
torch.jit.script(generation_model)
102+
103+
104+
105+
def test_beam_search_speed(self) -> None:
106+
pass

torchtext/prototype/generate.py

Lines changed: 77 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Dict, List, Optional, Tuple
2+
from typing import Any, Dict, List, Optional, Tuple, Union
33

44
import torch
55
import torch.nn.functional as F
@@ -12,6 +12,8 @@
1212
get_obj_from_emitting_model_state,
1313
)
1414

15+
import logging
16+
import warnings
1517
logger = logging.getLogger(__name__)
1618

1719
DEFAULT_MAX_SEQ_LEN = 256
@@ -50,34 +52,52 @@ class GenerationUtils(nn.Module):
5052
More examples can be found in the `notebooks` directory of this repository.
5153
"""
5254

55+
_huggingface_model_input_values = {
56+
"return_dict": True,
57+
"use_cache": True,
58+
"output_hidden_states": True
59+
}
60+
5361
def __init__(self, model: nn.Module, **kwargs) -> None:
5462
super().__init__()
5563
self.model = model
5664
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True)
5765
self.is_huggingface_model = kwargs.pop("is_huggingface_model", False)
5866

59-
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs, model_kwargs):
60-
"""Modified from."""
67+
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
68+
"""Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592.
69+
70+
Args:
71+
inputs: (Tensor): Tokenized startings sequence(s).
72+
model_kwargs (Dict[str, Any]): Model keyword arguments to be modified for decoding.
73+
74+
Returns:
75+
Modified model_kwargs with addition of encoded input sequence(s).
76+
"""
6177
# Get encoder
6278
encoder = self.model.get_encoder()
6379

64-
# Prepare encoder args and encoder kwargs from model kwargs
65-
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
66-
encoder_kwargs = {}
67-
for argument, value in model_kwargs.items():
68-
if not any([argument.startswith(p) for p in irrelevant_prefix]):
69-
encoder_kwargs[argument] = value
80+
# Create copy of encoder kwargs
81+
encoder_kwargs = model_kwargs.copy()
7082

7183
# Forward pass
7284
if self.is_huggingface_model:
7385
encoder_kwargs["return_dict"] = True
86+
87+
# import pdb
88+
# pdb.set_trace()
89+
# print(encoder_kwargs.keys())
90+
91+
# assert torch.jit.isinstance(encoder_kwargs, Optional[Dict[str, bool]])
92+
7493
model_kwargs["encoder_outputs"] = encoder(inputs, **encoder_kwargs)
7594

7695
return model_kwargs
7796

7897
def _prepare_decoder_ids_for_generation(
7998
self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, model_kwargs: Optional[Dict[str, Any]] = None
8099
):
100+
"""Prepare decoder IDs for generation."""
81101
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
82102
return model_kwargs.pop("decoder_input_ids")
83103
else:
@@ -87,16 +107,23 @@ def _update_model_kwargs_for_generation(
87107
self,
88108
outputs: Dict[str, Any],
89109
model_kwargs: Dict[str, Any],
90-
is_encoder_decoder: bool = False,
91110
) -> Dict[str, Any]:
92-
"""Modified from."""
111+
"""After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
112+
113+
Args:
114+
outputs (Dict[str, Any]): LM output.
115+
model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
116+
117+
Returns:
118+
Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
119+
"""
93120
# Update past
94121
if "past_key_values" in outputs:
95-
model_kwargs["past"] = outputs.past_key_values
122+
model_kwargs["past"] = outputs["past_key_values"]
96123
elif "mems" in outputs:
97-
model_kwargs["past"] = outputs.mems
124+
model_kwargs["past"] = outputs["mems"]
98125
elif "past_buckets_states" in outputs:
99-
model_kwargs["past"] = outputs.past_buckets_states
126+
model_kwargs["past"] = outputs["past_buckets_states"]
100127
else:
101128
model_kwargs["past"] = None
102129

@@ -105,13 +132,19 @@ def _update_model_kwargs_for_generation(
105132
token_type_ids = model_kwargs["token_type_ids"]
106133
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
107134

108-
# Update attention mask
109-
if not is_encoder_decoder:
135+
if not self.is_encoder_decoder:
110136
if "attention_mask" in model_kwargs:
111137
attention_mask = model_kwargs["attention_mask"]
112138
model_kwargs["attention_mask"] = torch.cat(
113139
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
114140
)
141+
else:
142+
if "decoder_attention_mask" in model_kwargs:
143+
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
144+
model_kwargs["decoder_attention_mask"] = torch.cat(
145+
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
146+
dim=-1,
147+
)
115148

116149
return model_kwargs
117150

@@ -135,9 +168,7 @@ def greedy_search(
135168
while True:
136169
model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
137170
if self.is_huggingface_model:
138-
model_inputs["return_dict"] = True
139-
model_inputs["use_cache"] = True
140-
model_inputs["output_hidden_states"] = True
171+
model_inputs.update(self._huggingface_model_input_values)
141172

142173
# Get model output
143174
outputs = self.model(**model_inputs)
@@ -177,7 +208,7 @@ def beam_search(
177208
eos_idx: int,
178209
num_python_workers: int,
179210
max_inference_batch_size: int,
180-
model_kwargs,
211+
model_kwargs: Dict[str, Any],
181212
) -> torch.Tensor:
182213
"""Beam search implemented using Flashlight Text (https://github.com/flashlight/text).
183214
@@ -260,26 +291,32 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
260291
num_samples if timestep > 0 else 1, -1, -1
261292
)
262293

263-
# Forward pass
294+
# Preprocess inputs for generation
264295
model_inputs = self.model.prepare_inputs_for_generation(state_and_tokens, **new_model_kwargs)
265-
print(model_inputs.get("use_cache"), model_inputs.get("past_key_values"))
266-
267296
if self.is_huggingface_model:
268-
model_inputs["return_dict"] = True
269-
model_inputs["use_cache"] = True
270-
model_inputs["output_hidden_states"] = True
297+
model_inputs.update(self._huggingface_model_input_values)
271298

272-
print(model_inputs.get("use_cache"), model_inputs.get("past_key_values"))
299+
from typing import MappingProxyType
273300

301+
model_inputs = MappingProxyType(model_inputs)
302+
# Forward pass
274303
outputs = self.model(**model_inputs)
304+
305+
# Collect outputs
275306
output_key = "logits" if self.is_huggingface_model else "decoder_output"
276307
lm_scores = outputs[output_key]
277308

278309
# HF optimizations to reduce overhead in future `forward` calls
279310
if self.is_huggingface_model:
280311
new_model_kwargs = self._update_model_kwargs_for_generation(outputs, new_model_kwargs, is_encoder_decoder=self.is_encoder_decoder)
281312
if new_model_kwargs["past"] is not None:
282-
new_model_kwargs["past"] = self.model._reorder_cache(new_model_kwargs["past"], torch.Tensor(num_samples).to(dtype=torch.int32, device=self.model.device))
313+
import pdb
314+
pdb.set_trace()
315+
beam_indices += [start for _ in range(num_samples)]
316+
new_model_kwargs["past"] = self.model._reorder_cache(
317+
new_model_kwargs["past"],
318+
torch.Tensor(beam_indices).to(dtype=torch.int32) # I think this is correct?
319+
)
283320

284321
# Keep track of probabilities over vocab for this pairing
285322
# TODO: clean up duplicate code in these branches
@@ -305,7 +342,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
305342
)
306343
)
307344
)
308-
345+
309346
start += step
310347

311348
return out_probs, model_states
@@ -378,6 +415,8 @@ def forward(
378415
num_python_workers: int = 1,
379416
max_inference_batch_size: int = 16,
380417
):
418+
"""Calls self.generate() method."""
419+
warnings.warn("Forward method simply calls `GenerationUtils.generate()`. Please use generate method directly.")
381420
return self.generate(
382421
inputs=inputs,
383422
num_beams=num_beams,
@@ -391,41 +430,42 @@ def forward(
391430
max_inference_batch_size=max_inference_batch_size,
392431
)
393432

394-
395433
def generate(
396434
self,
397435
inputs: Optional[torch.Tensor] = None,
398436
num_beams: Optional[int] = None,
399437
max_length: Optional[int] = None,
400438
pad_idx: int = 0,
401439
eos_idx: int = 1,
440+
num_python_workers: int = 1,
402441
beam_threshold: int = 100,
403442
beam_size_token: Optional[int] = None,
404443
eos_score: float = -1.0,
405-
num_python_workers: int = 1,
406444
max_inference_batch_size: int = 16,
407445
) -> torch.Tensor:
408-
"""Generation method.
446+
"""Entrypoint generation method.
409447
410448
Args:
411449
input_ids (Tensor): Ids of tokenized input tokens. The 'seed' text for generation.
412450
num_beams (int): If provided, specifies the number of beams to use in beam search generation.
413451
max_length (int): Max length to generate responses.
414452
pad_idx (int): Padding index. Defaults to 0.
415453
eos_idx (int): End-of-sequence index. Defaults to 1.
454+
num_python_workers (int): If > 1, using multiprocessing on CPU.
416455
beam_size_token (int): Vocab size for the beam search algo to evaluate, can typically default to vocab size of the model.
417456
beam_threshold (int): Threshold before pruning; specific to beam search.
418457
eos_score (float): Score to input when `eos_idx` is generated; specific to beam search.
458+
max_inference_batch_size (int): In beam search, to avoid OOMs, can choose to batch smaller amounts of hypothesis; defaults to 16.
419459
420460
Returns:
421461
Tensor of Tensors containing output sequences as ids.
422462
423-
Conditions for generation: \
424-
1. `num_beams` == 1 or `num_beams` is None -> greedy search \
463+
Conditions for generation:
464+
1. `num_beams` == 1 or `num_beams` is None -> greedy search
425465
2. `num_beams` > 1 -> beam search
426466
"""
427467
model_kwargs = {}
428-
468+
429469
if self.is_encoder_decoder:
430470
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs, model_kwargs)
431471
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, model_kwargs=model_kwargs)
@@ -436,6 +476,8 @@ def generate(
436476
max_length = DEFAULT_MAX_SEQ_LEN
437477

438478
if num_beams == 1 or num_beams is None:
479+
if num_python_workers > 1:
480+
logger.warning(f"Multiprocessing is not implemented for greedy search.")
439481
return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, model_kwargs=model_kwargs)
440482
elif num_beams > 1:
441483
if beam_size_token is None:

0 commit comments

Comments
 (0)