Skip to content

Commit 3544a0e

Browse files
authored
[Distributed] [model_free_ptq] Eliminate reindexing step via fine-grained parallelized partial reads (#2498)
## Purpose Eliminates the `reindex_fused_weights` preprocessing step for microscale schemes (NVFP4, MXFP4) by enabling each shard to be processed independently with full parallelism, even when fused weight sets (q/k/v, gate/up) span multiple shards. ## Approach Instead of grouping shards together (which reduces parallelism), each shard process fetches only the specific fused partner tensors it needs from other shards via targeted partial safetensors reads, computes the fused global scale locally, and writes only its own output shard. No cross-process coordination or file locking required. ## Changes ### `helpers.py` Added `build_tensor_file_index()` — reads `index.json` once at startup and builds a flat mapping of `tensor_name → resolved_file_path`. This gives each worker process an O(1) lookup to find which file contains any fused partner tensor, without re-scanning headers at runtime. ### `process.py` Updated `process_file_microscale_scheme()` with an optional `tensor_file_index` parameter. When provided: - `_fetch_fused_partners()` is called to identify any fused set members missing from the current shard, then fetches only those specific tensors via partial safetensors reads (headers + target tensors only, not full files) - Fused global scale is computed locally using all members of the fused set - `_belongs_to_shard()` ensures only native tensors are written to the output shard — fetched partner tensors are used for scale computation only and never written to the wrong shard ### `__init__.py` Simplified back to one job per shard — full parallelism restored. For microscale schemes, builds the `tensor_file_index` once from `index.json` and passes it to each job. No union-find, no grouping logic needed. ### `validate.py` Removed `NotImplementedError` for cross-shard fused weights — the case is now handled natively. Replaced with `logger.debug` noting that partner tensors will be resolved via partial reads. ## Latest Updates: Eliminate reindexing step via inverse_weights_map with unified job signatures ## Approach Each shard job receives a precomputed `inverse_weights_map` specifying exactly which tensors to load from which files. For cross-shard fused weights, only the shard owning the **primary** tensor (q_proj, gate_proj) fetches its partners — preventing double reads. All jobs share a unified signature for both standard and microscale schemes. ## Changes ### `microscale.py` - Refactor `DEFAULT_FUSED_MAPPINGS` from a list of lists to `{primary_pattern: [partner_templates]}` — only the primary-owning shard fetches its partners, preventing double reads for cross-shard fused weights - Move `build_inverse_weights_map()` here — uses regex match on primary patterns to construct partner names and locate them in other shards ### `process.py` - **Unified signature** for `validate_file`, `process_file`, and `process_file_microscale_scheme`: `(inverse_weights_map, save_path, scheme, ignore, device, converter)` - All functions use `safe_open` + `f.get_tensor()` for true partial reads - Partner tensors re-saved into requesting shard's output; caller updates safetensors index to reflect new locations ### `__init__.py` - Single `_get_weights_map()` helper handles both single-file and multi-file models (reads `safetensors.index.json` or scans file headers via `safe_open`) - Single `_build_quantization_jobs()` replaces separate standard/microscale builders — one job per shard with identical tuple structure for both - Validate jobs use `*job[1:]` for full future-proofing ### `helpers.py` - Removed `build_weights_map` and `build_inverse_weights_map` (moved to `microscale.py`) ### `validate.py` - Removed `NotImplementedError` for cross-shard fused weights — handled natively - Updated to reflect `inverse_weights_map`-based approach ## Testing - `pytest tests/llmcompressor/entrypoints/model_free/` — all passing locally - `make style && make quality` — all checks pass Signed-off-by: David Zheng <dqzheng1996@gmail.com> Closes #2497 Related to #2448 Signed-off-by: David Zheng <dqzheng1996@gmail.com>
1 parent 31e585c commit 3544a0e

File tree

6 files changed

+664
-77
lines changed

6 files changed

+664
-77
lines changed
Lines changed: 208 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
import shutil
34
from pathlib import Path
@@ -12,12 +13,15 @@
1213
from compressed_tensors.utils.safetensors_load import (
1314
get_checkpoint_files,
1415
is_weights_file,
15-
update_safetensors_index,
1616
)
1717
from loguru import logger
1818

19-
from llmcompressor.entrypoints.model_free.helpers import gpu_if_available
19+
from llmcompressor.entrypoints.model_free.helpers import (
20+
find_safetensors_index_file,
21+
gpu_if_available,
22+
)
2023
from llmcompressor.entrypoints.model_free.microscale import (
24+
build_inverse_weights_map,
2125
is_microscale_scheme,
2226
)
2327
from llmcompressor.entrypoints.model_free.process import (
@@ -46,52 +50,55 @@ def model_free_ptq(
4650
converter: Converter | None = None,
4751
):
4852
"""
49-
Quantize a model without the need for a model definition. This function operates on
50-
a model stub or folder containing weights saved in safetensors files
53+
Quantize a model without the need for a model definition. This function
54+
operates on a model stub or folder containing weights saved in safetensors
55+
files.
56+
57+
For microscale schemes (NVFP4, MXFP4), fused weight sets (q/k/v, gate/up)
58+
are handled correctly even when split across shards. Each shard job receives
59+
a precomputed inverse_weights_map specifying exactly which tensors to load
60+
from which files — enabling true partial reads with no runtime discovery
61+
and no redundant tensor reads.
5162
5263
:param model_stub: huggingface model hub or path to local weights files
64+
:param save_directory: directory to save quantized weights to
5365
:param scheme: weight quantization scheme or preset scheme name
54-
:param ignore: modules to ignore. Modules ending with "norm" are automatically
55-
ignored
66+
:param ignore: modules to ignore. Modules ending with "norm" are
67+
automatically ignored
5668
:param max_workers: number of worker threads to process files with
5769
:param device: gpu device to accelerate quantization with
58-
:param converter: optional converter to apply to the checkpoint to convert it to
59-
compressed-tensors format before running model-free PTQ
60-
e.g. conversion of some layers from modelopt format to compressed-tensors
61-
See compressed-tensors convert_checkpoint entrypoint for more information
70+
:param converter: optional converter to apply to the checkpoint to convert
71+
it to compressed-tensors format before running model-free PTQ
6272
"""
6373
# validate arguments
6474
model_files = get_checkpoint_files(model_stub)
6575
scheme_name, scheme = validate_scheme(scheme)
6676
device = gpu_if_available(device)
6777
validate_safetensors_index(model_files, scheme)
6878

69-
# 0. collect safetensors files, copy files
70-
jobs = []
71-
job_fn = (
72-
process_file
73-
if not is_microscale_scheme(scheme)
74-
else process_file_microscale_scheme
75-
)
79+
# copy non-safetensors files (configs, tokenizers, etc.)
7680
for file_path, resolved_path in model_files.items():
77-
save_path = Path(save_directory) / file_path
78-
79-
if file_path.endswith("safetensors"):
80-
jobs.append(
81-
(job_fn, resolved_path, save_path, scheme, ignore, device, converter)
82-
)
83-
84-
else:
81+
if not file_path.endswith("safetensors"):
82+
save_path = Path(save_directory) / file_path
8583
if is_weights_file(file_path):
8684
logger.warning(f"Skip processing for weights file {file_path}")
8785
save_path.parent.mkdir(parents=True, exist_ok=True)
88-
logger.info(f"Copying {file_path} {save_path}")
86+
logger.info(f"Copying {file_path} -> {save_path}")
8987
shutil.copyfile(resolved_path, save_path)
9088

91-
# 1. validate quantizable tensors fail fast before long-running quantization
92-
exec_jobs(
93-
[(validate_file, *job[1:]) for job in jobs], max_workers, desc="Validating"
94-
)
89+
# build quantization jobs
90+
if is_microscale_scheme(scheme):
91+
jobs = _build_microscale_jobs(
92+
model_files, save_directory, scheme, ignore, device, converter
93+
)
94+
else:
95+
jobs = _build_standard_jobs(
96+
model_files, save_directory, scheme, ignore, device, converter
97+
)
98+
99+
# 1. validate quantizable tensors — fail fast before long-running quantization
100+
validate_jobs = _build_validate_jobs(jobs)
101+
exec_jobs(validate_jobs, max_workers, desc="Validating")
95102

96103
# 2-5. quantize and compress weights
97104
total_size = 0
@@ -101,6 +108,175 @@ def model_free_ptq(
101108
total_size += _total_size
102109
weight_map.update(_weight_map)
103110

104-
# 5. update config and safetensors index
111+
# 6. update config and safetensors index
112+
# weight_map may contain tensors re-located to new shards (partner tensors
113+
# re-saved alongside the shard that needed them for fused scale computation)
105114
update_config(save_directory, scheme_name, scheme, ignore, converter)
106-
update_safetensors_index(save_directory, total_size, weight_map)
115+
116+
117+
def _build_standard_jobs(
118+
model_files: dict[str, str],
119+
save_directory: str | os.PathLike,
120+
scheme: QuantizationScheme,
121+
ignore: Iterable[str],
122+
device: torch.device,
123+
converter: Converter | None,
124+
job_fn=None,
125+
) -> list[tuple]:
126+
"""Build one job per safetensors file using the given processing function."""
127+
if job_fn is None:
128+
job_fn = process_file
129+
jobs = []
130+
for file_path, resolved_path in model_files.items():
131+
if file_path.endswith("safetensors"):
132+
save_path = Path(save_directory) / file_path
133+
jobs.append(
134+
(job_fn, resolved_path, save_path, scheme, ignore, device, converter)
135+
)
136+
return jobs
137+
138+
139+
def _build_microscale_jobs(
140+
model_files: dict[str, str],
141+
save_directory: str | os.PathLike,
142+
scheme: QuantizationScheme,
143+
ignore: Iterable[str],
144+
device: torch.device,
145+
converter: Converter | None,
146+
) -> list[tuple]:
147+
"""
148+
Build microscale jobs with precomputed inverse_weights_map per shard.
149+
150+
For each output shard, build_inverse_weights_map() determines exactly which
151+
tensors to load from which source files — including any fused partner tensors
152+
from other shards. This avoids runtime fused-partner discovery inside the
153+
process function and eliminates redundant tensor reads.
154+
155+
Job tuple format:
156+
(process_file_microscale_scheme, inverse_weights_map, save_path,
157+
scheme, ignore, device, converter)
158+
"""
159+
index_file = find_safetensors_index_file(model_files)
160+
161+
if index_file is None:
162+
# Single-file model — no cross-shard fused weights possible,
163+
# Create inverse_weights_map dict format for process_file_microscale_scheme
164+
jobs = []
165+
for file_path, resolved_path in model_files.items():
166+
if file_path.endswith("safetensors"):
167+
save_path = Path(save_directory) / file_path
168+
# Wrap as inverse_weights_map: {source_file: None}
169+
# means load all tensors
170+
inverse_weights_map = {resolved_path: []}
171+
jobs.append(
172+
(
173+
process_file_microscale_scheme,
174+
inverse_weights_map,
175+
save_path,
176+
scheme,
177+
ignore,
178+
device,
179+
converter,
180+
)
181+
)
182+
return jobs
183+
184+
# Read weight map from safetensors.index.json
185+
with open(index_file, "r") as f:
186+
weight_map: dict[str, str] = json.load(f)["weight_map"]
187+
188+
jobs = []
189+
for shard_name, resolved_path in model_files.items():
190+
if not shard_name.endswith("safetensors"):
191+
continue
192+
193+
save_path = Path(save_directory) / shard_name
194+
195+
# Precompute exactly which tensors to load from which files for this shard,
196+
# including fused partner tensors that live in other shards
197+
inverse_weights_map = build_inverse_weights_map(
198+
shard_name=shard_name,
199+
weight_map=weight_map,
200+
model_files=model_files,
201+
)
202+
203+
if len(inverse_weights_map) > 1:
204+
partner_shards = [s for s in inverse_weights_map if s != resolved_path]
205+
logger.info(
206+
f"{shard_name}: will fetch fused partners from "
207+
f"{[os.path.basename(s) for s in partner_shards]}"
208+
)
209+
210+
jobs.append(
211+
(
212+
process_file_microscale_scheme,
213+
inverse_weights_map,
214+
save_path,
215+
scheme,
216+
ignore,
217+
device,
218+
converter,
219+
)
220+
)
221+
222+
return jobs
223+
224+
225+
def _build_validate_jobs(jobs: list[tuple]) -> list[tuple]:
226+
"""
227+
Build validation jobs from processing jobs.
228+
229+
Handles both job formats:
230+
- Standard/fallback: (proc_fn, file_path_str, save_path, scheme, ignore, device, \
231+
converter)
232+
- Microscale with index: (proc_fn, inverse_weights_map_dict, save_path, scheme, \
233+
ignore, device, converter)
234+
"""
235+
validate_jobs = []
236+
for job in jobs:
237+
# job[0] is the processing function
238+
# Check if second element is a dict (microscale with index)
239+
# or string (standard/fallback)
240+
second_arg = job[1]
241+
242+
if isinstance(second_arg, dict):
243+
# Microscale job with inverse_weights_map dict
244+
_, inverse_weights_map, save_path, scheme, ignore, device, converter = job
245+
# Use first source file path from inverse_weights_map for validation
246+
source_file = next(iter(inverse_weights_map.keys()))
247+
validate_jobs.append(
248+
(
249+
validate_file,
250+
source_file,
251+
save_path,
252+
scheme,
253+
ignore,
254+
device,
255+
converter,
256+
inverse_weights_map,
257+
)
258+
)
259+
else:
260+
# Standard job or microscale fallback: second_arg is file_path string
261+
_, file_path, save_path, scheme, ignore, device, converter = job
262+
validate_jobs.append(
263+
(
264+
validate_file,
265+
file_path,
266+
save_path,
267+
scheme,
268+
ignore,
269+
device,
270+
converter,
271+
None,
272+
)
273+
)
274+
return validate_jobs
275+
276+
277+
def _get_all_tensor_names(file_path: str) -> list[str]:
278+
"""Get all tensor names from a safetensors file without loading tensors."""
279+
from safetensors import safe_open
280+
281+
with safe_open(file_path, framework="pt", device="cpu") as f:
282+
return list(f.keys())

src/llmcompressor/entrypoints/model_free/helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,23 @@ def invert_mapping(
9696
inverse[value].append(key)
9797

9898
return inverse
99+
100+
101+
def build_weights_map(
102+
weight_map: dict[str, str],
103+
model_files: dict[str, str],
104+
) -> dict[str, str]:
105+
"""
106+
Build a mapping of tensor name -> resolved file path from the model's
107+
weight_map (index.json). This allows any process to locate fused partner
108+
tensors from other shards without loading entire files.
109+
110+
:param weight_map: mapping of tensor name -> shard filename (from index.json)
111+
:param model_files: mapping of shard filename -> resolved absolute path
112+
:return: mapping of tensor name -> resolved absolute path
113+
"""
114+
return {
115+
tensor_name: model_files[shard_name]
116+
for tensor_name, shard_name in weight_map.items()
117+
if shard_name in model_files
118+
}

0 commit comments

Comments
 (0)