Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit db364c4

Browse files
committed
format
1 parent de07b69 commit db364c4

File tree

2 files changed

+55
-41
lines changed

2 files changed

+55
-41
lines changed

test/torchtext_unittest/prototype/test_generate.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,29 @@ def test_warns_when_mp_with_greedy(self, mock) -> None:
5757

5858
def test_beam_search_with_t5_(self) -> None:
5959
generation_model = GenerationUtil(self.model)
60-
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30, beam_size_token=self.model.config.vocab_size)
60+
tokens = generation_model.generate(
61+
self.inputs, num_beams=3, max_len=30, beam_size_token=self.model.config.vocab_size
62+
)
6163
generated_text = self.transform.decode(tokens.tolist())
6264

6365
expected_generated_text = [
64-
'kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for',
65-
'Das ist gut.',
66-
'acceptable',
67-
'4.0',
68-
'a tornado ripped through a swath of a lake in st. louis . a s'
66+
"kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for",
67+
"Das ist gut.",
68+
"acceptable",
69+
"4.0",
70+
"a tornado ripped through a swath of a lake in st. louis . a s",
6971
]
7072

7173
self.assertEqual(generated_text, expected_generated_text)
7274

73-
74-
7575
def test_hf_DELETE(self) -> None:
7676
from transformers import T5ForConditionalGeneration, T5Tokenizer
7777
from torchtext.prototype.generate import GenerationUtil
7878

7979
t5 = T5ForConditionalGeneration.from_pretrained("t5-base")
80-
test_sequence = ["summarize: studies have shown that owning a dog is good for you"]#, "Q: what is the capital of Alaska?"]
80+
test_sequence = [
81+
"summarize: studies have shown that owning a dog is good for you"
82+
] # , "Q: what is the capital of Alaska?"]
8183
generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True)
8284
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
8385
test_sequence_tk = t5_tokenizer(test_sequence, padding=True, return_tensors="pt").input_ids
@@ -89,18 +91,14 @@ def test_hf_DELETE(self) -> None:
8991
max_len=100,
9092
pad_idx=t5.config.pad_token_id,
9193
num_beams=10,
92-
93-
9494
)
9595
end = time.time() - start
9696
print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end)
9797
exit()
98-
98+
9999
def test_jit_generate(self) -> None:
100100
generation_model = GenerationUtil(self.model)
101101
torch.jit.script(generation_model)
102-
103102

104-
105103
def test_beam_search_speed(self) -> None:
106104
pass

torchtext/prototype/generate.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import logging
1616
import warnings
17+
1718
logger = logging.getLogger(__name__)
1819

1920
DEFAULT_MAX_SEQ_LEN = 256
@@ -52,25 +53,23 @@ class GenerationUtils(nn.Module):
5253
More examples can be found in the `notebooks` directory of this repository.
5354
"""
5455

55-
_huggingface_model_input_values = {
56-
"return_dict": True,
57-
"use_cache": True,
58-
"output_hidden_states": True
59-
}
56+
_huggingface_model_input_values = {"return_dict": True, "use_cache": True, "output_hidden_states": True}
6057

6158
def __init__(self, model: nn.Module, **kwargs) -> None:
6259
super().__init__()
6360
self.model = model
6461
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True)
6562
self.is_huggingface_model = kwargs.pop("is_huggingface_model", False)
66-
67-
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
63+
64+
def _prepare_encoder_decoder_kwargs_for_generation(
65+
self, inputs: torch.Tensor, model_kwargs: Dict[str, Any]
66+
) -> Dict[str, Any]:
6867
"""Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592.
6968
7069
Args:
7170
inputs: (Tensor): Tokenized startings sequence(s).
7271
model_kwargs (Dict[str, Any]): Model keyword arguments to be modified for decoding.
73-
72+
7473
Returns:
7574
Modified model_kwargs with addition of encoded input sequence(s).
7675
"""
@@ -83,19 +82,23 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor, m
8382
# Forward pass
8483
if self.is_huggingface_model:
8584
encoder_kwargs["return_dict"] = True
86-
85+
8786
# import pdb
8887
# pdb.set_trace()
8988
# print(encoder_kwargs.keys())
90-
89+
9190
# assert torch.jit.isinstance(encoder_kwargs, Optional[Dict[str, bool]])
92-
91+
9392
model_kwargs["encoder_outputs"] = encoder(inputs, **encoder_kwargs)
9493

9594
return model_kwargs
9695

9796
def _prepare_decoder_ids_for_generation(
98-
self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, model_kwargs: Optional[Dict[str, Any]] = None
97+
self,
98+
batch_size: int,
99+
pad_idx: int = 0,
100+
device: Optional[torch.device] = None,
101+
model_kwargs: Optional[Dict[str, Any]] = None,
99102
):
100103
"""Prepare decoder IDs for generation."""
101104
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
@@ -113,7 +116,7 @@ def _update_model_kwargs_for_generation(
113116
Args:
114117
outputs (Dict[str, Any]): LM output.
115118
model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
116-
119+
117120
Returns:
118121
Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
119122
"""
@@ -149,7 +152,12 @@ def _update_model_kwargs_for_generation(
149152
return model_kwargs
150153

151154
def greedy_search(
152-
self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: Optional[int] = None, model_kwargs: Optional[Dict[str, Any]] = {}
155+
self,
156+
input_ids: torch.Tensor,
157+
max_length: int,
158+
eos_idx: int,
159+
pad_idx: Optional[int] = None,
160+
model_kwargs: Optional[Dict[str, Any]] = {},
153161
) -> torch.Tensor:
154162
"""Greedy search decoding for text generation. Takes the most likely next token every time.
155163
@@ -222,7 +230,7 @@ def beam_search(
222230
eos_idx (int): End-of-sequence index.
223231
num_python_workers (int): Number of python workers to use for multiprocessing.
224232
model_kwargs
225-
233+
226234
Returns:
227235
Tensor of the generated sequences.
228236
"""
@@ -232,9 +240,9 @@ def beam_search(
232240

233241
def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, timestep):
234242
# `emissions` and `N` are unused in this current implementation
235-
243+
236244
i = T # Hacky access to the current seq in inputs
237-
245+
238246
# Copy over the `model_kwargs` in order to modify
239247
new_model_kwargs = model_kwargs.copy()
240248

@@ -259,18 +267,22 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
259267
max_inference_batch_size, 1000 / (timestep + 1)
260268
) # many hypotheses will EOS, so increase the batch size gradually
261269
curr_beam_size = len(prev_step_token_idxs)
262-
270+
263271
# 2. Batched inference to get next tokens
264272
while start < curr_beam_size: # catch the remainder
265273
end = start + step
266274
if end > curr_beam_size:
267275
end = curr_beam_size
268276

269-
num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
277+
num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
270278

271279
if prev_step_token_idxs != [-1]:
272280
state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0)
273-
token_indices = torch.Tensor(prev_step_token_idxs[start:end]).to(dtype=torch.long, device=self.model.device).reshape(num_samples, 1)
281+
token_indices = (
282+
torch.Tensor(prev_step_token_idxs[start:end])
283+
.to(dtype=torch.long, device=self.model.device)
284+
.reshape(num_samples, 1)
285+
)
274286

275287
state_and_tokens = torch.cat(
276288
[state_sequences, token_indices], dim=-1
@@ -308,14 +320,17 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
308320

309321
# HF optimizations to reduce overhead in future `forward` calls
310322
if self.is_huggingface_model:
311-
new_model_kwargs = self._update_model_kwargs_for_generation(outputs, new_model_kwargs, is_encoder_decoder=self.is_encoder_decoder)
323+
new_model_kwargs = self._update_model_kwargs_for_generation(
324+
outputs, new_model_kwargs, is_encoder_decoder=self.is_encoder_decoder
325+
)
312326
if new_model_kwargs["past"] is not None:
313327
import pdb
328+
314329
pdb.set_trace()
315330
beam_indices += [start for _ in range(num_samples)]
316331
new_model_kwargs["past"] = self.model._reorder_cache(
317332
new_model_kwargs["past"],
318-
torch.Tensor(beam_indices).to(dtype=torch.int32) # I think this is correct?
333+
torch.Tensor(beam_indices).to(dtype=torch.int32), # I think this is correct?
319334
)
320335

321336
# Keep track of probabilities over vocab for this pairing
@@ -342,7 +357,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
342357
)
343358
)
344359
)
345-
360+
346361
start += step
347362

348363
return out_probs, model_states
@@ -397,11 +412,10 @@ def is_not_neg_one(elem: int) -> bool:
397412
logger.warning("Multiprocessing has not yet been implemented.")
398413

399414
all_final_tokens = [beam_decode_step(i) for i in range(len(input_ids))]
400-
415+
401416
# 5. Return top hypotheses for all input sequences
402417
return torch.stack(all_final_tokens, dim=0)
403418

404-
405419
def forward(
406420
self,
407421
inputs: Optional[torch.Tensor] = None,
@@ -465,10 +479,12 @@ def generate(
465479
2. `num_beams` > 1 -> beam search
466480
"""
467481
model_kwargs = {}
468-
482+
469483
if self.is_encoder_decoder:
470484
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs, model_kwargs)
471-
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, model_kwargs=model_kwargs)
485+
inputs = self._prepare_decoder_ids_for_generation(
486+
len(inputs), device=inputs.device, model_kwargs=model_kwargs
487+
)
472488

473489
if max_length is None:
474490
# Too hard to try to figure out the exact max_seq_length for each model

0 commit comments

Comments
 (0)