Skip to content

Commit e79a12f

Browse files
authored
[UX] Fail if an invalid attention backend is specified (#22217)
Signed-off-by: mgoin <[email protected]>
1 parent cdfd687 commit e79a12f

File tree

2 files changed

+9
-15
lines changed

2 files changed

+9
-15
lines changed

tests/kernels/attention/test_attention_selector.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -278,23 +278,13 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
278278

279279
@pytest.mark.parametrize("use_v1", [True, False])
280280
def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch):
281-
281+
"""Test that invalid attention backend names raise ValueError."""
282282
with monkeypatch.context() as m, patch(
283283
"vllm.attention.selector.current_platform", CudaPlatform()):
284284
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
285285
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
286286

287-
# Test with head size 32
288-
backend = get_attn_backend(32, torch.float16, None, 16, False)
289-
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN"
290-
assert backend.get_name() == EXPECTED
291-
292-
# when block size == 16, backend will fall back to XFORMERS
293-
# this behavior is not yet supported on V1.
294-
if use_v1:
295-
# TODO: support fallback on V1!
296-
# https://github.com/vllm-project/vllm/issues/14524
297-
pass
298-
else:
299-
backend = get_attn_backend(16, torch.float16, None, 16, False)
300-
assert backend.get_name() == "XFORMERS"
287+
# Should raise ValueError for invalid backend
288+
with pytest.raises(ValueError) as exc_info:
289+
get_attn_backend(32, torch.float16, None, 16, False)
290+
assert "Invalid attention backend: 'INVALID'" in str(exc_info.value)

vllm/attention/selector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ def _cached_get_attn_backend(
193193
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
194194
if backend_by_env_var is not None:
195195
selected_backend = backend_name_to_enum(backend_by_env_var)
196+
if selected_backend is None:
197+
raise ValueError(
198+
f"Invalid attention backend: '{backend_by_env_var}'. "
199+
f"Valid backends are: {list(_Backend.__members__.keys())}")
196200

197201
# get device-specific attn_backend
198202
attention_cls = current_platform.get_attn_backend_cls(

0 commit comments

Comments
 (0)