Skip to content

Commit 031d912

Browse files
brian-dellabettadzhengAPgemini-code-assist[bot]
authored
[model_free_ptq] build job cleanup (#2545)
SUMMARY: Follow-up to #2498 and pre-cursor to landing #2491. This PR cleans up a few things: - [x] Use the same function signature for building standard jobs, microscale jobs, and validation jobs. These will be needed in #2491. - [x] Renamed microscale-specific `build_inverse_weights_map` -> `build_microscale_inverse_weights_map` because other reindexing logic will need different functionality when determining fused tensors. - [x] Prunes unused `_get_all_tensor_names` - [x] Breaks out loading logic for inverse_weights_map to a helper that can be moved to CT in follow-up #2491 TEST PLAN: No net new functionality, if all tests pass should be good to go --------- Signed-off-by: David Zheng <dqzheng1996@gmail.com> Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Signed-off-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> Co-authored-by: David Zheng <dqzheng1996@gmail.com> Co-authored-by: David Zheng <153074367+dzhengAP@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 83d2001 commit 031d912

File tree

5 files changed

+78
-151
lines changed

5 files changed

+78
-151
lines changed

src/llmcompressor/entrypoints/model_free/__init__.py

Lines changed: 24 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
gpu_if_available,
2222
)
2323
from llmcompressor.entrypoints.model_free.microscale import (
24-
build_inverse_weights_map,
24+
build_microscale_inverse_weights_map,
2525
is_microscale_scheme,
2626
)
2727
from llmcompressor.entrypoints.model_free.process import (
@@ -87,17 +87,10 @@ def model_free_ptq(
8787
shutil.copyfile(resolved_path, save_path)
8888

8989
# build quantization jobs
90-
if is_microscale_scheme(scheme):
91-
jobs = _build_microscale_jobs(
92-
model_files, save_directory, scheme, ignore, device, converter
93-
)
94-
else:
95-
jobs = _build_standard_jobs(
96-
model_files, save_directory, scheme, ignore, device, converter
97-
)
90+
jobs = _build_jobs(model_files, save_directory, scheme, ignore, device, converter)
9891

9992
# 1. validate quantizable tensors — fail fast before long-running quantization
100-
validate_jobs = _build_validate_jobs(jobs)
93+
validate_jobs = [(validate_file, *job[1:]) for job in jobs]
10194
exec_jobs(validate_jobs, max_workers, desc="Validating")
10295

10396
# 2-5. quantize and compress weights
@@ -114,29 +107,7 @@ def model_free_ptq(
114107
update_config(save_directory, scheme_name, scheme, ignore, converter)
115108

116109

117-
def _build_standard_jobs(
118-
model_files: dict[str, str],
119-
save_directory: str | os.PathLike,
120-
scheme: QuantizationScheme,
121-
ignore: Iterable[str],
122-
device: torch.device,
123-
converter: Converter | None,
124-
job_fn=None,
125-
) -> list[tuple]:
126-
"""Build one job per safetensors file using the given processing function."""
127-
if job_fn is None:
128-
job_fn = process_file
129-
jobs = []
130-
for file_path, resolved_path in model_files.items():
131-
if file_path.endswith("safetensors"):
132-
save_path = Path(save_directory) / file_path
133-
jobs.append(
134-
(job_fn, resolved_path, save_path, scheme, ignore, device, converter)
135-
)
136-
return jobs
137-
138-
139-
def _build_microscale_jobs(
110+
def _build_jobs(
140111
model_files: dict[str, str],
141112
save_directory: str | os.PathLike,
142113
scheme: QuantizationScheme,
@@ -152,10 +123,17 @@ def _build_microscale_jobs(
152123
from other shards. This avoids runtime fused-partner discovery inside the
153124
process function and eliminates redundant tensor reads.
154125
155-
Job tuple format:
156-
(process_file_microscale_scheme, inverse_weights_map, save_path,
157-
scheme, ignore, device, converter)
126+
:returns: list of jobs tuples
127+
(job_fn, inverse_weights_map, save_path, scheme, ignore, device, converter)
158128
"""
129+
if is_microscale_scheme(scheme):
130+
job_fn = process_file_microscale_scheme
131+
build_inverse_weights_map = build_microscale_inverse_weights_map
132+
else:
133+
job_fn = process_file
134+
# TODO brian-dellabetta (#2491): update here in follow-up PR based on converter
135+
build_inverse_weights_map = None
136+
159137
index_file = find_safetensors_index_file(model_files)
160138

161139
if index_file is None:
@@ -170,7 +148,7 @@ def _build_microscale_jobs(
170148
inverse_weights_map = {resolved_path: []}
171149
jobs.append(
172150
(
173-
process_file_microscale_scheme,
151+
job_fn,
174152
inverse_weights_map,
175153
save_path,
176154
scheme,
@@ -194,11 +172,14 @@ def _build_microscale_jobs(
194172

195173
# Precompute exactly which tensors to load from which files for this shard,
196174
# including fused partner tensors that live in other shards
197-
inverse_weights_map = build_inverse_weights_map(
198-
shard_name=shard_name,
199-
weight_map=weight_map,
200-
model_files=model_files,
201-
)
175+
if build_inverse_weights_map is None:
176+
inverse_weights_map = {resolved_path: []}
177+
else:
178+
inverse_weights_map = build_inverse_weights_map(
179+
shard_name=shard_name,
180+
weight_map=weight_map,
181+
model_files=model_files,
182+
)
202183

203184
if len(inverse_weights_map) > 1:
204185
partner_shards = [s for s in inverse_weights_map if s != resolved_path]
@@ -209,7 +190,7 @@ def _build_microscale_jobs(
209190

210191
jobs.append(
211192
(
212-
process_file_microscale_scheme,
193+
job_fn,
213194
inverse_weights_map,
214195
save_path,
215196
scheme,
@@ -220,63 +201,3 @@ def _build_microscale_jobs(
220201
)
221202

222203
return jobs
223-
224-
225-
def _build_validate_jobs(jobs: list[tuple]) -> list[tuple]:
226-
"""
227-
Build validation jobs from processing jobs.
228-
229-
Handles both job formats:
230-
- Standard/fallback: (proc_fn, file_path_str, save_path, scheme, ignore, device, \
231-
converter)
232-
- Microscale with index: (proc_fn, inverse_weights_map_dict, save_path, scheme, \
233-
ignore, device, converter)
234-
"""
235-
validate_jobs = []
236-
for job in jobs:
237-
# job[0] is the processing function
238-
# Check if second element is a dict (microscale with index)
239-
# or string (standard/fallback)
240-
second_arg = job[1]
241-
242-
if isinstance(second_arg, dict):
243-
# Microscale job with inverse_weights_map dict
244-
_, inverse_weights_map, save_path, scheme, ignore, device, converter = job
245-
# Use first source file path from inverse_weights_map for validation
246-
source_file = next(iter(inverse_weights_map.keys()))
247-
validate_jobs.append(
248-
(
249-
validate_file,
250-
source_file,
251-
save_path,
252-
scheme,
253-
ignore,
254-
device,
255-
converter,
256-
inverse_weights_map,
257-
)
258-
)
259-
else:
260-
# Standard job or microscale fallback: second_arg is file_path string
261-
_, file_path, save_path, scheme, ignore, device, converter = job
262-
validate_jobs.append(
263-
(
264-
validate_file,
265-
file_path,
266-
save_path,
267-
scheme,
268-
ignore,
269-
device,
270-
converter,
271-
None,
272-
)
273-
)
274-
return validate_jobs
275-
276-
277-
def _get_all_tensor_names(file_path: str) -> list[str]:
278-
"""Get all tensor names from a safetensors file without loading tensors."""
279-
from safetensors import safe_open
280-
281-
with safe_open(file_path, framework="pt", device="cpu") as f:
282-
return list(f.keys())

src/llmcompressor/entrypoints/model_free/microscale.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010

1111
__all__ = [
12-
"build_inverse_weights_map",
12+
"build_microscale_inverse_weights_map",
1313
"is_microscale_scheme",
1414
"get_fused_names",
1515
"DEFAULT_FUSED_MAPPINGS",
@@ -78,7 +78,7 @@ def get_fused_names(
7878
return matched, unmatched
7979

8080

81-
def build_inverse_weights_map(
81+
def build_microscale_inverse_weights_map(
8282
shard_name: str,
8383
weight_map: dict[str, str],
8484
model_files: dict[str, str],

src/llmcompressor/entrypoints/model_free/process.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from compressed_tensors.quantization import QuantizationScheme
99
from compressed_tensors.utils import match_quantizable_tensors
1010
from safetensors import safe_open
11-
from safetensors.torch import load_file, save_file
11+
from safetensors.torch import save_file
1212
from torch.nn import Module
1313

1414
from llmcompressor.entrypoints.model_free.lifecycle import (
@@ -36,7 +36,6 @@ def validate_file(
3636
ignore: Iterable[str],
3737
device: str | torch.device,
3838
converter: Converter | None = None,
39-
weights_map: dict[str, str] | None = None,
4039
):
4140
"""
4241
Validate that each quantizable tensor in a safetensors file can be quantized.
@@ -49,19 +48,8 @@ def validate_file(
4948
:param device: device used to quantize and compress weights
5049
:param converter: optional converter to apply to the checkpoint,
5150
e.g. conversion of some layers from some format to compressed-tensors
52-
:param weights_map: optional mapping of tensor name -> source file path,
53-
built from safetensors.index.json. Reserved for future use by callers
54-
that need cross-shard tensor location lookup during validation.
5551
"""
56-
# Extract file path from inverse_weights_map (standard mode: load all)
57-
# Backward compatibility: handle both dict and Path/string formats
58-
if not isinstance(inverse_weights_map, dict):
59-
# Legacy call with file_path - wrap it as inverse_weights_map
60-
inverse_weights_map = {inverse_weights_map: None}
61-
# Extract source file from inverse_weights_map
62-
source_file = next(iter(inverse_weights_map.keys()))
63-
# Extract source file from inverse_weights_map
64-
tensors = load_file(source_file)
52+
tensors = _load_tensors_from_inverse_weights_map(inverse_weights_map, device)
6553

6654
if converter is not None:
6755
converter.validate(tensors)
@@ -92,16 +80,8 @@ def process_file(
9280
e.g. conversion of some layers from some format to compressed-tensors
9381
"""
9482
assert not is_microscale_scheme(scheme), "Use `process_file_microscale_scheme`"
95-
# Extract file path from inverse_weights_map (standard mode: load all)
96-
# Backward compatibility: handle both dict and Path/string formats
97-
if not isinstance(inverse_weights_map, dict):
98-
# Legacy call with file_path - wrap it as inverse_weights_map
99-
inverse_weights_map = {inverse_weights_map: None}
100-
# Extract source file from inverse_weights_map
101-
source_file = next(iter(inverse_weights_map.keys()))
102-
# Extract source file from inverse_weights_map
103-
source_file = next(iter(inverse_weights_map.keys()))
104-
tensors = load_file(source_file)
83+
84+
tensors = _load_tensors_from_inverse_weights_map(inverse_weights_map, device)
10585

10686
if converter is not None:
10787
converter.process(tensors)
@@ -152,7 +132,7 @@ def process_file_microscale_scheme(
152132
153133
:param inverse_weights_map: mapping of resolved source file path ->
154134
list of tensor names to load from that file. Precomputed by
155-
build_inverse_weights_map() in the job-building phase.
135+
build_microscale_inverse_weights_map() in the job-building phase.
156136
Example: {"/path/shard0.safetensors": ["q_proj.weight"],
157137
"/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]}
158138
:param save_path: output path for this shard's compressed weights
@@ -165,18 +145,7 @@ def process_file_microscale_scheme(
165145
"""
166146
assert is_microscale_scheme(scheme), "Use `process_file` for non-microscale scheme"
167147

168-
# Load all required tensors using true partial reads via safe_open.
169-
# inverse_weights_map tells us exactly which tensors to load from each file —
170-
# no entire-file loads, no runtime discovery.
171-
tensors: dict[str, torch.Tensor] = {}
172-
for source_file, tensor_names in inverse_weights_map.items():
173-
with safe_open(source_file, framework="pt", device="cpu") as f:
174-
available = set(f.keys())
175-
# Load all tensors if tensor_names is None or empty
176-
names_to_load = tensor_names if tensor_names else list(available)
177-
for name in names_to_load:
178-
if name in available:
179-
tensors[name] = f.get_tensor(name)
148+
tensors = _load_tensors_from_inverse_weights_map(inverse_weights_map, device)
180149

181150
if converter is not None:
182151
converter.process(tensors)
@@ -247,3 +216,40 @@ def process_file_microscale_scheme(
247216
total_size = sum(t.nbytes for t in tensors.values())
248217
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
249218
return total_size, weight_map
219+
220+
221+
# TODO brian-dellabetta (#2491): move to compressed-tensors.utils.safetensors_load
222+
def _load_tensors_from_inverse_weights_map(
223+
inverse_weights_map: dict[str, list[str] | None],
224+
device: str | torch.device,
225+
) -> dict[str, torch.Tensor]:
226+
"""
227+
Given an inverse_weights_map, which is a dictionary of file name to list of
228+
tensor names, load up all listed tensor names
229+
230+
:param inverse_weights_map: mapping of resolved source file path ->
231+
list of tensor names to load from that file. Precomputed by
232+
build_inverse_weights_map() in the job-building phase.
233+
If list is empty, all tensors are pulled
234+
Example: {"/path/shard0.safetensors": ["q_proj.weight"],
235+
"/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]}
236+
:param device: tensors will be loaded onto this device.
237+
238+
:returns: mapping of tensor name to actual tensor loaded from safetensors file
239+
Example: {"q_proj.weight": torch.Tensor(...), "k_proj.weight: torch.Tensor(...)}
240+
"""
241+
tensors: dict[str, torch.Tensor] = {}
242+
for source_file, tensor_names in inverse_weights_map.items():
243+
with safe_open(source_file, framework="pt", device=str(device)) as f:
244+
keys = f.keys()
245+
# if tensor_names is empty, pull all tensors
246+
if tensor_names is None or len(tensor_names) == 0:
247+
tensor_names = keys
248+
for tensor_name in tensor_names:
249+
if tensor_name not in keys:
250+
raise ValueError(
251+
f"Expected to find tensor {tensor_name} in "
252+
f"{source_file}, but tensor was not found."
253+
)
254+
tensors[tensor_name] = f.get_tensor(tensor_name)
255+
return tensors

tests/llmcompressor/entrypoints/model_free/test_model_free_validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_validate_file_raises_for_non_2d_linear_weight(tmp_path):
2525
save_file({"model.layers.0.mlp.down_proj.weight": torch.ones(128)}, str(path))
2626

2727
with pytest.raises(ValueError, match="model.layers.0.mlp.down_proj.weight"):
28-
validate_file(path, None, _get_block_scheme(), [], None)
28+
validate_file({str(path): []}, None, _get_block_scheme(), [], "cpu")
2929

3030

3131
def test_validate_file_does_not_raise_for_block_incompatible_shape(tmp_path):
@@ -35,4 +35,4 @@ def test_validate_file_does_not_raise_for_block_incompatible_shape(tmp_path):
3535
str(path),
3636
)
3737

38-
validate_file(path, None, _get_block_scheme(), [], None)
38+
validate_file({str(path): []}, None, _get_block_scheme(), [], "cpu")

0 commit comments

Comments
 (0)