Skip to content

Commit 12ed388

Browse files
Add FP8 AllGather correctness test
Adds test_fp8_allgather_pass_correctness to validate that the FP8 AllGather optimization produces identical outputs compared to the baseline. The test compares: - WITH optimization: enable_fp8_allgather_opt=True - WITHOUT optimization: enable_fp8_allgather_opt=False (baseline) Uses RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8 model with TP=2 to ensure the optimization is triggered and results are numerically equivalent. Signed-off-by: jasonlizhengjian <[email protected]>
1 parent 159745c commit 12ed388

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

tests/compile/test_async_tp.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,95 @@ def test_async_tp_pass_correctness(
413413
compare_two_settings(
414414
model_id, async_tp_args, tp_args, async_tp_env, tp_env, method="generate"
415415
)
416+
417+
418+
@create_new_process_for_each_test()
419+
@pytest.mark.parametrize("model_id", [
420+
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
421+
])
422+
@pytest.mark.parametrize("tp_size", [2])
423+
@pytest.mark.parametrize("fp8_allgather_enabled", [True])
424+
@pytest.mark.parametrize("distributed_backend", ["mp"])
425+
@pytest.mark.parametrize("eager_mode", [False])
426+
def test_fp8_allgather_pass_correctness(
427+
model_id: str,
428+
tp_size: int,
429+
fp8_allgather_enabled: bool,
430+
distributed_backend: str,
431+
eager_mode: bool,
432+
num_gpus_available: int,
433+
):
434+
"""Test FP8 AllGather optimization correctness on FP8 models"""
435+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
436+
model_info.check_transformers_version(on_fail="skip")
437+
model_info.check_available_online(on_fail="skip")
438+
439+
pp_size = 1
440+
if num_gpus_available < tp_size:
441+
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
442+
443+
common_args = [
444+
"--dtype",
445+
"bfloat16",
446+
"--max-model-len",
447+
"2048",
448+
"--max-num-seqs",
449+
"8",
450+
]
451+
if eager_mode:
452+
common_args.append("--enforce-eager")
453+
454+
# Configuration WITH FP8 AllGather optimization
455+
fp8_allgather_compilation_config = {
456+
'level': 3,
457+
'compile_sizes': [2, 4, 8],
458+
'splitting_ops': [],
459+
'pass_config': {
460+
'enable_async_tp': True,
461+
'enable_fp8_allgather_opt': fp8_allgather_enabled
462+
},
463+
}
464+
465+
# Configuration WITHOUT FP8 AllGather optimization (baseline)
466+
baseline_compilation_config = {
467+
'level': 3,
468+
'compile_sizes': [2, 4, 8],
469+
'splitting_ops': [],
470+
'pass_config': {
471+
'enable_async_tp': True,
472+
'enable_fp8_allgather_opt': False
473+
},
474+
}
475+
476+
fp8_allgather_env = baseline_env = {
477+
"VLLM_USE_V1": "1",
478+
}
479+
480+
fp8_allgather_args = [
481+
*common_args,
482+
"--tensor-parallel-size",
483+
str(tp_size),
484+
"--distributed-executor-backend",
485+
distributed_backend,
486+
"--compilation_config",
487+
json.dumps(fp8_allgather_compilation_config),
488+
]
489+
490+
baseline_args = [
491+
*common_args,
492+
"--tensor-parallel-size",
493+
str(tp_size),
494+
"--distributed-executor-backend",
495+
distributed_backend,
496+
"--compilation_config",
497+
json.dumps(baseline_compilation_config),
498+
]
499+
500+
compare_two_settings(
501+
model_id,
502+
fp8_allgather_args,
503+
baseline_args,
504+
fp8_allgather_env,
505+
baseline_env,
506+
method="generate",
507+
)

0 commit comments

Comments
 (0)