@@ -73,7 +73,12 @@ def __init__(self,
73
73
raise ValueError (f"head_size ({ self .head_size } ) is not supported. "
74
74
f"Supported head sizes: { _SUPPORTED_HEAD_SIZES } ." )
75
75
76
- def set_attn_bias (self , input_metadata : InputMetadata ) -> None :
76
+ def set_attn_bias (
77
+ self ,
78
+ input_metadata : InputMetadata ,
79
+ dtype : torch .dtype ,
80
+ ) -> None :
81
+ del dtype # Unused.
77
82
if input_metadata .attn_bias :
78
83
# Already set by a previous layer.
79
84
return
@@ -196,7 +201,7 @@ def forward(
196
201
if num_prompt_tokens > 0 :
197
202
# Prompt run.
198
203
assert input_metadata .num_generation_tokens == 0
199
- self .set_attn_bias (input_metadata )
204
+ self .set_attn_bias (input_metadata , dtype = query . dtype )
200
205
self .multi_query_kv_attention (
201
206
output [:num_prompt_tokens ],
202
207
query [:num_prompt_tokens ],
@@ -340,13 +345,14 @@ def __init__(self,
340
345
slopes = torch .tensor (slopes , dtype = torch .float32 )
341
346
self .register_buffer ("alibi_slopes" , slopes , persistent = False )
342
347
343
- def set_attn_bias (self , input_metadata : InputMetadata ) -> None :
348
+ def set_attn_bias (self , input_metadata : InputMetadata ,
349
+ dtype : torch .dtype ) -> None :
344
350
if input_metadata .attn_bias :
345
351
# Already set by a previous layer.
346
352
return
347
353
# Generates ALiBi mask for each prompt.
348
354
for prompt_len in input_metadata .prompt_lens :
349
- bias = torch .arange (prompt_len )
355
+ bias = torch .arange (prompt_len , dtype = dtype )
350
356
# Note(zhuohan): HF uses
351
357
# `bias = bias[None, :].repeat(prompt_len, 1)`
352
358
# here. We find that both biases give the same results, but
@@ -364,6 +370,7 @@ def set_attn_bias(self, input_metadata: InputMetadata) -> None:
364
370
prompt_len ,
365
371
padded_len ,
366
372
device = self .alibi_slopes .device ,
373
+ dtype = dtype ,
367
374
)[:, :, :, :prompt_len ].copy_ (bias )
368
375
bias .mul_ (self .alibi_slopes [:, None , None ])
369
376
attn_bias = LowerTriangularMaskWithTensorBias (bias )
0 commit comments