@@ -74,8 +74,6 @@ def __init__(
74
74
else :
75
75
self .relative_attention_bias = None
76
76
77
- self .device = device
78
-
79
77
def forward (
80
78
self ,
81
79
query : Tensor ,
@@ -257,9 +255,7 @@ def _t5_multi_head_attention_forward(
257
255
).unsqueeze (0 )
258
256
else :
259
257
position_bias = self ._compute_bias (
260
- tgt_len ,
261
- src_len ,
262
- bidirectional = (not self .is_decoder ),
258
+ tgt_len , src_len , bidirectional = (not self .is_decoder ), device = k .device
263
259
)
264
260
265
261
# Calculate attention and out projection
@@ -405,15 +401,12 @@ def _t5_dot_product_attention(
405
401
406
402
# NOTE: Modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421
407
403
def _compute_bias (
408
- self ,
409
- query_length : int ,
410
- key_length : int ,
411
- bidirectional : bool = True ,
404
+ self , query_length : int , key_length : int , bidirectional : bool = True , device : Optional [torch .device ] = None
412
405
) -> Tensor :
413
406
"""Compute binned relative position bias"""
414
407
assert self .relative_attention_bias is not None
415
- context_position = torch .arange (query_length , dtype = torch .long , device = self . device )[:, None ]
416
- memory_position = torch .arange (key_length , dtype = torch .long , device = self . device )[None , :]
408
+ context_position = torch .arange (query_length , dtype = torch .long , device = device )[:, None ]
409
+ memory_position = torch .arange (key_length , dtype = torch .long , device = device )[None , :]
417
410
relative_position = memory_position - context_position # shape (query_length, key_length)
418
411
relative_position_bucket = self ._relative_position_bucket (
419
412
relative_position , # shape (query_length, key_length)
@@ -446,7 +439,7 @@ def _relative_position_bucket(
446
439
Returns:
447
440
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
448
441
"""
449
- relative_buckets = torch .zeros (relative_position .shape , dtype = torch .long , device = self .device )
442
+ relative_buckets = torch .zeros (relative_position .shape , dtype = torch .long , device = relative_position .device )
450
443
if bidirectional :
451
444
num_buckets = num_buckets // 2
452
445
relative_buckets += (relative_position > 0 ).to (torch .long ) * num_buckets
0 commit comments