@@ -353,6 +353,8 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
353353 simulated paged KV cache.
354354 5. Comparing the vLLM backend's output to the ground-truth SDPA output.
355355 """
356+ from vllm .v1 .attention .backends .mla .common import QueryLenSupport
357+
356358 batch_spec = BATCH_SPECS [batch_spec_name ]
357359 is_spec_decode_test = batch_spec_name .startswith ("spec_decode" )
358360 spec_decode_backends = {_Backend .FLASH_ATTN_MLA , _Backend .FLASHMLA }
@@ -459,16 +461,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
459461 for backend_idx , backend in enumerate (BACKENDS_TO_TEST ):
460462 builder_cls , _ = try_get_attention_backend (backend )
461463 if is_spec_decode_test :
462- from vllm .v1 .attention .backends .mla .common import QueryLenSupport
463-
464464 query_len_support = getattr (
465465 builder_cls , "query_len_support" , QueryLenSupport .SINGLE_ONLY
466466 )
467467 supports_spec = query_len_support != QueryLenSupport .SINGLE_ONLY
468468 is_decode .append (supports_spec )
469469 else :
470- from vllm .v1 .attention .backends .mla .common import QueryLenSupport
471-
472470 threshold = getattr (builder_cls , "reorder_batch_threshold" , None )
473471 query_len_support = getattr (
474472 builder_cls , "query_len_support" , QueryLenSupport .SINGLE_ONLY
0 commit comments