Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ steps:
- pytest -v -s compile/test_functionalization.py
- pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py
- pytest -v -s compile/test_fusion_all_reduce.py
- pytest -v -s compile/test_decorator.py
- pytest -v -s compile/test_noop_elimination.py
Expand Down Expand Up @@ -920,6 +919,7 @@ steps:
- vllm/worker/worker_base.py
- vllm/v1/engine/
- vllm/v1/worker/
- tests/compile/test_async_tp.py
- tests/compile/test_basic_correctness.py
- tests/compile/test_wrapper.py
- tests/distributed/
Expand All @@ -933,6 +933,7 @@ steps:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- pytest -v -s ./compile/test_async_tp.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
Expand Down
92 changes: 92 additions & 0 deletions tests/compile/test_async_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,3 +413,95 @@ def test_async_tp_pass_correctness(
compare_two_settings(
model_id, async_tp_args, tp_args, async_tp_env, tp_env, method="generate"
)


@create_new_process_for_each_test()
@pytest.mark.parametrize(
"model_id",
[
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
],
)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("fp8_allgather_enabled", [True])
@pytest.mark.parametrize("distributed_backend", ["mp"])
@pytest.mark.parametrize("eager_mode", [False])
def test_fp8_allgather_pass_correctness(
model_id: str,
tp_size: int,
fp8_allgather_enabled: bool,
distributed_backend: str,
eager_mode: bool,
num_gpus_available: int,
):
"""Test FP8 AllGather optimization correctness on FP8 models"""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
model_info.check_available_online(on_fail="skip")

pp_size = 1
if num_gpus_available < tp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")

common_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"8",
]
if eager_mode:
common_args.append("--enforce-eager")

# Configuration WITH FP8 AllGather optimization
fp8_allgather_compilation_config = {
"level": 3,
"compile_sizes": [2, 4, 8],
"splitting_ops": [],
"pass_config": {
"enable_async_tp": True,
"enable_fp8_allgather_opt": fp8_allgather_enabled,
},
}

# Configuration WITHOUT FP8 AllGather optimization (baseline)
baseline_compilation_config = {
"level": 3,
"compile_sizes": [2, 4, 8],
"splitting_ops": [],
"pass_config": {"enable_async_tp": False, "enable_fp8_allgather_opt": False},
}

fp8_allgather_env = baseline_env = {
"VLLM_USE_V1": "1",
}

fp8_allgather_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--distributed-executor-backend",
distributed_backend,
"--compilation_config",
json.dumps(fp8_allgather_compilation_config),
]

baseline_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--distributed-executor-backend",
distributed_backend,
"--compilation_config",
json.dumps(baseline_compilation_config),
]

compare_two_settings(
model_id,
fp8_allgather_args,
baseline_args,
fp8_allgather_env,
baseline_env,
method="generate",
)
94 changes: 94 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,30 @@ def _test_completion(
}
)

# test logprobs (for numerical correctness)
completion = client.completions.create(
model=model,
prompt=prompt,
max_tokens=5,
temperature=0.0,
logprobs=5, # Request top 5 logprobs per token
)

# Extract logprobs for each token
logprobs_per_token = []
if completion.choices[0].logprobs and completion.choices[0].logprobs.top_logprobs:
for token_logprobs in completion.choices[0].logprobs.top_logprobs:
if token_logprobs:
logprobs_per_token.append(dict(token_logprobs))

results.append(
{
"test": "logprobs_check",
"text": completion.choices[0].text,
"logprobs": logprobs_per_token,
}
)

return results


Expand Down Expand Up @@ -686,6 +710,76 @@ def compare_all_settings(
)
del ref_result["embedding"]
del compare_result["embedding"]

# Compare logprobs with tolerance for numerical precision
if (
"logprobs" in ref_result
and ref_result.get("test") == "logprobs_check"
):
ref_logprobs = ref_result["logprobs"]
compare_logprobs = compare_result["logprobs"]

assert len(ref_logprobs) == len(compare_logprobs), (
f"Logprobs length mismatch: "
f"{len(ref_logprobs)} vs "
f"{len(compare_logprobs)}"
)

# Track statistics for logging
max_diff = 0.0
min_overlap = 1.0
total_diff = 0.0
total_comparisons = 0

# Compare logprobs for each token position
for token_idx, (
ref_token_logprobs,
compare_token_logprobs,
) in enumerate(zip(ref_logprobs, compare_logprobs)):
# Check that the same tokens appear in top-k
ref_tokens = set(ref_token_logprobs.keys())
compare_tokens = set(compare_token_logprobs.keys())

# Allow for minor differences in top-k tokens
# due to numerical precision. Require at least
# 95% overlap (at most 1 token differs in top-5)
overlap = len(ref_tokens & compare_tokens) / len(ref_tokens)
min_overlap = min(min_overlap, overlap)

assert overlap >= 0.95, (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion for top-k token overlap seems overly strict and contradicts the comment on lines 712-714. The comment states 'Require at least 95% overlap (at most 1 token differs in top-5)', but for a top-5 list, allowing one token to differ corresponds to an 80% overlap (4/5). The current threshold of 0.95 requires a perfect match (5/5 common tokens), which might lead to flaky tests due to minor numerical precision differences inherent with FP8. I suggest lowering the threshold to 0.8 to align with the intention of allowing one differing token.

Suggested change
assert overlap >= 0.95, (
assert overlap >= 0.8, (

f"Token {token_idx}: Top-k tokens differ "
f"too much. Only {overlap * 100:.1f}% overlap.\n"
f"Ref tokens: {ref_tokens}\n"
f"Compare tokens: {compare_tokens}"
)

# Compare logprob values for common tokens
for token in ref_tokens & compare_tokens:
ref_logprob = ref_token_logprobs[token]
compare_logprob = compare_token_logprobs[token]
diff = abs(ref_logprob - compare_logprob)

max_diff = max(max_diff, diff)
total_diff += diff
total_comparisons += 1

# Allow up to 0.02 difference in logprobs
# (~2% probability difference). This is
# consistent with FP8 precision and the
# pattern equivalence test
assert diff <= 0.02, (
f"Token {token_idx}, token '{token}': "
f"Logprob difference too large: "
f"{diff:.4f}\n"
f"Ref: {ref_logprob:.4f}, "
f"Compare: {compare_logprob:.4f}"
)

# Remove logprobs from comparison dict so they
# don't fail exact equality check
del ref_result["logprobs"]
del compare_result["logprobs"]

assert ref_result == compare_result, (
f"Results for {model=} are not the same.\n"
f"{ref_args=} {ref_envs=}\n"
Expand Down
28 changes: 22 additions & 6 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,23 @@ def replacement(
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim=0,
out_dtype=self.dtype,
group_name=self.tp.device_group.group_name,
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)

return gemm_rs
Expand Down Expand Up @@ -296,15 +304,23 @@ def replacement(
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim=0,
out_dtype=self.dtype,
group_name=self.tp.device_group.group_name,
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)

return gemm_rs
Expand Down
Loading