Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 651a033

Browse files
authored
Fix device setting for T5 model (#2007)
* Fix device setting for T5 model * Fix lint issues
1 parent c4f1e84 commit 651a033

File tree

2 files changed

+10
-15
lines changed

2 files changed

+10
-15
lines changed

torchtext/prototype/models/t5/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def __init__(
8585
self.padding_idx = config.padding_idx
8686
self.training = config.training
8787
self.dropout = config.dropout if config.training else 0.0
88-
self.device = device
8988
self.dtype = dtype
9089

9190
self.token_embeddings = nn.Embedding(config.vocab_size, config.embedding_dim, config.padding_idx)
@@ -184,13 +183,16 @@ def forward(
184183

185184
# decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx.
186185
if decoder_tokens is None:
187-
decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * self.padding_idx
186+
decoder_tokens = (
187+
torch.ones((encoder_tokens.size(0), 1), device=encoder_tokens.device, dtype=torch.long)
188+
* self.padding_idx
189+
)
188190

189191
if decoder_mask is None:
190192
assert decoder_tokens is not None and decoder_tokens.dim() == 2
191193
tgt_len = decoder_tokens.shape[1]
192194
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1)
193-
decoder_mask = decoder_mask.to(self.device, dtype=torch.bool)
195+
decoder_mask = decoder_mask.to(decoder_tokens.device, dtype=torch.bool)
194196

195197
decoder_padding_mask = decoder_tokens.eq(self.padding_idx)
196198
# T5 implemention uses padding idx to start sequence. Want to ignore this when masking

torchtext/prototype/models/t5/modules.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ def __init__(
7474
else:
7575
self.relative_attention_bias = None
7676

77-
self.device = device
78-
7977
def forward(
8078
self,
8179
query: Tensor,
@@ -257,9 +255,7 @@ def _t5_multi_head_attention_forward(
257255
).unsqueeze(0)
258256
else:
259257
position_bias = self._compute_bias(
260-
tgt_len,
261-
src_len,
262-
bidirectional=(not self.is_decoder),
258+
tgt_len, src_len, bidirectional=(not self.is_decoder), device=k.device
263259
)
264260

265261
# Calculate attention and out projection
@@ -405,15 +401,12 @@ def _t5_dot_product_attention(
405401

406402
# NOTE: Modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421
407403
def _compute_bias(
408-
self,
409-
query_length: int,
410-
key_length: int,
411-
bidirectional: bool = True,
404+
self, query_length: int, key_length: int, bidirectional: bool = True, device: Optional[torch.device] = None
412405
) -> Tensor:
413406
"""Compute binned relative position bias"""
414407
assert self.relative_attention_bias is not None
415-
context_position = torch.arange(query_length, dtype=torch.long, device=self.device)[:, None]
416-
memory_position = torch.arange(key_length, dtype=torch.long, device=self.device)[None, :]
408+
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
409+
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
417410
relative_position = memory_position - context_position # shape (query_length, key_length)
418411
relative_position_bucket = self._relative_position_bucket(
419412
relative_position, # shape (query_length, key_length)
@@ -446,7 +439,7 @@ def _relative_position_bucket(
446439
Returns:
447440
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
448441
"""
449-
relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long, device=self.device)
442+
relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long, device=relative_position.device)
450443
if bidirectional:
451444
num_buckets = num_buckets // 2
452445
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets

0 commit comments

Comments
 (0)