Skip to content

Commit 9cdde47

Browse files
[BugFix] Fix fusion test and add them to CI (#16287)
Signed-off-by: luka <[email protected]>
1 parent b1eb4ca commit 9cdde47

File tree

3 files changed

+74
-49
lines changed

3 files changed

+74
-49
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,14 @@ steps:
292292
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
293293
parallelism: 4
294294

295+
- label: PyTorch Compilation Unit Tests
296+
source_file_dependencies:
297+
- vllm/
298+
- tests/compile
299+
commands:
300+
- pytest -v -s compile/test_pass_manager.py
301+
- pytest -v -s compile/test_fusion.py
302+
295303
- label: PyTorch Fullgraph Smoke Test # 9min
296304
source_file_dependencies:
297305
- vllm/
@@ -301,7 +309,6 @@ steps:
301309
# these tests need to be separated, cannot combine
302310
- pytest -v -s compile/piecewise/test_simple.py
303311
- pytest -v -s compile/piecewise/test_toy_llama.py
304-
- pytest -v -s compile/test_pass_manager.py
305312

306313
- label: PyTorch Fullgraph Test # 18min
307314
source_file_dependencies:

tests/compile/test_full_graph.py

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, Union
5+
from typing import Any, Optional, Union
66

77
import pytest
88
import torch
@@ -15,7 +15,7 @@
1515
from ..utils import create_new_process_for_each_test
1616

1717

18-
def models_list(all: bool):
18+
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
1919
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
2020
("facebook/opt-125m", {}),
2121
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
@@ -32,47 +32,50 @@ def models_list(all: bool):
3232
("meta-llama/Llama-3.2-1B-Instruct", {}),
3333
]
3434

35-
if not all:
36-
return TEST_MODELS
35+
if all:
36+
if is_quant_method_supported("aqlm"):
37+
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
38+
"quantization": "aqlm"
39+
}))
40+
41+
# TODO: figure out why this fails.
42+
if False and is_quant_method_supported("gguf"): # noqa: SIM223
43+
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
44+
"quantization": "gguf"
45+
}))
46+
47+
if is_quant_method_supported("gptq"):
48+
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
49+
"quantization": "gptq"
50+
}))
51+
52+
if is_quant_method_supported("gptq_marlin"):
53+
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
54+
"quantization": "gptq_marlin"
55+
}))
3756

38-
if is_quant_method_supported("aqlm"):
39-
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
40-
"quantization": "aqlm"
41-
}))
42-
43-
# TODO: figure out why this fails.
44-
if False and is_quant_method_supported("gguf"): # noqa: SIM223
45-
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
46-
"quantization": "gguf"
47-
}))
48-
49-
if is_quant_method_supported("gptq"):
50-
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
51-
"quantization": "gptq"
52-
}))
53-
54-
if is_quant_method_supported("gptq_marlin"):
55-
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
56-
"quantization": "gptq_marlin"
57-
}))
58-
59-
if is_quant_method_supported("gptq_marlin_24"):
60-
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
61-
"quantization": "gptq_marlin_24"
62-
}))
63-
64-
if is_quant_method_supported("marlin"):
65-
TEST_MODELS.append(
66-
("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
67-
"quantization": "marlin"
57+
if is_quant_method_supported("gptq_marlin_24"):
58+
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
59+
"quantization": "gptq_marlin_24"
6860
}))
6961

70-
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
71-
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
72-
"quantization": "AWQ"
73-
}))
62+
if is_quant_method_supported("marlin"):
63+
TEST_MODELS.append(
64+
("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
65+
"quantization": "marlin"
66+
}))
7467

75-
return TEST_MODELS
68+
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
69+
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
70+
"quantization": "AWQ"
71+
}))
72+
73+
if keywords is None:
74+
return TEST_MODELS
75+
76+
# filter by keywords
77+
pred = lambda model: any(keyword in model[0] for keyword in keywords)
78+
return list(filter(pred, TEST_MODELS))
7679

7780

7881
@pytest.mark.parametrize(
@@ -96,20 +99,30 @@ def test_full_graph(
9699
run_model(optimization_level, model, model_kwargs)
97100

98101

102+
PassConfig = CompilationConfig.PassConfig
103+
104+
99105
# TODO(luka) add other supported compilation config scenarios here
100106
@pytest.mark.parametrize(
101-
"compilation_config",
102-
# additional compile sizes
107+
"compilation_config, model_info",
103108
[
104-
CompilationConfig(level=CompilationLevel.PIECEWISE,
105-
compile_sizes=[1, 2])
109+
# additional compile sizes, only some of the models
110+
(CompilationConfig(level=CompilationLevel.PIECEWISE,
111+
compile_sizes=[1, 2]), model)
112+
for model in models_list(all=False)
113+
] + [
114+
# RMSNorm + quant fusion, only 8-bit quant models
115+
(CompilationConfig(level=CompilationLevel.PIECEWISE,
116+
custom_ops=["+rms_norm"],
117+
pass_config=PassConfig(enable_fusion=True,
118+
enable_noop=True)), model)
119+
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
106120
])
107121
# only test some of the models
108-
@pytest.mark.parametrize("model_info", models_list(all=False))
109122
@create_new_process_for_each_test()
110123
def test_custom_compile_config(
111-
model_info: tuple[str, dict[str, Any]],
112124
compilation_config: CompilationConfig,
125+
model_info: tuple[str, dict[str, Any]],
113126
):
114127
model, model_kwargs = model_info
115128
print(f"MODEL={model}")

tests/compile/test_fusion.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,17 @@ def forward(self, x):
4444
resid = torch.sqrt(x)
4545
y = self.norm[0](x)
4646

47-
x2 = self.fp8_linear.apply(y, self.w[0], self.wscale[0], self.scale[0])
47+
x2 = self.fp8_linear.apply(y,
48+
self.w[0],
49+
self.wscale[0],
50+
input_scale=self.scale[0])
4851
# make sure resid is used for replacement to work
4952
y2, resid = self.norm[1](x2, resid)
5053

51-
x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1],
52-
self.scale[1])
54+
x3 = self.fp8_linear.apply(y2,
55+
self.w[1],
56+
self.wscale[1],
57+
input_scale=self.scale[1])
5358
y3, resid = self.norm[2](x3, resid) # use resid here
5459
return y3
5560

0 commit comments

Comments
 (0)