Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e64f1ef

Browse files
committed
wip feb 8
1 parent db364c4 commit e64f1ef

File tree

3 files changed

+40
-47
lines changed

3 files changed

+40
-47
lines changed

test/torchtext_unittest/prototype/test_generate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_generate_errors_with_incorrect_beams(self) -> None:
4646
with self.assertRaises(ValueError):
4747
generation_model.generate(self.inputs, num_beams=0)
4848

49-
@patch("logging.Logger.warning")
49+
@patch("warnings.warn")
5050
def test_warns_when_no_max_len_provided(self, mock) -> None:
5151
generation_model = GenerationUtil(self.model)
5252
generation_model.generate(self.inputs)
@@ -91,12 +91,17 @@ def test_hf_DELETE(self) -> None:
9191
max_len=100,
9292
pad_idx=t5.config.pad_token_id,
9393
num_beams=10,
94+
beam_size_token=t5.config.vocab_size,
9495
)
9596
end = time.time() - start
9697
print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end)
9798
exit()
9899

99100
def test_jit_generate(self) -> None:
101+
# jitted_model = torch.jit.script(self.model)
102+
# encoder = jitted_model.get_encoder()
103+
104+
100105
generation_model = GenerationUtil(self.model)
101106
torch.jit.script(generation_model)
102107

torchtext/models/t5/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def prepare_inputs_for_generation(
215215
"return_past_key_values": return_past_key_values,
216216
}
217217

218+
@torch.jit.export
218219
def get_encoder(self) -> T5Encoder:
219220
return self.encoder
220221

torchtext/prototype/generate.py

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
get_obj_from_emitting_model_state,
1313
)
1414

15-
import logging
1615
import warnings
1716

18-
logger = logging.getLogger(__name__)
17+
18+
MODEL_KWARGS_TYPE = Dict[str, Dict[str, Union[torch.Tensor, List[Optional[torch.Tensor]], List[torch.Tensor], None]]]
1919

2020
DEFAULT_MAX_SEQ_LEN = 256
2121

@@ -61,9 +61,7 @@ def __init__(self, model: nn.Module, **kwargs) -> None:
6161
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True)
6262
self.is_huggingface_model = kwargs.pop("is_huggingface_model", False)
6363

64-
def _prepare_encoder_decoder_kwargs_for_generation(
65-
self, inputs: torch.Tensor, model_kwargs: Dict[str, Any]
66-
) -> Dict[str, Any]:
64+
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor) -> MODEL_KWARGS_TYPE:
6765
"""Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592.
6866
6967
Args:
@@ -77,40 +75,36 @@ def _prepare_encoder_decoder_kwargs_for_generation(
7775
encoder = self.model.get_encoder()
7876

7977
# Create copy of encoder kwargs
80-
encoder_kwargs = model_kwargs.copy()
78+
encoder_kwargs: Dict[str, bool] = {}
8179

82-
# Forward pass
8380
if self.is_huggingface_model:
8481
encoder_kwargs["return_dict"] = True
8582

86-
# import pdb
87-
# pdb.set_trace()
88-
# print(encoder_kwargs.keys())
89-
90-
# assert torch.jit.isinstance(encoder_kwargs, Optional[Dict[str, bool]])
91-
92-
model_kwargs["encoder_outputs"] = encoder(inputs, **encoder_kwargs)
93-
83+
# Forward pass
84+
# Explicitly call forward method to assert to assert this is a ScriptModule if JITted
85+
model_kwargs = {"encoder_outputs": encoder.forward(inputs)} # , **encoder_kwargs)
9486
return model_kwargs
9587

9688
def _prepare_decoder_ids_for_generation(
9789
self,
9890
batch_size: int,
9991
pad_idx: int = 0,
10092
device: Optional[torch.device] = None,
101-
model_kwargs: Optional[Dict[str, Any]] = None,
102-
):
93+
model_kwargs: Optional[MODEL_KWARGS_TYPE] = None,
94+
) -> torch.Tensor:
10395
"""Prepare decoder IDs for generation."""
10496
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
105-
return model_kwargs.pop("decoder_input_ids")
97+
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
98+
assert torch.jit.isinstance(decoder_input_ids, torch.Tensor)
99+
return decoder_input_ids
106100
else:
107101
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx
108102

109103
def _update_model_kwargs_for_generation(
110104
self,
111105
outputs: Dict[str, Any],
112106
model_kwargs: Dict[str, Any],
113-
) -> Dict[str, Any]:
107+
) -> MODEL_KWARGS_TYPE:
114108
"""After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
115109
116110
Args:
@@ -157,7 +151,7 @@ def greedy_search(
157151
max_length: int,
158152
eos_idx: int,
159153
pad_idx: Optional[int] = None,
160-
model_kwargs: Optional[Dict[str, Any]] = {},
154+
model_kwargs: Optional[MODEL_KWARGS_TYPE] = {},
161155
) -> torch.Tensor:
162156
"""Greedy search decoding for text generation. Takes the most likely next token every time.
163157
@@ -189,9 +183,8 @@ def greedy_search(
189183
_, next_tokens = torch.topk(probs, 1)
190184

191185
# For any finished sequences, padding idx should be the last token
192-
if eos_idx is not None:
193-
if pad_idx is not None:
194-
next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences)
186+
if eos_idx is not None and pad_idx is not None:
187+
next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences)
195188

196189
# Append the next tokens to the previous tokens
197190
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
@@ -238,7 +231,7 @@ def beam_search(
238231
encoder_output_key = "last_hidden_state" if self.is_huggingface_model else "encoder_output"
239232
encoder_output = model_kwargs["encoder_outputs"][encoder_output_key]
240233

241-
def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, timestep):
234+
def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_step_model_states, timestep):
242235
# `emissions` and `N` are unused in this current implementation
243236

244237
i = T # Hacky access to the current seq in inputs
@@ -274,7 +267,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
274267
if end > curr_beam_size:
275268
end = curr_beam_size
276269

277-
num_samples = end - start # Is this always just gunna be equal to curr_beam_size?
270+
num_samples = end - start
278271

279272
if prev_step_token_idxs != [-1]:
280273
state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0)
@@ -308,9 +301,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
308301
if self.is_huggingface_model:
309302
model_inputs.update(self._huggingface_model_input_values)
310303

311-
from typing import MappingProxyType
312-
313-
model_inputs = MappingProxyType(model_inputs)
314304
# Forward pass
315305
outputs = self.model(**model_inputs)
316306

@@ -320,17 +310,14 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_model_states, t
320310

321311
# HF optimizations to reduce overhead in future `forward` calls
322312
if self.is_huggingface_model:
323-
new_model_kwargs = self._update_model_kwargs_for_generation(
324-
outputs, new_model_kwargs, is_encoder_decoder=self.is_encoder_decoder
325-
)
326-
if new_model_kwargs["past"] is not None:
327-
import pdb
328-
329-
pdb.set_trace()
330-
beam_indices += [start for _ in range(num_samples)]
313+
new_model_kwargs = self._update_model_kwargs_for_generation(outputs, new_model_kwargs)
314+
if new_model_kwargs["past"] is not None and len(prev_step_hyp_idxs) > 1:
315+
if len(prev_step_hyp_idxs) == 9:
316+
import pdb
317+
pdb.set_trace()
331318
new_model_kwargs["past"] = self.model._reorder_cache(
332319
new_model_kwargs["past"],
333-
torch.Tensor(beam_indices).to(dtype=torch.int32), # I think this is correct?
320+
torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32), # I think this is correct?
334321
)
335322

336323
# Keep track of probabilities over vocab for this pairing
@@ -409,7 +396,7 @@ def is_not_neg_one(elem: int) -> bool:
409396
return final_tokens_as_tensors
410397

411398
if num_python_workers > 1:
412-
logger.warning("Multiprocessing has not yet been implemented.")
399+
warnings.warn("Multiprocessing has not yet been implemented.")
413400

414401
all_final_tokens = [beam_decode_step(i) for i in range(len(input_ids))]
415402

@@ -478,28 +465,28 @@ def generate(
478465
1. `num_beams` == 1 or `num_beams` is None -> greedy search
479466
2. `num_beams` > 1 -> beam search
480467
"""
481-
model_kwargs = {}
468+
model_kwargs: MODEL_KWARGS_TYPE = {}
482469

483470
if self.is_encoder_decoder:
484-
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs, model_kwargs)
471+
assert torch.jit.isinstance(inputs, torch.Tensor)
472+
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs)
485473
inputs = self._prepare_decoder_ids_for_generation(
486474
len(inputs), device=inputs.device, model_kwargs=model_kwargs
487475
)
488476

489477
if max_length is None:
490478
# Too hard to try to figure out the exact max_seq_length for each model
491-
logger.warning(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")
492-
max_length = DEFAULT_MAX_SEQ_LEN
479+
warnings.warn("`max_len` was not specified. Defaulting to 256 tokens.")
480+
max_length = 256
493481

494-
if num_beams == 1 or num_beams is None:
482+
if num_beams is None or num_beams == 1:
495483
if num_python_workers > 1:
496-
logger.warning(f"Multiprocessing is not implemented for greedy search.")
484+
warnings.warn(f"Multiprocessing is not implemented for greedy search.")
497485
return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, model_kwargs=model_kwargs)
498486
elif num_beams > 1:
499487
if beam_size_token is None:
500488
raise ValueError(
501-
"`beam_size_token` must be specified for beam search. \
502-
If confused about what to put, you can default to the vocab size of the model you are using."
489+
"`beam_size_token` must be specified for beam search. If confused about what to put, you can default to the vocab size of the model you are using."
503490
)
504491
return self.beam_search(
505492
inputs,

0 commit comments

Comments
 (0)