Skip to content

Commit 44f6b0d

Browse files
authored
Consolidate example script tests into single parametrized test (#1801)
SUMMARY: Combines the logic from `tests/examples/test_*.py` into `tests/examples/test_example_scripts.py` which has a parametrized test fn that runs them instead. Notes: - Removed README parsing: some tests previously parsed a README file and then ran a code block from that file. However, these code blocks where always just `python3 some_script.py`, so I replaced these by just calling the script instead. - A handful of tests had additional extra handling. I maintained the behavior of these by adding additional options to the `TestCase` namedtuple / making the test function more flexible. (i.e. pre-processing, flags, and post-processing verification) TEST PLAN: All of the example tests essentially boil down to "run the example script and check if it crashes". A few also have additional checks or preprocessing steps but that is the main idea. To run the scripts, the tests all (both before and after changes) call `run_cli_command` in `tests/testing_utils.py` which does the actual `python ...` call. Therefore, to test this change I replaced that function with a dummy function that just prints the command and returns a success. Then I verified that the printed commands matched before and after changes (excluding reorderings of the calls). ```python def run_cli_command(cmd: List[str], cwd: Optional[Union[str, Path]] = None): print() print(" ".join(cmd), "in", str(cwd).split("examples/")[-1]) class DummyReturn: returncode = 0 return DummyReturn # return run(cmd, stdout=PIPE, stderr=STDOUT, check=False, encoding="utf-8", cwd=cwd) ``` In addition, to verify that the special handling works correctly I ran the full examples test suite and confirmed that the tests with special handling still passed. --------- Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 0711630 commit 44f6b0d

18 files changed

+271
-731
lines changed

tests/e2e/vLLM/test_vllm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
from llmcompressor.core import active_session
1414
from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing
15-
from tests.examples.utils import requires_gpu_count
1615
from tests.test_timer.timer_utils import get_singleton_manager, log_time
16+
from tests.testing_utils import requires_gpu
1717

1818
HF_MODEL_HUB_NAME = "nm-testing"
1919

@@ -35,7 +35,7 @@
3535

3636
# Will run each test case in its own process through run_tests.sh
3737
# emulating vLLM CI testing
38-
@requires_gpu_count(1)
38+
@requires_gpu(1)
3939
@pytest.mark.parametrize(
4040
"test_data_file", [pytest.param(TEST_DATA_FILE, id=TEST_DATA_FILE)]
4141
)

tests/examples/test_awq.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

tests/examples/test_big_models_with_sequential_onloading.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

tests/examples/test_compressed_inference.py

Lines changed: 0 additions & 31 deletions
This file was deleted.
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import shlex
5+
import shutil
6+
import sys
7+
from pathlib import Path
8+
from typing import Callable, List, NamedTuple
9+
10+
import pytest
11+
from transformers import AutoConfig
12+
13+
from tests.testing_utils import requires_gpu, run_cli_command
14+
15+
16+
def replace_2of4_w4a16_recipe(content: str) -> str:
17+
return content.replace("2of4_w4a16_recipe.yaml", "2of4_w4a16_group-128_recipe.yaml")
18+
19+
20+
def verify_2of4_w4a16_output(tmp_path: Path, example_dir: str):
21+
output_dir = Path("output_llama7b_2of4_w4a16_channel")
22+
23+
stages = {
24+
"quantization": {
25+
"path": Path("quantization_stage"),
26+
"format": "marlin-24",
27+
},
28+
"sparsity": {
29+
"path": Path("sparsity_stage"),
30+
"format": "sparse-24-bitmask",
31+
},
32+
"finetuning": {
33+
"path": Path("finetuning_stage"),
34+
"format": "sparse-24-bitmask",
35+
},
36+
}
37+
38+
for stage, stage_info in stages.items():
39+
stage_path = tmp_path / example_dir / output_dir / stage_info["path"]
40+
recipe_path = stage_path / "recipe.yaml"
41+
config_path = stage_path / "config.json"
42+
43+
assert recipe_path.exists(), f"Missing recipe file in {stage}: {recipe_path}"
44+
assert config_path.exists(), f"Missing config file in {stage}: {config_path}"
45+
46+
config = AutoConfig.from_pretrained(stage_path)
47+
assert config is not None, f"Failed to load config in {stage}"
48+
49+
quant_config = getattr(config, "quantization_config", {})
50+
if stage == "quantization":
51+
actual_format = quant_config.get("format")
52+
else:
53+
actual_format = quant_config.get("sparsity_config", {}).get("format")
54+
55+
assert actual_format, f"Missing expected format field in {stage} config"
56+
assert actual_format == stage_info["format"], (
57+
f"Unexpected format in {stage}: got '{actual_format}', "
58+
f"expected '{stage_info['format']}'"
59+
)
60+
61+
62+
def verify_w4a4_fp4_output(tmp_path: Path, example_dir: str):
63+
# verify the expected directory was generated
64+
nvfp4_dirs: List[Path] = [p for p in tmp_path.rglob("*-NVFP4") if p.is_dir()]
65+
assert (
66+
len(nvfp4_dirs)
67+
) == 1, f"did not find exactly one generated folder: {nvfp4_dirs}"
68+
69+
# verify the format in the generated config
70+
config_json = json.loads((nvfp4_dirs[0] / "config.json").read_text())
71+
config_format = config_json["quantization_config"]["format"]
72+
assert config_format == "nvfp4-pack-quantized"
73+
74+
75+
class TestCase(NamedTuple):
76+
path: str
77+
flags: tuple[str] = ()
78+
preprocess_fn: None | Callable[[str], str] = None
79+
# verify_fn(tmp_path, example_dir)
80+
verify_fn: Callable[[Path, str], None] | None = None
81+
82+
def __repr__(self):
83+
values = [f"'{self.path}'"]
84+
for attr_name in ["flags", "preprocess_fn", "verify_fn"]:
85+
attr = getattr(self, attr_name)
86+
if attr:
87+
if callable(attr):
88+
attr_repr = attr.__name__
89+
else:
90+
attr_repr = repr(attr)
91+
92+
values.append(f"{attr_name}={attr_repr}")
93+
94+
return f"{self.__class__.__name__}({', '.join(values)})"
95+
96+
97+
@pytest.mark.example
98+
@requires_gpu(1)
99+
@pytest.mark.parametrize(
100+
"test_case",
101+
[
102+
"awq/llama_example.py",
103+
"awq/qwen3_moe_example.py",
104+
"big_models_with_sequential_onloading/llama3.3_70b.py",
105+
"compressed_inference/fp8_compressed_inference.py",
106+
"quantization_kv_cache/llama3_fp8_kv_example.py",
107+
"quantization_w4a16/llama3_example.py",
108+
"quantization_w8a8_fp8/gemma2_example.py",
109+
"quantization_w8a8_fp8/fp8_block_example.py",
110+
"quantization_w8a8_fp8/llama3_example.py",
111+
"quantization_w8a8_int8/llama3_example.py",
112+
"quantization_w8a8_int8/gemma2_example.py",
113+
"quantizing_moe/mixtral_example.py",
114+
pytest.param(
115+
"quantizing_moe/mixtral_example.py",
116+
marks=(requires_gpu(2), pytest.mark.multi_gpu),
117+
),
118+
"quantizing_moe/qwen_example.py",
119+
# sparse_2of4
120+
"sparse_2of4_quantization_fp8/llama3_8b_2of4.py",
121+
TestCase(
122+
"sparse_2of4_quantization_fp8/llama3_8b_2of4.py",
123+
flags=["--fp8"],
124+
),
125+
TestCase(
126+
"quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py",
127+
preprocess_fn=replace_2of4_w4a16_recipe,
128+
),
129+
TestCase(
130+
"quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py",
131+
verify_fn=verify_2of4_w4a16_output,
132+
),
133+
# w4a4_fp4
134+
TestCase(
135+
"quantization_w4a4_fp4/llama3_example.py", verify_fn=verify_w4a4_fp4_output
136+
),
137+
TestCase(
138+
"quantization_w4a4_fp4/llama4_example.py", verify_fn=verify_w4a4_fp4_output
139+
),
140+
TestCase(
141+
"quantization_w4a4_fp4/qwen_30b_a3b.py", verify_fn=verify_w4a4_fp4_output
142+
),
143+
# skips
144+
pytest.param(
145+
"quantizing_moe/deepseek_r1_example.py",
146+
marks=pytest.mark.skip("exceptionally long run time"),
147+
),
148+
pytest.param(
149+
"trl_mixin/ex_trl_constant.py",
150+
marks=pytest.mark.skip("disabled until further updates"),
151+
),
152+
pytest.param(
153+
"trl_mixin/ex_trl_distillation.py",
154+
marks=(
155+
pytest.mark.skip("disabled until further updates"),
156+
pytest.mark.multi_gpu,
157+
),
158+
),
159+
],
160+
ids=repr,
161+
)
162+
def test_example_scripts(test_case: str | TestCase, tmp_path: Path):
163+
if isinstance(test_case, str):
164+
test_case = TestCase(test_case)
165+
166+
example_subdir, filename = test_case.path.rsplit("/", 1)
167+
example_dir = f"examples/{example_subdir}"
168+
169+
command = [sys.executable, filename]
170+
if test_case.flags:
171+
command.extend(test_case.flags)
172+
173+
script_working_dir = tmp_path / example_dir
174+
shutil.copytree(Path.cwd() / example_dir, script_working_dir)
175+
176+
if test_case.preprocess_fn:
177+
path = script_working_dir / filename
178+
content = path.read_text(encoding="utf-8")
179+
content = test_case.preprocess_fn(content)
180+
path.write_text(content, encoding="utf-8")
181+
182+
result = run_cli_command(command, cwd=script_working_dir)
183+
184+
assert result.returncode == 0, (
185+
f"command failed with exit code {result.returncode}:\n"
186+
f"Command:\n{shlex.join(command)}\nOutput:\n{result.stdout}"
187+
)
188+
189+
if test_case.verify_fn:
190+
test_case.verify_fn(tmp_path, example_dir)

0 commit comments

Comments
 (0)