Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 62f5ca8

Browse files
committed
format
1 parent 150ca1f commit 62f5ca8

File tree

2 files changed

+55
-41
lines changed

2 files changed

+55
-41
lines changed

test/torchtext_unittest/prototype/test_generate.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,29 @@ def test_warns_when_mp_with_greedy(self, mock) -> None:
5757

5858
def test_beam_search_with_t5_(self) -> None:
5959
generation_model = GenerationUtil(self.model)
60-
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30, beam_size_token=self.model.config.vocab_size)
60+
tokens = generation_model.generate(
61+
self.inputs, num_beams=3, max_len=30, beam_size_token=self.model.config.vocab_size
62+
)
6163
generated_text = self.transform.decode(tokens.tolist())
6264

6365
expected_generated_text = [
64-
'kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for',
65-
'Das ist gut.',
66-
'acceptable',
67-
'4.0',
68-
'a tornado ripped through a swath of a lake in st. louis . a s'
66+
"kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for",
67+
"Das ist gut.",
68+
"acceptable",
69+
"4.0",
70+
"a tornado ripped through a swath of a lake in st. louis . a s",
6971
]
7072

7173
self.assertEqual(generated_text, expected_generated_text)
7274

73-
74-
7575
def test_hf_DELETE(self) -> None:
7676
from transformers import T5ForConditionalGeneration, T5Tokenizer
7777
from torchtext.prototype.generate import GenerationUtil
7878

7979
t5 = T5ForConditionalGeneration.from_pretrained("t5-base")
80-
test_sequence = ["summarize: studies have shown that owning a dog is good for you"]#, "Q: what is the capital of Alaska?"]
80+
test_sequence = [
81+
"summarize: studies have shown that owning a dog is good for you"
82+
] # , "Q: what is the capital of Alaska?"]
8183
generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True)
8284
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
8385
test_sequence_tk = t5_tokenizer(test_sequence, padding=True, return_tensors="pt").input_ids
@@ -89,18 +91,14 @@ def test_hf_DELETE(self) -> None:
8991
max_len=100,
9092
pad_idx=t5.config.pad_token_id,
9193
num_beams=10,
92-
93-
9494
)
9595
end = time.time() - start
9696
print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end)
9797
exit()
98-
98+
9999
def test_jit_generate(self) -> None:
100100
generation_model = GenerationUtil(self.model)
101101
torch.jit.script(generation_model)
102-
103102

104-
105103
def test_beam_search_speed(self) -> None:
106104
pass

torchtext/prototype/generate.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import logging
1616
import warnings
17+
1718
logger = logging.getLogger(__name__)
1819

1920

@@ -47,25 +48,23 @@ class GenerationUtil(nn.Module):
4748
More examples can be found in the `notebooks` directory of this repository.
4849
"""
4950

50-
_huggingface_model_input_values = {
51-
"return_dict": True,
52-
"use_cache": True,
53-
"output_hidden_states": True
54-
}
51+
_huggingface_model_input_values = {"return_dict": True, "use_cache": True, "output_hidden_states": True}
5552

5653
def __init__(self, model: nn.Module, **kwargs) -> None:
5754
super().__init__()
5855
self.model = model
5956
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True)
6057
self.is_huggingface_model = kwargs.pop("is_huggingface_model", False)
61-
62-
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
58+
59+
def _prepare_encoder_decoder_kwargs_for_generation(
60+
self, inputs: torch.Tensor, model_kwargs: Dict[str, Any]
61+
) -> Dict[str, Any]:
6362
"""Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592.
6463
6564
Args:
6665
inputs: (Tensor): Tokenized startings sequence(s).
6766
model_kwargs (Dict[str, Any]): Model keyword arguments to be modified for decoding.
68-
67+
6968
Returns:
7069
Modified model_kwargs with addition of encoded input sequence(s).
7170
"""
@@ -78,19 +77,23 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor, m
7877
# Forward pass
7978
if self.is_huggingface_model:
8079
encoder_kwargs["return_dict"] = True
81-
80+
8281
# import pdb
8382
# pdb.set_trace()
8483
# print(encoder_kwargs.keys())
85-
84+
8685
# assert torch.jit.isinstance(encoder_kwargs, Optional[Dict[str, bool]])
87-
86+
8887
model_kwargs["encoder_outputs"] = encoder(inputs, **encoder_kwargs)
8988

9089
return model_kwargs
9190

9291
def _prepare_decoder_ids_for_generation(
93-
self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, model_kwargs: Optional[Dict[str, Any]] = None
92+
self,
93+
batch_size: int,
94+
pad_idx: int = 0,
95+
device: Optional[torch.device] = None,
96+
model_kwargs: Optional[Dict[str, Any]] = None,
9497
):
9598
"""Prepare decoder IDs for generation."""
9699
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
@@ -108,7 +111,7 @@ def _update_model_kwargs_for_generation(
108111
Args:
109112
outputs (Dict[str, Any]): LM output.
110113
model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
111-
114+
112115
Returns:
113116
Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
114117
"""
@@ -144,7 +147,12 @@ def _update_model_kwargs_for_generation(
144147
return model_kwargs
145148

146149
def greedy_search(
147-
self, input_ids: torch.Tensor, max_len: int, eos_idx: int, pad_idx: Optional[int] = None, model_kwargs: Optional[Dict[str, Any]] = {}
150+
self,
151+
input_ids: torch.Tensor,
152+
max_len: int,
153+
eos_idx: int,
154+
pad_idx: Optional[int] = None,
155+
model_kwargs: Optional[Dict[str, Any]] = {},
148156
) -> torch.Tensor:
149157
"""Greedy search decoding for text generation. Takes the most likely next token every time.
150158
@@ -217,7 +225,7 @@ def beam_search(
217225
eos_idx (int): End-of-sequence index.
218226
num_python_workers (int): Number of python workers to use for multiprocessing.
219227
model_kwargs
220-
228+
221229
Returns:
222230
Tensor of the generated sequences.
223231
"""
@@ -227,9 +235,9 @@ def beam_search(
227235

228236
def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, timestep):
229237
# `emissions` and `N` are unused in this current implementation
230-
238+
231239
i = T # Hacky access to the current seq in inputs
232-
240+
233241
# Copy over the `model_kwargs` in order to modify
234242
new_model_kwargs = model_kwargs.copy()
235243

@@ -254,18 +262,22 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
254262
max_inference_batch_size, 1000 / (timestep + 1)
255263
) # many hypotheses will EOS, so increase the batch size gradually
256264
curr_beam_size = len(prev_step_token_idxs)
257-
265+
258266
# 2. Batched inference to get next tokens
259267
while start < curr_beam_size: # catch the remainder
260268
end = start + step
261269
if end > curr_beam_size:
262270
end = curr_beam_size
263271

264-
num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
272+
num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
265273

266274
if prev_step_token_idxs != [-1]:
267275
state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0)
268-
token_indices = torch.Tensor(prev_step_token_idxs[start:end]).to(dtype=torch.long, device=self.model.device).reshape(num_samples, 1)
276+
token_indices = (
277+
torch.Tensor(prev_step_token_idxs[start:end])
278+
.to(dtype=torch.long, device=self.model.device)
279+
.reshape(num_samples, 1)
280+
)
269281

270282
state_and_tokens = torch.cat(
271283
[state_sequences, token_indices], dim=-1
@@ -303,14 +315,17 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
303315

304316
# HF optimizations to reduce overhead in future `forward` calls
305317
if self.is_huggingface_model:
306-
new_model_kwargs = self._update_model_kwargs_for_generation(outputs, new_model_kwargs, is_encoder_decoder=self.is_encoder_decoder)
318+
new_model_kwargs = self._update_model_kwargs_for_generation(
319+
outputs, new_model_kwargs, is_encoder_decoder=self.is_encoder_decoder
320+
)
307321
if new_model_kwargs["past"] is not None:
308322
import pdb
323+
309324
pdb.set_trace()
310325
beam_indices += [start for _ in range(num_samples)]
311326
new_model_kwargs["past"] = self.model._reorder_cache(
312327
new_model_kwargs["past"],
313-
torch.Tensor(beam_indices).to(dtype=torch.int32) # I think this is correct?
328+
torch.Tensor(beam_indices).to(dtype=torch.int32), # I think this is correct?
314329
)
315330

316331
# Keep track of probabilities over vocab for this pairing
@@ -337,7 +352,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
337352
)
338353
)
339354
)
340-
355+
341356
start += step
342357

343358
return out_probs, model_states
@@ -392,11 +407,10 @@ def is_not_neg_one(elem: int) -> bool:
392407
logger.warning("Multiprocessing has not yet been implemented.")
393408

394409
all_final_tokens = [beam_decode_step(i) for i in range(len(input_ids))]
395-
410+
396411
# 5. Return top hypotheses for all input sequences
397412
return torch.stack(all_final_tokens, dim=0)
398413

399-
400414
def forward(
401415
self,
402416
inputs: Optional[torch.Tensor] = None,
@@ -460,10 +474,12 @@ def generate(
460474
2. `num_beams` > 1 -> beam search
461475
"""
462476
model_kwargs = {}
463-
477+
464478
if self.is_encoder_decoder:
465479
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs, model_kwargs)
466-
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, model_kwargs=model_kwargs)
480+
inputs = self._prepare_decoder_ids_for_generation(
481+
len(inputs), device=inputs.device, model_kwargs=model_kwargs
482+
)
467483

468484
if max_len is None:
469485
# Too hard to try to figure out the exact max_seq_length for each model

0 commit comments

Comments
 (0)