Skip to content

Commit 1f400c5

Browse files
authored
[CI] Add batch invariant test to ci (#27842)
Signed-off-by: yewentao256 <[email protected]>
1 parent 711241c commit 1f400c5

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,18 @@ steps:
346346
commands:
347347
- pytest -v -s v1/attention
348348

349+
- label: Batch Invariance Tests (H100) # 10min
350+
timeout_in_minutes: 25
351+
gpu: h100
352+
source_file_dependencies:
353+
- vllm/
354+
- tests/v1/determinism/
355+
commands:
356+
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
357+
- pip install pytest-timeout pytest-forked
358+
- pytest -v -s v1/determinism/test_batch_invariance.py
359+
- pytest -v -s v1/determinism/test_rms_norm_batch_invariant.py
360+
349361
- label: V1 Test attention (B200) # 10min
350362
timeout_in_minutes: 30
351363
gpu: b200

tests/v1/determinism/test_batch_invariance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
190190
max_num_seqs=32,
191191
max_model_len=8192,
192192
dtype="bfloat16", # not everything is supported
193+
gpu_memory_utilization=0.9,
193194
)
194195

195196
# Use more realistic prompts for better token generation
@@ -444,6 +445,7 @@ def test_logprobs_without_batch_invariance_should_fail(
444445
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
445446

446447
# CRITICAL: Disable batch invariance for this test
448+
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
447449
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
448450
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
449451
random.seed(seed)

tests/v1/determinism/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
import torch
88

9+
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
910
from vllm.platforms import current_platform
1011

1112
skip_unsupported = pytest.mark.skipif(
@@ -18,7 +19,7 @@
1819
"FLASHINFER",
1920
]
2021

21-
if current_platform.is_cuda() and current_platform.is_device_capability(90):
22+
if flash_attn_supports_mla():
2223
BACKENDS.append("FLASH_ATTN_MLA")
2324

2425
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"

0 commit comments

Comments
 (0)