Skip to content

Commit 97948b6

Browse files
committed
feat: eliminate reindexing step via fusion-aware file grouping
model_free_ptq now automatically groups shards that contain cross-file fused weights (q/k/v, gate/up) for joint microscale processing, removing the requirement to run reindex_fused_weights as a preprocessing step. Closes #2497 Signed-off-by: David Zheng <dqzheng1996@gmail.com>
1 parent 026c917 commit 97948b6

File tree

4 files changed

+379
-37
lines changed

4 files changed

+379
-37
lines changed

src/llmcompressor/entrypoints/model_free/__init__.py

Lines changed: 175 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
import shutil
34
from pathlib import Path
@@ -14,12 +15,17 @@
1415
from compressed_tensors.quantization import QuantizationScheme
1516
from loguru import logger
1617

17-
from llmcompressor.entrypoints.model_free.helpers import gpu_if_available
18+
from llmcompressor.entrypoints.model_free.helpers import (
19+
find_safetensors_index_file,
20+
gpu_if_available,
21+
group_files_by_fused_weights,
22+
)
1823
from llmcompressor.entrypoints.model_free.microscale import (
1924
is_microscale_scheme,
2025
)
2126
from llmcompressor.entrypoints.model_free.process import (
2227
process_file,
28+
process_file_group_microscale_scheme,
2329
process_file_microscale_scheme,
2430
validate_file,
2531
)
@@ -45,9 +51,14 @@ def model_free_ptq(
4551
):
4652
"""
4753
Quantize a model without the need for a model definition. This function operates on
48-
a model stub or folder containing weights saved in safetensors files
54+
a model stub or folder containing weights saved in safetensors files.
55+
56+
For microscale schemes (NVFP4, MXFP4), fused weight sets (q/k/v, gate/up) are
57+
automatically grouped for joint processing even when split across shards, removing
58+
the need to run reindex_fused_weights as a preprocessing step.
4959
5060
:param model_stub: huggingface model hub or path to local weights files
61+
:param save_directory: directory to save quantized weights to
5162
:param scheme: weight quantization scheme or preset scheme name
5263
:param ignore: modules to ignore. Modules ending with "norm" are automatically
5364
ignored
@@ -64,32 +75,31 @@ def model_free_ptq(
6475
device = gpu_if_available(device)
6576
validate_safetensors_index(model_files, scheme)
6677

67-
# 0. collect safetensors files, copy files
68-
jobs = []
69-
job_fn = (
70-
process_file
71-
if not is_microscale_scheme(scheme)
72-
else process_file_microscale_scheme
73-
)
78+
# copy non-safetensors files (configs, tokenizers, etc.)
7479
for file_path, resolved_path in model_files.items():
75-
save_path = Path(save_directory) / file_path
76-
77-
if file_path.endswith("safetensors"):
78-
jobs.append(
79-
(job_fn, resolved_path, save_path, scheme, ignore, device, converter)
80-
)
81-
82-
else:
80+
if not file_path.endswith("safetensors"):
81+
save_path = Path(save_directory) / file_path
8382
if is_weights_file(file_path):
8483
logger.warning(f"Skip processing for weights file {file_path}")
8584
save_path.parent.mkdir(parents=True, exist_ok=True)
86-
logger.info(f"Copying {file_path} {save_path}")
85+
logger.info(f"Copying {file_path} -> {save_path}")
8786
shutil.copyfile(resolved_path, save_path)
8887

89-
# 1. validate quantizable tensors fail fast before long-running quantization
90-
exec_jobs(
91-
[(validate_file, *job[1:]) for job in jobs], max_workers, desc="Validating"
88+
# build quantization jobs
89+
if is_microscale_scheme(scheme):
90+
jobs = _build_microscale_jobs(
91+
model_files, save_directory, scheme, ignore, device, converter
92+
)
93+
else:
94+
jobs = _build_standard_jobs(
95+
model_files, save_directory, scheme, ignore, device, converter
96+
)
97+
98+
# 1. validate quantizable tensors — fail fast before long-running quantization
99+
validate_jobs = _make_validate_jobs(
100+
jobs, model_files, scheme, ignore, device, converter
92101
)
102+
exec_jobs(validate_jobs, max_workers, desc="Validating")
93103

94104
# 2-5. quantize and compress weights
95105
total_size = 0
@@ -99,6 +109,149 @@ def model_free_ptq(
99109
total_size += _total_size
100110
weight_map.update(_weight_map)
101111

102-
# 5. update config and safetensors index
112+
# 6. update config and safetensors index
103113
update_config(save_directory, scheme_name, scheme, ignore, converter)
104114
update_safetensors_index(save_directory, total_size, weight_map)
115+
116+
117+
def _build_standard_jobs(
118+
model_files: dict[str, str],
119+
save_directory: str | os.PathLike,
120+
scheme: QuantizationScheme,
121+
ignore: Iterable[str],
122+
device: torch.device,
123+
converter: Converter | None,
124+
) -> list[tuple]:
125+
"""Build one job per safetensors file for non-microscale schemes."""
126+
jobs = []
127+
for file_path, resolved_path in model_files.items():
128+
if file_path.endswith("safetensors"):
129+
save_path = Path(save_directory) / file_path
130+
jobs.append(
131+
(
132+
process_file,
133+
resolved_path,
134+
save_path,
135+
scheme,
136+
ignore,
137+
device,
138+
converter,
139+
)
140+
)
141+
return jobs
142+
143+
144+
def _build_microscale_jobs(
145+
model_files: dict[str, str],
146+
save_directory: str | os.PathLike,
147+
scheme: QuantizationScheme,
148+
ignore: Iterable[str],
149+
device: torch.device,
150+
converter: Converter | None,
151+
) -> list[tuple]:
152+
"""
153+
Build jobs for microscale schemes, grouping files that share fused weight sets
154+
so that global scale fusion works correctly across shard boundaries.
155+
156+
For models where all fused weights are already co-located in single shards,
157+
each group will be a singleton and process_file_microscale_scheme is used.
158+
For models with cross-shard fused weights, multi-file groups are formed and
159+
process_file_group_microscale_scheme is used, eliminating the need for
160+
reindex_fused_weights preprocessing.
161+
"""
162+
index_file = find_safetensors_index_file(model_files)
163+
164+
if index_file is None:
165+
# Single-file model (no index.json) — use standard microscale path
166+
jobs = []
167+
for file_path, resolved_path in model_files.items():
168+
if file_path.endswith("safetensors"):
169+
save_path = Path(save_directory) / file_path
170+
jobs.append(
171+
(
172+
process_file_microscale_scheme,
173+
resolved_path,
174+
save_path,
175+
scheme,
176+
ignore,
177+
device,
178+
converter,
179+
)
180+
)
181+
return jobs
182+
183+
# Read weight map to determine cross-shard fused weight groupings
184+
with open(index_file, "r") as f:
185+
weight_map: dict[str, str] = json.load(f)["weight_map"]
186+
187+
file_groups = group_files_by_fused_weights(weight_map)
188+
jobs = []
189+
190+
for group in file_groups:
191+
if len(group) == 1:
192+
# No cross-shard fused weights — use the standard single-file path
193+
shard_name = group[0]
194+
resolved_path = model_files[shard_name]
195+
save_path = Path(save_directory) / shard_name
196+
jobs.append(
197+
(
198+
process_file_microscale_scheme,
199+
resolved_path,
200+
save_path,
201+
scheme,
202+
ignore,
203+
device,
204+
converter,
205+
)
206+
)
207+
else:
208+
# Cross-shard fused weights — load group jointly
209+
logger.info(
210+
f"Grouping {len(group)} shards for joint microscale processing "
211+
f"(fused weights span multiple files): {group}"
212+
)
213+
file_paths = [model_files[shard] for shard in group]
214+
save_paths = [Path(save_directory) / shard for shard in group]
215+
jobs.append(
216+
(
217+
process_file_group_microscale_scheme,
218+
file_paths,
219+
save_paths,
220+
scheme,
221+
ignore,
222+
device,
223+
converter,
224+
)
225+
)
226+
227+
return jobs
228+
229+
230+
def _make_validate_jobs(
231+
jobs: list[tuple],
232+
model_files: dict[str, str],
233+
scheme: QuantizationScheme,
234+
ignore: Iterable[str],
235+
device: torch.device,
236+
converter: Converter | None,
237+
) -> list[tuple]:
238+
"""
239+
Build validate_file jobs corresponding to the quantization jobs.
240+
For group jobs, creates one validate_file call per file in the group.
241+
"""
242+
validate_jobs = []
243+
for job in jobs:
244+
fn = job[0]
245+
if fn is process_file_group_microscale_scheme:
246+
# job = (fn, file_paths, save_paths, scheme, ignore, device, converter)
247+
file_paths, save_paths = job[1], job[2]
248+
for fp, sp in zip(file_paths, save_paths):
249+
validate_jobs.append(
250+
(validate_file, fp, sp, scheme, ignore, device, converter)
251+
)
252+
else:
253+
# job = (fn, file_path, save_path, scheme, ignore, device, converter)
254+
validate_jobs.append(
255+
(validate_file, job[1], job[2], scheme, ignore, device, converter)
256+
)
257+
return validate_jobs

src/llmcompressor/entrypoints/model_free/helpers.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"match_names_set_eager",
1313
"MatchedNamesSet",
1414
"invert_mapping",
15+
"group_files_by_fused_weights",
1516
]
1617

1718
KeyType = TypeVar("K")
@@ -96,3 +97,58 @@ def invert_mapping(
9697
inverse[value].append(key)
9798

9899
return inverse
100+
101+
102+
def group_files_by_fused_weights(
103+
weight_map: dict[str, str],
104+
) -> list[list[str]]:
105+
"""
106+
Group safetensors files such that files containing complementary fused
107+
weight sets (e.g. q/k/v_proj split across shards) are placed in the
108+
same group for joint processing. Files with no cross-shard fused
109+
dependencies form singleton groups.
110+
111+
This allows model_free_ptq to handle microscale schemes (NVFP4, MXFP4)
112+
without a reindexing preprocessing step, by loading all tensors in a
113+
fused set together at processing time.
114+
115+
:param weight_map: mapping of weight name -> file name (from index.json)
116+
:return: list of file groups; each group is a sorted list of shard file names
117+
"""
118+
# Import here to avoid circular dependency
119+
from llmcompressor.entrypoints.model_free.microscale import get_fused_names
120+
121+
all_tensor_names = list(weight_map.keys())
122+
fused_sets, _ = get_fused_names(all_tensor_names)
123+
124+
# union-find over file names
125+
file_names = sorted(set(weight_map.values()))
126+
parent = {f: f for f in file_names}
127+
128+
def find(x: str) -> str:
129+
while parent[x] != x:
130+
parent[x] = parent[parent[x]] # path compression
131+
x = parent[x]
132+
return x
133+
134+
def union(a: str, b: str) -> None:
135+
parent[find(a)] = find(b)
136+
137+
# union all files that share a fused set
138+
for fused_set in fused_sets:
139+
files_in_set = list(
140+
{
141+
weight_map[name]
142+
for name in fused_set.values()
143+
if name is not None and name in weight_map
144+
}
145+
)
146+
for f in files_in_set[1:]:
147+
union(files_in_set[0], f)
148+
149+
# collect files into groups keyed by their root
150+
groups: dict[str, list[str]] = defaultdict(list)
151+
for f in file_names:
152+
groups[find(f)].append(f)
153+
154+
return [sorted(g) for g in groups.values()]

0 commit comments

Comments
 (0)