Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
3606891
feat: eliminate reindexing step via fine-grained parallelized partial…
dzhengAP Mar 20, 2026
193b87d
Merge branch 'main' into model-free-ptq-runtime-optimization
dzhengAP Mar 25, 2026
21931c9
feat: unify job building with inverse_weights_map dict
dzhengAP Mar 25, 2026
aa57deb
Merge branch 'main' into model-free-ptq-runtime-optimization
dzhengAP Mar 27, 2026
b9914c9
fix: remove stale exports from helpers __all__, fix long line in test
dzhengAP Mar 27, 2026
35cc30b
Merge branch 'main' into model-free-ptq-runtime-optimization
brian-dellabetta Mar 27, 2026
4fb1784
fix changed imports; style fix
brian-dellabetta Mar 27, 2026
498d38e
make process/validate function signatures uniform
brian-dellabetta Mar 27, 2026
2ed8b25
Merge branch 'main' into model-free-ptq-runtime-optimization
brian-dellabetta Mar 27, 2026
b92af83
style fixes
brian-dellabetta Mar 27, 2026
177efba
Merge branch 'main' into bdellabe/model-free-ptq-cleanup
brian-dellabetta Mar 30, 2026
b029136
cleanup p2
brian-dellabetta Mar 30, 2026
627713d
make signatures the same
brian-dellabetta Mar 30, 2026
8bdf31d
Apply suggestion from @gemini-code-assist[bot]
brian-dellabetta Mar 30, 2026
ccecc3d
gemini suggestion
brian-dellabetta Mar 30, 2026
e3e4586
Merge branch 'main' into bdellabe/model-free-ptq-cleanup
brian-dellabetta Mar 30, 2026
8073749
stylefixes
brian-dellabetta Mar 30, 2026
d7b7c67
Merge branch 'main' into bdellabe/model-free-ptq-cleanup
brian-dellabetta Mar 31, 2026
5c7e9f0
Merge branch 'main' into bdellabe/model-free-ptq-cleanup
brian-dellabetta Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 24 additions & 103 deletions src/llmcompressor/entrypoints/model_free/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
gpu_if_available,
)
from llmcompressor.entrypoints.model_free.microscale import (
build_inverse_weights_map,
build_microscale_inverse_weights_map,
is_microscale_scheme,
)
from llmcompressor.entrypoints.model_free.process import (
Expand Down Expand Up @@ -87,17 +87,10 @@ def model_free_ptq(
shutil.copyfile(resolved_path, save_path)

# build quantization jobs
if is_microscale_scheme(scheme):
jobs = _build_microscale_jobs(
model_files, save_directory, scheme, ignore, device, converter
)
else:
jobs = _build_standard_jobs(
model_files, save_directory, scheme, ignore, device, converter
)
jobs = _build_jobs(model_files, save_directory, scheme, ignore, device, converter)

# 1. validate quantizable tensors — fail fast before long-running quantization
validate_jobs = _build_validate_jobs(jobs)
validate_jobs = [(validate_file, *job[1:]) for job in jobs]
exec_jobs(validate_jobs, max_workers, desc="Validating")

# 2-5. quantize and compress weights
Expand All @@ -114,29 +107,7 @@ def model_free_ptq(
update_config(save_directory, scheme_name, scheme, ignore, converter)


def _build_standard_jobs(
model_files: dict[str, str],
save_directory: str | os.PathLike,
scheme: QuantizationScheme,
ignore: Iterable[str],
device: torch.device,
converter: Converter | None,
job_fn=None,
) -> list[tuple]:
"""Build one job per safetensors file using the given processing function."""
if job_fn is None:
job_fn = process_file
jobs = []
for file_path, resolved_path in model_files.items():
if file_path.endswith("safetensors"):
save_path = Path(save_directory) / file_path
jobs.append(
(job_fn, resolved_path, save_path, scheme, ignore, device, converter)
)
return jobs


def _build_microscale_jobs(
def _build_jobs(
model_files: dict[str, str],
save_directory: str | os.PathLike,
scheme: QuantizationScheme,
Expand All @@ -152,10 +123,17 @@ def _build_microscale_jobs(
from other shards. This avoids runtime fused-partner discovery inside the
process function and eliminates redundant tensor reads.

Job tuple format:
(process_file_microscale_scheme, inverse_weights_map, save_path,
scheme, ignore, device, converter)
:returns: list of jobs tuples
(job_fn, inverse_weights_map, save_path, scheme, ignore, device, converter)
"""
if is_microscale_scheme(scheme):
job_fn = process_file_microscale_scheme
build_inverse_weights_map = build_microscale_inverse_weights_map
else:
job_fn = process_file
# TODO brian-dellabetta (#2491): update here in follow-up PR based on converter
build_inverse_weights_map = None

index_file = find_safetensors_index_file(model_files)

if index_file is None:
Expand All @@ -170,7 +148,7 @@ def _build_microscale_jobs(
inverse_weights_map = {resolved_path: []}
jobs.append(
(
process_file_microscale_scheme,
job_fn,
inverse_weights_map,
save_path,
scheme,
Expand All @@ -194,11 +172,14 @@ def _build_microscale_jobs(

# Precompute exactly which tensors to load from which files for this shard,
# including fused partner tensors that live in other shards
inverse_weights_map = build_inverse_weights_map(
shard_name=shard_name,
weight_map=weight_map,
model_files=model_files,
)
if build_inverse_weights_map is None:
inverse_weights_map = {resolved_path: []}
else:
inverse_weights_map = build_inverse_weights_map(
shard_name=shard_name,
weight_map=weight_map,
model_files=model_files,
)

if len(inverse_weights_map) > 1:
partner_shards = [s for s in inverse_weights_map if s != resolved_path]
Expand All @@ -209,7 +190,7 @@ def _build_microscale_jobs(

jobs.append(
(
process_file_microscale_scheme,
job_fn,
inverse_weights_map,
save_path,
scheme,
Expand All @@ -220,63 +201,3 @@ def _build_microscale_jobs(
)

return jobs


def _build_validate_jobs(jobs: list[tuple]) -> list[tuple]:
"""
Build validation jobs from processing jobs.

Handles both job formats:
- Standard/fallback: (proc_fn, file_path_str, save_path, scheme, ignore, device, \
converter)
- Microscale with index: (proc_fn, inverse_weights_map_dict, save_path, scheme, \
ignore, device, converter)
"""
validate_jobs = []
for job in jobs:
# job[0] is the processing function
# Check if second element is a dict (microscale with index)
# or string (standard/fallback)
second_arg = job[1]

if isinstance(second_arg, dict):
# Microscale job with inverse_weights_map dict
_, inverse_weights_map, save_path, scheme, ignore, device, converter = job
# Use first source file path from inverse_weights_map for validation
source_file = next(iter(inverse_weights_map.keys()))
validate_jobs.append(
(
validate_file,
source_file,
save_path,
scheme,
ignore,
device,
converter,
inverse_weights_map,
)
)
else:
# Standard job or microscale fallback: second_arg is file_path string
_, file_path, save_path, scheme, ignore, device, converter = job
validate_jobs.append(
(
validate_file,
file_path,
save_path,
scheme,
ignore,
device,
converter,
None,
)
)
return validate_jobs


def _get_all_tensor_names(file_path: str) -> list[str]:
"""Get all tensor names from a safetensors file without loading tensors."""
from safetensors import safe_open

with safe_open(file_path, framework="pt", device="cpu") as f:
return list(f.keys())
4 changes: 2 additions & 2 deletions src/llmcompressor/entrypoints/model_free/microscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)

__all__ = [
"build_inverse_weights_map",
"build_microscale_inverse_weights_map",
"is_microscale_scheme",
"get_fused_names",
"DEFAULT_FUSED_MAPPINGS",
Expand Down Expand Up @@ -78,7 +78,7 @@ def get_fused_names(
return matched, unmatched


def build_inverse_weights_map(
def build_microscale_inverse_weights_map(
shard_name: str,
weight_map: dict[str, str],
model_files: dict[str, str],
Expand Down
80 changes: 43 additions & 37 deletions src/llmcompressor/entrypoints/model_free/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from compressed_tensors.quantization import QuantizationScheme
from compressed_tensors.utils import match_quantizable_tensors
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from safetensors.torch import save_file
from torch.nn import Module

from llmcompressor.entrypoints.model_free.lifecycle import (
Expand Down Expand Up @@ -36,7 +36,6 @@ def validate_file(
ignore: Iterable[str],
device: str | torch.device,
converter: Converter | None = None,
weights_map: dict[str, str] | None = None,
):
"""
Validate that each quantizable tensor in a safetensors file can be quantized.
Expand All @@ -49,19 +48,8 @@ def validate_file(
:param device: device used to quantize and compress weights
:param converter: optional converter to apply to the checkpoint,
e.g. conversion of some layers from some format to compressed-tensors
:param weights_map: optional mapping of tensor name -> source file path,
built from safetensors.index.json. Reserved for future use by callers
that need cross-shard tensor location lookup during validation.
"""
# Extract file path from inverse_weights_map (standard mode: load all)
# Backward compatibility: handle both dict and Path/string formats
if not isinstance(inverse_weights_map, dict):
# Legacy call with file_path - wrap it as inverse_weights_map
inverse_weights_map = {inverse_weights_map: None}
# Extract source file from inverse_weights_map
source_file = next(iter(inverse_weights_map.keys()))
# Extract source file from inverse_weights_map
tensors = load_file(source_file)
tensors = _load_tensors_from_inverse_weights_map(inverse_weights_map, device)

if converter is not None:
converter.validate(tensors)
Expand Down Expand Up @@ -92,16 +80,8 @@ def process_file(
e.g. conversion of some layers from some format to compressed-tensors
"""
assert not is_microscale_scheme(scheme), "Use `process_file_microscale_scheme`"
# Extract file path from inverse_weights_map (standard mode: load all)
# Backward compatibility: handle both dict and Path/string formats
if not isinstance(inverse_weights_map, dict):
# Legacy call with file_path - wrap it as inverse_weights_map
inverse_weights_map = {inverse_weights_map: None}
# Extract source file from inverse_weights_map
source_file = next(iter(inverse_weights_map.keys()))
# Extract source file from inverse_weights_map
source_file = next(iter(inverse_weights_map.keys()))
tensors = load_file(source_file)

tensors = _load_tensors_from_inverse_weights_map(inverse_weights_map, device)

if converter is not None:
converter.process(tensors)
Expand Down Expand Up @@ -152,7 +132,7 @@ def process_file_microscale_scheme(

:param inverse_weights_map: mapping of resolved source file path ->
list of tensor names to load from that file. Precomputed by
build_inverse_weights_map() in the job-building phase.
build_microscale_inverse_weights_map() in the job-building phase.
Example: {"/path/shard0.safetensors": ["q_proj.weight"],
"/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]}
:param save_path: output path for this shard's compressed weights
Expand All @@ -165,18 +145,7 @@ def process_file_microscale_scheme(
"""
assert is_microscale_scheme(scheme), "Use `process_file` for non-microscale scheme"

# Load all required tensors using true partial reads via safe_open.
# inverse_weights_map tells us exactly which tensors to load from each file —
# no entire-file loads, no runtime discovery.
tensors: dict[str, torch.Tensor] = {}
for source_file, tensor_names in inverse_weights_map.items():
with safe_open(source_file, framework="pt", device="cpu") as f:
available = set(f.keys())
# Load all tensors if tensor_names is None or empty
names_to_load = tensor_names if tensor_names else list(available)
for name in names_to_load:
if name in available:
tensors[name] = f.get_tensor(name)
tensors = _load_tensors_from_inverse_weights_map(inverse_weights_map, device)

if converter is not None:
converter.process(tensors)
Expand Down Expand Up @@ -247,3 +216,40 @@ def process_file_microscale_scheme(
total_size = sum(t.nbytes for t in tensors.values())
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
return total_size, weight_map


# TODO brian-dellabetta (#2491): move to compressed-tensors.utils.safetensors_load
def _load_tensors_from_inverse_weights_map(
inverse_weights_map: dict[str, list[str] | None],
device: str | torch.device,
) -> dict[str, torch.Tensor]:
"""
Given an inverse_weights_map, which is a dictionary of file name to list of
tensor names, load up all listed tensor names

:param inverse_weights_map: mapping of resolved source file path ->
list of tensor names to load from that file. Precomputed by
build_inverse_weights_map() in the job-building phase.
If list is empty, all tensors are pulled
Example: {"/path/shard0.safetensors": ["q_proj.weight"],
"/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]}
:param device: tensors will be loaded onto this device.

:returns: mapping of tensor name to actual tensor loaded from safetensors file
Example: {"q_proj.weight": torch.Tensor(...), "k_proj.weight: torch.Tensor(...)}
"""
tensors: dict[str, torch.Tensor] = {}
for source_file, tensor_names in inverse_weights_map.items():
with safe_open(source_file, framework="pt", device=str(device)) as f:
keys = f.keys()
# if tensor_names is empty, pull all tensors
if tensor_names is None or len(tensor_names) == 0:
tensor_names = keys
for tensor_name in tensor_names:
if tensor_name not in keys:
raise ValueError(
f"Expected to find tensor {tensor_name} in "
f"{source_file}, but tensor was not found."
)
tensors[tensor_name] = f.get_tensor(tensor_name)
return tensors
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_validate_file_raises_for_non_2d_linear_weight(tmp_path):
save_file({"model.layers.0.mlp.down_proj.weight": torch.ones(128)}, str(path))

with pytest.raises(ValueError, match="model.layers.0.mlp.down_proj.weight"):
validate_file(path, None, _get_block_scheme(), [], None)
validate_file({str(path): []}, None, _get_block_scheme(), [], "cpu")


def test_validate_file_does_not_raise_for_block_incompatible_shape(tmp_path):
Expand All @@ -35,4 +35,4 @@ def test_validate_file_does_not_raise_for_block_incompatible_shape(tmp_path):
str(path),
)

validate_file(path, None, _get_block_scheme(), [], None)
validate_file({str(path): []}, None, _get_block_scheme(), [], "cpu")
Loading
Loading