Skip to content

Commit c6f2097

Browse files
Add FP8 AllGather correctness test
Adds test_fp8_allgather_pass_correctness to validate that the FP8 AllGather optimization produces identical outputs compared to the baseline. The test compares: - WITH optimization: enable_fp8_allgather_opt=True - WITHOUT optimization: enable_fp8_allgather_opt=False (baseline) Uses RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8 model with TP=2 to ensure the optimization is triggered and results are numerically equivalent. Signed-off-by: jasonlizhengjian <[email protected]>
1 parent 07c0dc6 commit c6f2097

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

tests/compile/test_async_tp.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,93 @@ def test_async_tp_pass_correctness(
378378
async_tp_env,
379379
tp_env,
380380
method="generate")
381+
382+
383+
@create_new_process_for_each_test()
384+
@pytest.mark.parametrize("model_id", [
385+
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
386+
])
387+
@pytest.mark.parametrize("tp_size", [2])
388+
@pytest.mark.parametrize("fp8_allgather_enabled", [True])
389+
@pytest.mark.parametrize("distributed_backend", ["mp"])
390+
@pytest.mark.parametrize("eager_mode", [False])
391+
def test_fp8_allgather_pass_correctness(
392+
model_id: str,
393+
tp_size: int,
394+
fp8_allgather_enabled: bool,
395+
distributed_backend: str,
396+
eager_mode: bool,
397+
num_gpus_available: int,
398+
):
399+
"""Test FP8 AllGather optimization correctness on FP8 models"""
400+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
401+
model_info.check_transformers_version(on_fail="skip")
402+
model_info.check_available_online(on_fail="skip")
403+
404+
pp_size = 1
405+
if num_gpus_available < tp_size:
406+
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
407+
408+
common_args = [
409+
"--dtype",
410+
"bfloat16",
411+
"--max-model-len",
412+
"2048",
413+
"--max-num-seqs",
414+
"8",
415+
]
416+
if eager_mode:
417+
common_args.append("--enforce-eager")
418+
419+
# Configuration WITH FP8 AllGather optimization
420+
fp8_allgather_compilation_config = {
421+
'level': 3,
422+
'compile_sizes': [2, 4, 8],
423+
'splitting_ops': [],
424+
'pass_config': {
425+
'enable_async_tp': True,
426+
'enable_fp8_allgather_opt': fp8_allgather_enabled
427+
},
428+
}
429+
430+
# Configuration WITHOUT FP8 AllGather optimization (baseline)
431+
baseline_compilation_config = {
432+
'level': 3,
433+
'compile_sizes': [2, 4, 8],
434+
'splitting_ops': [],
435+
'pass_config': {
436+
'enable_async_tp': True,
437+
'enable_fp8_allgather_opt': False
438+
},
439+
}
440+
441+
fp8_allgather_env = baseline_env = {
442+
"VLLM_USE_V1": "1",
443+
}
444+
445+
fp8_allgather_args = [
446+
*common_args,
447+
"--tensor-parallel-size",
448+
str(tp_size),
449+
"--distributed-executor-backend",
450+
distributed_backend,
451+
"--compilation_config",
452+
json.dumps(fp8_allgather_compilation_config),
453+
]
454+
455+
baseline_args = [
456+
*common_args,
457+
"--tensor-parallel-size",
458+
str(tp_size),
459+
"--distributed-executor-backend",
460+
distributed_backend,
461+
"--compilation_config",
462+
json.dumps(baseline_compilation_config),
463+
]
464+
465+
compare_two_settings(model_id,
466+
fp8_allgather_args,
467+
baseline_args,
468+
fp8_allgather_env,
469+
baseline_env,
470+
method="generate")

0 commit comments

Comments
 (0)