Skip to content

Commit 8958217

Browse files
h-sugiWoosukKwon
andauthored
[Bugfix] Fix use_cascade_attention handling for Alibi-based models on vllm/v1 (#15211)
Signed-off-by: h-sugi <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]>
1 parent ac5bc61 commit 8958217

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

vllm/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from vllm.logger import enable_trace_function_call, init_logger
6262

6363
if TYPE_CHECKING:
64-
from vllm.config import VllmConfig
64+
from vllm.config import ModelConfig, VllmConfig
6565

6666
logger = init_logger(__name__)
6767

@@ -2498,6 +2498,18 @@ def wrapper(*args, **kwargs):
24982498
return decorator
24992499

25002500

2501+
# Only relevant for models using ALiBi (e.g, MPT)
2502+
def check_use_alibi(model_config: ModelConfig) -> bool:
2503+
return (getattr(model_config.hf_text_config, "alibi", False) # Falcon
2504+
or ("BloomForCausalLM" in getattr(model_config.hf_config,
2505+
"architectures", [])) # Bloom
2506+
or getattr(model_config.hf_text_config, "position_encoding_type",
2507+
"") == "alibi" # codellm_1b_alibi
2508+
or
2509+
(hasattr(model_config.hf_text_config, "attn_config") # MPT
2510+
and model_config.hf_text_config.attn_config.get("alibi", False)))
2511+
2512+
25012513
def sha256(input) -> int:
25022514
"""Hash any picklable Python object using SHA-256.
25032515

vllm/v1/worker/gpu_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from vllm.sampling_params import SamplingType
2626
from vllm.sequence import IntermediateTensors
2727
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
28-
LayerBlockType, LazyLoader, cdiv,
28+
LayerBlockType, LazyLoader, cdiv, check_use_alibi,
2929
is_pin_memory_available)
3030
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
3131
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@@ -223,6 +223,9 @@ def __init__(
223223
device="cpu",
224224
pin_memory=self.pin_memory)
225225

226+
# Only relevant for models using ALiBi (e.g, MPT)
227+
self.use_alibi = check_use_alibi(model_config)
228+
226229
self.inputs_embeds = torch.zeros(
227230
(self.max_num_tokens, self.hidden_size),
228231
dtype=self.dtype,
@@ -689,7 +692,7 @@ def _compute_cascade_attn_prefix_len(
689692
query_lens=num_scheduled_tokens,
690693
num_query_heads=self.num_query_heads,
691694
num_kv_heads=self.num_kv_heads,
692-
use_alibi=False, # FIXME
695+
use_alibi=self.use_alibi,
693696
use_sliding_window=self.window_size is not None,
694697
num_sms=self.num_sms,
695698
)

0 commit comments

Comments
 (0)