@@ -278,23 +278,13 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
278
278
279
279
@pytest .mark .parametrize ("use_v1" , [True , False ])
280
280
def test_invalid_env (use_v1 : bool , monkeypatch : pytest .MonkeyPatch ):
281
-
281
+ """Test that invalid attention backend names raise ValueError."""
282
282
with monkeypatch .context () as m , patch (
283
283
"vllm.attention.selector.current_platform" , CudaPlatform ()):
284
284
m .setenv ("VLLM_USE_V1" , "1" if use_v1 else "0" )
285
285
m .setenv (STR_BACKEND_ENV_VAR , STR_INVALID_VAL )
286
286
287
- # Test with head size 32
288
- backend = get_attn_backend (32 , torch .float16 , None , 16 , False )
289
- EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN"
290
- assert backend .get_name () == EXPECTED
291
-
292
- # when block size == 16, backend will fall back to XFORMERS
293
- # this behavior is not yet supported on V1.
294
- if use_v1 :
295
- # TODO: support fallback on V1!
296
- # https://github.com/vllm-project/vllm/issues/14524
297
- pass
298
- else :
299
- backend = get_attn_backend (16 , torch .float16 , None , 16 , False )
300
- assert backend .get_name () == "XFORMERS"
287
+ # Should raise ValueError for invalid backend
288
+ with pytest .raises (ValueError ) as exc_info :
289
+ get_attn_backend (32 , torch .float16 , None , 16 , False )
290
+ assert "Invalid attention backend: 'INVALID'" in str (exc_info .value )
0 commit comments