Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7c58534

Browse files
committed
Separate encoding/decoding logic for T5 model in preparation for generation
1 parent bc57394 commit 7c58534

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

torchtext/models/t5/model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818

1919
from .modules import DECODER_OUTPUTS_TYPE, ENCODER_OUTPUTS_TYPE, PAST_KEY_VALUES_TYPE, T5Decoder, T5Encoder
2020

21+
# logging library is not automatically supported by Torchscript
22+
import warnings
23+
2124

22-
@dataclass
25+
@dataclass(frozen=True)
2326
class T5Conf:
2427
encoder_only: bool = False
2528
linear_head: bool = False
@@ -288,6 +291,8 @@ def forward(
288291

289292
# decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx.
290293
if decoder_tokens is None:
294+
batch_size = encoder_output.size()[0]
295+
encoder_output_device = encoder_output.device
291296
decoder_tokens = (
292297
torch.ones((batch_size, 1), device=encoder_output_device, dtype=torch.long) * self.padding_idx
293298
)
@@ -317,7 +322,7 @@ def forward(
317322
# Rescale output before projecting on vocab. This happens when the encoder and decoder share the
318323
# same word embeddings, which is always the case in our t5 implementation.
319324
# See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661
320-
decoder_output = decoder_output * (self.embedding_dim ** -0.5)
325+
decoder_output = decoder_output * (self.embedding_dim**-0.5)
321326
decoder_output = self.lm_head(decoder_output)
322327
decoder_outputs["decoder_output"] = decoder_output
323328

torchtext/prototype/generate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
from typing import Optional
32

43
import torch

0 commit comments

Comments
 (0)