Skip to content

Commit a62de9e

Browse files
authored
Fix wrong dtype in PagedAttentionWithALiBi bias (#996)
--------- Signed-off-by: Antoni Baum <[email protected]>
1 parent 4042d19 commit a62de9e

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

vllm/model_executor/layers/attention.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,12 @@ def __init__(self,
7373
raise ValueError(f"head_size ({self.head_size}) is not supported. "
7474
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
7575

76-
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
76+
def set_attn_bias(
77+
self,
78+
input_metadata: InputMetadata,
79+
dtype: torch.dtype,
80+
) -> None:
81+
del dtype # Unused.
7782
if input_metadata.attn_bias:
7883
# Already set by a previous layer.
7984
return
@@ -196,7 +201,7 @@ def forward(
196201
if num_prompt_tokens > 0:
197202
# Prompt run.
198203
assert input_metadata.num_generation_tokens == 0
199-
self.set_attn_bias(input_metadata)
204+
self.set_attn_bias(input_metadata, dtype=query.dtype)
200205
self.multi_query_kv_attention(
201206
output[:num_prompt_tokens],
202207
query[:num_prompt_tokens],
@@ -340,13 +345,14 @@ def __init__(self,
340345
slopes = torch.tensor(slopes, dtype=torch.float32)
341346
self.register_buffer("alibi_slopes", slopes, persistent=False)
342347

343-
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
348+
def set_attn_bias(self, input_metadata: InputMetadata,
349+
dtype: torch.dtype) -> None:
344350
if input_metadata.attn_bias:
345351
# Already set by a previous layer.
346352
return
347353
# Generates ALiBi mask for each prompt.
348354
for prompt_len in input_metadata.prompt_lens:
349-
bias = torch.arange(prompt_len)
355+
bias = torch.arange(prompt_len, dtype=dtype)
350356
# Note(zhuohan): HF uses
351357
# `bias = bias[None, :].repeat(prompt_len, 1)`
352358
# here. We find that both biases give the same results, but
@@ -364,6 +370,7 @@ def set_attn_bias(self, input_metadata: InputMetadata) -> None:
364370
prompt_len,
365371
padded_len,
366372
device=self.alibi_slopes.device,
373+
dtype=dtype,
367374
)[:, :, :, :prompt_len].copy_(bias)
368375
bias.mul_(self.alibi_slopes[:, None, None])
369376
attn_bias = LowerTriangularMaskWithTensorBias(bias)

0 commit comments

Comments
 (0)