|
25 | 25 | from vllm.sampling_params import SamplingType
|
26 | 26 | from vllm.sequence import IntermediateTensors
|
27 | 27 | from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
28 |
| - LayerBlockType, LazyLoader, cdiv, |
| 28 | + LayerBlockType, LazyLoader, cdiv, check_use_alibi, |
29 | 29 | is_pin_memory_available)
|
30 | 30 | from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
31 | 31 | from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
@@ -223,6 +223,9 @@ def __init__(
|
223 | 223 | device="cpu",
|
224 | 224 | pin_memory=self.pin_memory)
|
225 | 225 |
|
| 226 | + # Only relevant for models using ALiBi (e.g, MPT) |
| 227 | + self.use_alibi = check_use_alibi(model_config) |
| 228 | + |
226 | 229 | self.inputs_embeds = torch.zeros(
|
227 | 230 | (self.max_num_tokens, self.hidden_size),
|
228 | 231 | dtype=self.dtype,
|
@@ -689,7 +692,7 @@ def _compute_cascade_attn_prefix_len(
|
689 | 692 | query_lens=num_scheduled_tokens,
|
690 | 693 | num_query_heads=self.num_query_heads,
|
691 | 694 | num_kv_heads=self.num_kv_heads,
|
692 |
| - use_alibi=False, # FIXME |
| 695 | + use_alibi=self.use_alibi, |
693 | 696 | use_sliding_window=self.window_size is not None,
|
694 | 697 | num_sms=self.num_sms,
|
695 | 698 | )
|
|
0 commit comments