Skip to content

Commit 0f959c2

Browse files
committed
test: add unit and integration tests for fusion-aware file grouping
- 7 unit tests for group_files_by_fused_weights covering co-located, cross-shard, three-way split, independent layers, and edge cases - 3 integration tests for process_file_group_microscale_scheme verifying correct key output, original sharding preserved, and size consistency
1 parent 134ac0e commit 0f959c2

File tree

1 file changed

+238
-0
lines changed

1 file changed

+238
-0
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import pytest
2+
import torch
3+
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
4+
from safetensors.torch import save_file
5+
6+
from llmcompressor.entrypoints.model_free.helpers import group_files_by_fused_weights
7+
from llmcompressor.entrypoints.model_free.process import (
8+
process_file_group_microscale_scheme,
9+
process_file_microscale_scheme,
10+
)
11+
12+
13+
def _make_nvfp4_scheme():
14+
return QuantizationScheme(
15+
targets=["Linear"],
16+
weights=QuantizationArgs(
17+
num_bits=4,
18+
type="float",
19+
strategy="tensor_group",
20+
group_size=16,
21+
symmetric=True,
22+
dynamic=False,
23+
scale_dtype=torch.float8_e4m3fn,
24+
),
25+
)
26+
27+
28+
def _rand_weight(*shape):
29+
return torch.randn(*shape, dtype=torch.float16)
30+
31+
32+
class TestGroupFilesByFusedWeights:
33+
def test_no_cross_shard_fused_weights_returns_singletons(self):
34+
weight_map = {
35+
"model.layers.0.self_attn.q_proj.weight": "shard-00001.safetensors",
36+
"model.layers.0.self_attn.k_proj.weight": "shard-00001.safetensors",
37+
"model.layers.0.self_attn.v_proj.weight": "shard-00001.safetensors",
38+
"model.layers.0.mlp.gate_proj.weight": "shard-00002.safetensors",
39+
"model.layers.0.mlp.up_proj.weight": "shard-00002.safetensors",
40+
"model.layers.0.mlp.down_proj.weight": "shard-00002.safetensors",
41+
}
42+
groups = group_files_by_fused_weights(weight_map)
43+
assert len(groups) == 2
44+
assert all(len(g) == 1 for g in groups)
45+
46+
def test_cross_shard_qkv_grouped_together(self):
47+
weight_map = {
48+
"model.layers.0.self_attn.q_proj.weight": "shard-00001.safetensors",
49+
"model.layers.0.self_attn.k_proj.weight": "shard-00002.safetensors",
50+
"model.layers.0.self_attn.v_proj.weight": "shard-00002.safetensors",
51+
"model.layers.0.mlp.down_proj.weight": "shard-00001.safetensors",
52+
}
53+
groups = group_files_by_fused_weights(weight_map)
54+
assert len(groups) == 1
55+
assert sorted(groups[0]) == [
56+
"shard-00001.safetensors",
57+
"shard-00002.safetensors",
58+
]
59+
60+
def test_cross_shard_gate_up_grouped_together(self):
61+
weight_map = {
62+
"model.layers.0.mlp.gate_proj.weight": "shard-00001.safetensors",
63+
"model.layers.0.mlp.up_proj.weight": "shard-00002.safetensors",
64+
"model.layers.0.mlp.down_proj.weight": "shard-00002.safetensors",
65+
}
66+
groups = group_files_by_fused_weights(weight_map)
67+
assert len(groups) == 1
68+
assert sorted(groups[0]) == [
69+
"shard-00001.safetensors",
70+
"shard-00002.safetensors",
71+
]
72+
73+
def test_independent_layers_not_merged(self):
74+
weight_map = {
75+
"model.layers.0.self_attn.q_proj.weight": "shard-00001.safetensors",
76+
"model.layers.0.self_attn.k_proj.weight": "shard-00001.safetensors",
77+
"model.layers.0.self_attn.v_proj.weight": "shard-00001.safetensors",
78+
"model.layers.1.self_attn.q_proj.weight": "shard-00002.safetensors",
79+
"model.layers.1.self_attn.k_proj.weight": "shard-00002.safetensors",
80+
"model.layers.1.self_attn.v_proj.weight": "shard-00002.safetensors",
81+
}
82+
groups = group_files_by_fused_weights(weight_map)
83+
assert len(groups) == 2
84+
assert all(len(g) == 1 for g in groups)
85+
86+
def test_three_way_cross_shard_group(self):
87+
weight_map = {
88+
"model.layers.0.self_attn.q_proj.weight": "shard-00001.safetensors",
89+
"model.layers.0.self_attn.k_proj.weight": "shard-00002.safetensors",
90+
"model.layers.0.self_attn.v_proj.weight": "shard-00003.safetensors",
91+
}
92+
groups = group_files_by_fused_weights(weight_map)
93+
assert len(groups) == 1
94+
assert sorted(groups[0]) == [
95+
"shard-00001.safetensors",
96+
"shard-00002.safetensors",
97+
"shard-00003.safetensors",
98+
]
99+
100+
def test_empty_weight_map(self):
101+
groups = group_files_by_fused_weights({})
102+
assert groups == []
103+
104+
def test_single_file_no_fused_weights(self):
105+
weight_map = {
106+
"model.embed_tokens.weight": "model.safetensors",
107+
"lm_head.weight": "model.safetensors",
108+
}
109+
groups = group_files_by_fused_weights(weight_map)
110+
assert len(groups) == 1
111+
assert groups[0] == ["model.safetensors"]
112+
113+
114+
class TestProcessFileGroupMicroscaleScheme:
115+
@pytest.fixture
116+
def qkv_tensors(self):
117+
return {
118+
"model.layers.0.self_attn.q_proj.weight": _rand_weight(32, 32),
119+
"model.layers.0.self_attn.k_proj.weight": _rand_weight(32, 32),
120+
"model.layers.0.self_attn.v_proj.weight": _rand_weight(32, 32),
121+
"model.layers.0.mlp.down_proj.weight": _rand_weight(32, 32),
122+
}
123+
124+
def _save_split_shards(self, tmp_path, tensors):
125+
shard1 = {"model.layers.0.self_attn.q_proj.weight":
126+
tensors["model.layers.0.self_attn.q_proj.weight"]}
127+
shard2 = {k: v for k, v in tensors.items()
128+
if k != "model.layers.0.self_attn.q_proj.weight"}
129+
shard1_path = tmp_path / "shard-00001.safetensors"
130+
shard2_path = tmp_path / "shard-00002.safetensors"
131+
save_file(shard1, shard1_path)
132+
save_file(shard2, shard2_path)
133+
return shard1_path, shard2_path
134+
135+
def _save_merged_shard(self, tmp_path, tensors):
136+
merged_path = tmp_path / "merged.safetensors"
137+
save_file(tensors, merged_path)
138+
return merged_path
139+
140+
def test_group_processing_produces_same_keys_as_single_shard(
141+
self, qkv_tensors, tmp_path
142+
):
143+
scheme = _make_nvfp4_scheme()
144+
split_dir = tmp_path / "split"
145+
split_dir.mkdir()
146+
merged_dir = tmp_path / "merged"
147+
merged_dir.mkdir()
148+
group_out_dir = tmp_path / "group_out"
149+
group_out_dir.mkdir()
150+
merged_out_dir = tmp_path / "merged_out"
151+
merged_out_dir.mkdir()
152+
153+
shard1_path, shard2_path = self._save_split_shards(split_dir, qkv_tensors)
154+
merged_path = self._save_merged_shard(merged_dir, qkv_tensors)
155+
156+
save_paths = [
157+
group_out_dir / "shard-00001.safetensors",
158+
group_out_dir / "shard-00002.safetensors",
159+
]
160+
_, weight_map_group = process_file_group_microscale_scheme(
161+
file_paths=[shard1_path, shard2_path],
162+
save_paths=save_paths,
163+
scheme=scheme,
164+
ignore=[],
165+
device="cpu",
166+
)
167+
168+
_, weight_map_merged = process_file_microscale_scheme(
169+
file_path=merged_path,
170+
save_path=merged_out_dir / "merged.safetensors",
171+
scheme=scheme,
172+
ignore=[],
173+
device="cpu",
174+
)
175+
176+
assert set(weight_map_group.keys()) == set(weight_map_merged.keys())
177+
178+
def test_group_processing_preserves_original_sharding(
179+
self, qkv_tensors, tmp_path
180+
):
181+
scheme = _make_nvfp4_scheme()
182+
split_dir = tmp_path / "split"
183+
split_dir.mkdir()
184+
out_dir = tmp_path / "out"
185+
out_dir.mkdir()
186+
187+
shard1_path, shard2_path = self._save_split_shards(split_dir, qkv_tensors)
188+
save_paths = [
189+
out_dir / "shard-00001.safetensors",
190+
out_dir / "shard-00002.safetensors",
191+
]
192+
process_file_group_microscale_scheme(
193+
file_paths=[shard1_path, shard2_path],
194+
save_paths=save_paths,
195+
scheme=scheme,
196+
ignore=[],
197+
device="cpu",
198+
)
199+
200+
for save_path in save_paths:
201+
assert save_path.exists()
202+
assert save_path.stat().st_size > 0
203+
204+
def test_group_processing_total_size_matches_merged(
205+
self, qkv_tensors, tmp_path
206+
):
207+
scheme = _make_nvfp4_scheme()
208+
split_dir = tmp_path / "split"
209+
split_dir.mkdir()
210+
merged_dir = tmp_path / "merged"
211+
merged_dir.mkdir()
212+
group_out_dir = tmp_path / "group_out"
213+
group_out_dir.mkdir()
214+
merged_out_dir = tmp_path / "merged_out"
215+
merged_out_dir.mkdir()
216+
217+
shard1_path, shard2_path = self._save_split_shards(split_dir, qkv_tensors)
218+
merged_path = self._save_merged_shard(merged_dir, qkv_tensors)
219+
220+
save_paths = [
221+
group_out_dir / "shard-00001.safetensors",
222+
group_out_dir / "shard-00002.safetensors",
223+
]
224+
total_size_group, _ = process_file_group_microscale_scheme(
225+
file_paths=[shard1_path, shard2_path],
226+
save_paths=save_paths,
227+
scheme=scheme,
228+
ignore=[],
229+
device="cpu",
230+
)
231+
total_size_merged, _ = process_file_microscale_scheme(
232+
file_path=merged_path,
233+
save_path=merged_out_dir / "merged.safetensors",
234+
scheme=scheme,
235+
ignore=[],
236+
device="cpu",
237+
)
238+
assert total_size_group == total_size_merged

0 commit comments

Comments
 (0)