diff --git a/src/llmcompressor/entrypoints/model_free/__init__.py b/src/llmcompressor/entrypoints/model_free/__init__.py index fd4900eb52..6ef1051ba4 100644 --- a/src/llmcompressor/entrypoints/model_free/__init__.py +++ b/src/llmcompressor/entrypoints/model_free/__init__.py @@ -1,3 +1,4 @@ +import json import os import shutil from pathlib import Path @@ -12,12 +13,15 @@ from compressed_tensors.utils.safetensors_load import ( get_checkpoint_files, is_weights_file, - update_safetensors_index, ) from loguru import logger -from llmcompressor.entrypoints.model_free.helpers import gpu_if_available +from llmcompressor.entrypoints.model_free.helpers import ( + find_safetensors_index_file, + gpu_if_available, +) from llmcompressor.entrypoints.model_free.microscale import ( + build_inverse_weights_map, is_microscale_scheme, ) from llmcompressor.entrypoints.model_free.process import ( @@ -46,19 +50,25 @@ def model_free_ptq( converter: Converter | None = None, ): """ - Quantize a model without the need for a model definition. This function operates on - a model stub or folder containing weights saved in safetensors files + Quantize a model without the need for a model definition. This function + operates on a model stub or folder containing weights saved in safetensors + files. + + For microscale schemes (NVFP4, MXFP4), fused weight sets (q/k/v, gate/up) + are handled correctly even when split across shards. Each shard job receives + a precomputed inverse_weights_map specifying exactly which tensors to load + from which files — enabling true partial reads with no runtime discovery + and no redundant tensor reads. :param model_stub: huggingface model hub or path to local weights files + :param save_directory: directory to save quantized weights to :param scheme: weight quantization scheme or preset scheme name - :param ignore: modules to ignore. Modules ending with "norm" are automatically - ignored + :param ignore: modules to ignore. Modules ending with "norm" are + automatically ignored :param max_workers: number of worker threads to process files with :param device: gpu device to accelerate quantization with - :param converter: optional converter to apply to the checkpoint to convert it to - compressed-tensors format before running model-free PTQ - e.g. conversion of some layers from modelopt format to compressed-tensors - See compressed-tensors convert_checkpoint entrypoint for more information + :param converter: optional converter to apply to the checkpoint to convert + it to compressed-tensors format before running model-free PTQ """ # validate arguments model_files = get_checkpoint_files(model_stub) @@ -66,32 +76,29 @@ def model_free_ptq( device = gpu_if_available(device) validate_safetensors_index(model_files, scheme) - # 0. collect safetensors files, copy files - jobs = [] - job_fn = ( - process_file - if not is_microscale_scheme(scheme) - else process_file_microscale_scheme - ) + # copy non-safetensors files (configs, tokenizers, etc.) for file_path, resolved_path in model_files.items(): - save_path = Path(save_directory) / file_path - - if file_path.endswith("safetensors"): - jobs.append( - (job_fn, resolved_path, save_path, scheme, ignore, device, converter) - ) - - else: + if not file_path.endswith("safetensors"): + save_path = Path(save_directory) / file_path if is_weights_file(file_path): logger.warning(f"Skip processing for weights file {file_path}") save_path.parent.mkdir(parents=True, exist_ok=True) - logger.info(f"Copying {file_path} {save_path}") + logger.info(f"Copying {file_path} -> {save_path}") shutil.copyfile(resolved_path, save_path) - # 1. validate quantizable tensors fail fast before long-running quantization - exec_jobs( - [(validate_file, *job[1:]) for job in jobs], max_workers, desc="Validating" - ) + # build quantization jobs + if is_microscale_scheme(scheme): + jobs = _build_microscale_jobs( + model_files, save_directory, scheme, ignore, device, converter + ) + else: + jobs = _build_standard_jobs( + model_files, save_directory, scheme, ignore, device, converter + ) + + # 1. validate quantizable tensors — fail fast before long-running quantization + validate_jobs = _build_validate_jobs(jobs) + exec_jobs(validate_jobs, max_workers, desc="Validating") # 2-5. quantize and compress weights total_size = 0 @@ -101,6 +108,175 @@ def model_free_ptq( total_size += _total_size weight_map.update(_weight_map) - # 5. update config and safetensors index + # 6. update config and safetensors index + # weight_map may contain tensors re-located to new shards (partner tensors + # re-saved alongside the shard that needed them for fused scale computation) update_config(save_directory, scheme_name, scheme, ignore, converter) - update_safetensors_index(save_directory, total_size, weight_map) + + +def _build_standard_jobs( + model_files: dict[str, str], + save_directory: str | os.PathLike, + scheme: QuantizationScheme, + ignore: Iterable[str], + device: torch.device, + converter: Converter | None, + job_fn=None, +) -> list[tuple]: + """Build one job per safetensors file using the given processing function.""" + if job_fn is None: + job_fn = process_file + jobs = [] + for file_path, resolved_path in model_files.items(): + if file_path.endswith("safetensors"): + save_path = Path(save_directory) / file_path + jobs.append( + (job_fn, resolved_path, save_path, scheme, ignore, device, converter) + ) + return jobs + + +def _build_microscale_jobs( + model_files: dict[str, str], + save_directory: str | os.PathLike, + scheme: QuantizationScheme, + ignore: Iterable[str], + device: torch.device, + converter: Converter | None, +) -> list[tuple]: + """ + Build microscale jobs with precomputed inverse_weights_map per shard. + + For each output shard, build_inverse_weights_map() determines exactly which + tensors to load from which source files — including any fused partner tensors + from other shards. This avoids runtime fused-partner discovery inside the + process function and eliminates redundant tensor reads. + + Job tuple format: + (process_file_microscale_scheme, inverse_weights_map, save_path, + scheme, ignore, device, converter) + """ + index_file = find_safetensors_index_file(model_files) + + if index_file is None: + # Single-file model — no cross-shard fused weights possible, + # Create inverse_weights_map dict format for process_file_microscale_scheme + jobs = [] + for file_path, resolved_path in model_files.items(): + if file_path.endswith("safetensors"): + save_path = Path(save_directory) / file_path + # Wrap as inverse_weights_map: {source_file: None} + # means load all tensors + inverse_weights_map = {resolved_path: []} + jobs.append( + ( + process_file_microscale_scheme, + inverse_weights_map, + save_path, + scheme, + ignore, + device, + converter, + ) + ) + return jobs + + # Read weight map from safetensors.index.json + with open(index_file, "r") as f: + weight_map: dict[str, str] = json.load(f)["weight_map"] + + jobs = [] + for shard_name, resolved_path in model_files.items(): + if not shard_name.endswith("safetensors"): + continue + + save_path = Path(save_directory) / shard_name + + # Precompute exactly which tensors to load from which files for this shard, + # including fused partner tensors that live in other shards + inverse_weights_map = build_inverse_weights_map( + shard_name=shard_name, + weight_map=weight_map, + model_files=model_files, + ) + + if len(inverse_weights_map) > 1: + partner_shards = [s for s in inverse_weights_map if s != resolved_path] + logger.info( + f"{shard_name}: will fetch fused partners from " + f"{[os.path.basename(s) for s in partner_shards]}" + ) + + jobs.append( + ( + process_file_microscale_scheme, + inverse_weights_map, + save_path, + scheme, + ignore, + device, + converter, + ) + ) + + return jobs + + +def _build_validate_jobs(jobs: list[tuple]) -> list[tuple]: + """ + Build validation jobs from processing jobs. + + Handles both job formats: + - Standard/fallback: (proc_fn, file_path_str, save_path, scheme, ignore, device, \ + converter) + - Microscale with index: (proc_fn, inverse_weights_map_dict, save_path, scheme, \ + ignore, device, converter) + """ + validate_jobs = [] + for job in jobs: + # job[0] is the processing function + # Check if second element is a dict (microscale with index) + # or string (standard/fallback) + second_arg = job[1] + + if isinstance(second_arg, dict): + # Microscale job with inverse_weights_map dict + _, inverse_weights_map, save_path, scheme, ignore, device, converter = job + # Use first source file path from inverse_weights_map for validation + source_file = next(iter(inverse_weights_map.keys())) + validate_jobs.append( + ( + validate_file, + source_file, + save_path, + scheme, + ignore, + device, + converter, + inverse_weights_map, + ) + ) + else: + # Standard job or microscale fallback: second_arg is file_path string + _, file_path, save_path, scheme, ignore, device, converter = job + validate_jobs.append( + ( + validate_file, + file_path, + save_path, + scheme, + ignore, + device, + converter, + None, + ) + ) + return validate_jobs + + +def _get_all_tensor_names(file_path: str) -> list[str]: + """Get all tensor names from a safetensors file without loading tensors.""" + from safetensors import safe_open + + with safe_open(file_path, framework="pt", device="cpu") as f: + return list(f.keys()) diff --git a/src/llmcompressor/entrypoints/model_free/helpers.py b/src/llmcompressor/entrypoints/model_free/helpers.py index 3a1a651e31..42b950c9f3 100644 --- a/src/llmcompressor/entrypoints/model_free/helpers.py +++ b/src/llmcompressor/entrypoints/model_free/helpers.py @@ -96,3 +96,23 @@ def invert_mapping( inverse[value].append(key) return inverse + + +def build_weights_map( + weight_map: dict[str, str], + model_files: dict[str, str], +) -> dict[str, str]: + """ + Build a mapping of tensor name -> resolved file path from the model's + weight_map (index.json). This allows any process to locate fused partner + tensors from other shards without loading entire files. + + :param weight_map: mapping of tensor name -> shard filename (from index.json) + :param model_files: mapping of shard filename -> resolved absolute path + :return: mapping of tensor name -> resolved absolute path + """ + return { + tensor_name: model_files[shard_name] + for tensor_name, shard_name in weight_map.items() + if shard_name in model_files + } diff --git a/src/llmcompressor/entrypoints/model_free/microscale.py b/src/llmcompressor/entrypoints/model_free/microscale.py index 5091ec8944..e48f515dba 100644 --- a/src/llmcompressor/entrypoints/model_free/microscale.py +++ b/src/llmcompressor/entrypoints/model_free/microscale.py @@ -1,3 +1,6 @@ +import re +from collections import defaultdict + from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy from llmcompressor.entrypoints.model_free.helpers import ( @@ -5,10 +8,44 @@ match_names_set_eager, ) -__all__ = ["is_microscale_scheme", "get_fused_names", "DEFAULT_FUSED_MAPPINGS"] +__all__ = [ + "build_inverse_weights_map", + "is_microscale_scheme", + "get_fused_names", + "DEFAULT_FUSED_MAPPINGS", +] +# Mapping of primary weight pattern -> list of partner weight patterns. +# The shard owning the primary tensor is responsible for fetching its partners. +# This prevents double reads: each fused set is fetched exactly once, by the +# shard that owns the primary (e.g. q_proj fetches k_proj + v_proj). +# +# Patterns use a named group (?P...) so partner names can be +# constructed by substituting the matched prefix via: +# partner.format(prefix=match.group("prefix")) +DEFAULT_FUSED_MAPPINGS: dict[str, list[str]] = { + # Attention q/k/v fusion: q_proj is primary + r"^(?P.+?)\.(?Pattn|attention|self_attn|self_attention)" + r"\.q_proj\.weight$": [ + r"{prefix}.{attn}.k_proj.weight", + r"{prefix}.{attn}.v_proj.weight", + ], + # MLA attention fusion: wq_a is primary + r"^(?P.+?)\.(?Pattn|attention|self_attn)\.wq_a\.weight$": [ + r"{prefix}.{attn}.wkv_a_with_mqa.weight", + ], + # MLP gate/up fusion: gate_proj is primary + r"^(?P.+?)\.(?Pmlp|feed_forward)\.gate_proj\.weight$": [ + r"{prefix}.{mlp}.up_proj.weight", + ], + # MoE w1/w3 fusion: w1 is primary + r"^(?P.+?)\.w1\.weight$": [ + r"{prefix}.w3.weight", + ], +} -DEFAULT_FUSED_MAPPINGS = [ +# List-of-lists format used by get_fused_names and validate.py +_DEFAULT_FUSED_MAPPINGS_LIST = [ [ r"re:.*(attn|attention)\.q_proj\.weight$", r"re:.*(attn|attention)\.k_proj\.weight$", @@ -33,11 +70,65 @@ def get_fused_names( ) -> tuple[list[MatchedNamesSet], list[MatchedNamesSet]]: matched = [] unmatched = [] - for mapping in DEFAULT_FUSED_MAPPINGS: + for mapping in _DEFAULT_FUSED_MAPPINGS_LIST: _matched, _unmatched = match_names_set_eager(tensor_names, mapping) - matched.extend(_matched) if _unmatched is not None: unmatched.append(_unmatched) - return matched, unmatched + + +def build_inverse_weights_map( + shard_name: str, + weight_map: dict[str, str], + model_files: dict[str, str], +) -> dict[str, list[str]]: + """ + For a given output shard, precompute exactly which tensors to load from + which source files — including fused partner tensors from other shards. + + Uses DEFAULT_FUSED_MAPPINGS with primary->partners structure to ensure + only the shard owning the primary tensor fetches its partners, preventing + double reads when fused weights span multiple shards. + + Example — given: + shard0: [q_proj.weight, ...] <- primary owner + shard1: [k_proj.weight, v_proj.weight, ...] <- partners + + Only shard0's inverse_weights_map will include shard1's tensors. + Shard1's job loads only its own native tensors. + + :param shard_name: the shard filename this job will process and save + :param weight_map: tensor name -> shard filename (from safetensors.index.json) + :param model_files: shard filename -> resolved absolute path + :return: {resolved_file_path: [tensor_names_to_load]} + """ + own_resolved = model_files[shard_name] + native_tensors = [t for t, s in weight_map.items() if s == shard_name] + + inverse_weights_map: dict[str, list[str]] = defaultdict(list) + inverse_weights_map[own_resolved] = list(native_tensors) + + # For each native tensor that matches a primary pattern, fetch its partners + for name in native_tensors: + for primary_pattern, partner_templates in DEFAULT_FUSED_MAPPINGS.items(): + match = re.match(primary_pattern, name) + if match is None: + continue + + # Build partner names using named groups from the match + for partner_template in partner_templates: + partner_name = partner_template.format(**match.groupdict()) + + partner_shard = weight_map.get(partner_name) + if partner_shard is None or partner_shard == shard_name: + continue # same shard or not found + + partner_resolved = model_files.get(partner_shard) + if partner_resolved is None: + continue + + if partner_name not in inverse_weights_map[partner_resolved]: + inverse_weights_map[partner_resolved].append(partner_name) + + return dict(inverse_weights_map) diff --git a/src/llmcompressor/entrypoints/model_free/process.py b/src/llmcompressor/entrypoints/model_free/process.py index 66dcf32469..6923d63996 100644 --- a/src/llmcompressor/entrypoints/model_free/process.py +++ b/src/llmcompressor/entrypoints/model_free/process.py @@ -7,6 +7,7 @@ from compressed_tensors.entrypoints.convert import Converter from compressed_tensors.quantization import QuantizationScheme from compressed_tensors.utils import match_quantizable_tensors +from safetensors import safe_open from safetensors.torch import load_file, save_file from torch.nn import Module @@ -29,24 +30,38 @@ def validate_file( - file_path: str | os.PathLike, + inverse_weights_map: dict[str, list[str] | None], save_path: str | os.PathLike, scheme: QuantizationScheme, ignore: Iterable[str], device: str | torch.device, converter: Converter | None = None, + weights_map: dict[str, str] | None = None, ): """ Validate that each quantizable tensor in a safetensors file can be quantized. - :param file_path: safetensors file to validate + :param inverse_weights_map: mapping of source file path -> tensor names to validate + :param save_path: save path of file with quantized weights :param scheme: quantization scheme to apply to tensors :param ignore: modules to ignore. Modules ending with "norm" are automatically ignored + :param device: device used to quantize and compress weights :param converter: optional converter to apply to the checkpoint, e.g. conversion of some layers from some format to compressed-tensors + :param weights_map: optional mapping of tensor name -> source file path, + built from safetensors.index.json. Reserved for future use by callers + that need cross-shard tensor location lookup during validation. """ - tensors = load_file(file_path) + # Extract file path from inverse_weights_map (standard mode: load all) + # Backward compatibility: handle both dict and Path/string formats + if not isinstance(inverse_weights_map, dict): + # Legacy call with file_path - wrap it as inverse_weights_map + inverse_weights_map = {inverse_weights_map: None} + # Extract source file from inverse_weights_map + source_file = next(iter(inverse_weights_map.keys())) + # Extract source file from inverse_weights_map + tensors = load_file(source_file) if converter is not None: converter.validate(tensors) @@ -56,7 +71,7 @@ def validate_file( def process_file( - file_path: str | os.PathLike, + inverse_weights_map: dict[str, list[str] | None], save_path: str | os.PathLike, scheme: QuantizationScheme, ignore: Iterable[str], @@ -64,9 +79,10 @@ def process_file( converter: Converter | None = None, ) -> tuple[int, dict[str, str]]: """ - Quantize and compress tensors in a given safetensors file + Quantize and compress tensors in a given safetensors file. - :param file_path: safetensors file to process + :param inverse_weights_map: mapping of source file path -> tensor names. + For standard mode: {{resolved_path: None}} means load all tensors to process :param save_path: save path of file with quantized weights :param scheme: quantization scheme to apply to tensors :param ignore: modules to ignore. Modules ending with "norm" are automatically @@ -75,8 +91,17 @@ def process_file( :param converter: optional converter to apply to the checkpoint, e.g. conversion of some layers from some format to compressed-tensors """ - assert not is_microscale_scheme(scheme), "Use `_process_file_microscale_scheme`" - tensors = load_file(file_path) + assert not is_microscale_scheme(scheme), "Use `process_file_microscale_scheme`" + # Extract file path from inverse_weights_map (standard mode: load all) + # Backward compatibility: handle both dict and Path/string formats + if not isinstance(inverse_weights_map, dict): + # Legacy call with file_path - wrap it as inverse_weights_map + inverse_weights_map = {inverse_weights_map: None} + # Extract source file from inverse_weights_map + source_file = next(iter(inverse_weights_map.keys())) + # Extract source file from inverse_weights_map + source_file = next(iter(inverse_weights_map.keys())) + tensors = load_file(source_file) if converter is not None: converter.process(tensors) @@ -106,7 +131,7 @@ def process_file( def process_file_microscale_scheme( - file_path: str | os.PathLike, + inverse_weights_map: dict[str, list[str]], save_path: str | os.PathLike, scheme: QuantizationScheme, ignore: Iterable[str], @@ -114,35 +139,59 @@ def process_file_microscale_scheme( converter: Converter | None = None, ) -> tuple[int, dict[str, str]]: """ - Quantize and compress tensors in a given safetensors file - - :param file_path: safetensors file to process - :param save_path: save path of file with quantized weights - :param scheme: quantization scheme to apply to tensors + Quantize and compress tensors for a single output shard using a microscale + scheme (NVFP4, MXFP4). + + Accepts a precomputed inverse_weights_map that specifies exactly which tensors + to load from which source files — including any fused partner tensors from + other shards needed for global scale computation. This avoids runtime + discovery of fused partners and redundant tensor reads. + + Partner tensors fetched from other shards are re-saved into this shard's + output. The caller updates the safetensors index to reflect new locations. + + :param inverse_weights_map: mapping of resolved source file path -> + list of tensor names to load from that file. Precomputed by + build_inverse_weights_map() in the job-building phase. + Example: {"/path/shard0.safetensors": ["q_proj.weight"], + "/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]} + :param save_path: output path for this shard's compressed weights + :param scheme: microscale quantization scheme (NVFP4, MXFP4) :param ignore: modules to ignore. Modules ending with "norm" are automatically ignored :param device: device used to quantize and compress weights :param converter: optional converter to apply to the checkpoint, e.g. conversion of some layers from some format to compressed-tensors """ - assert is_microscale_scheme(scheme), "Use `_process_file` for non-microscale scheme" - tensors = load_file(file_path) + assert is_microscale_scheme(scheme), "Use `process_file` for non-microscale scheme" + + # Load all required tensors using true partial reads via safe_open. + # inverse_weights_map tells us exactly which tensors to load from each file — + # no entire-file loads, no runtime discovery. + tensors: dict[str, torch.Tensor] = {} + for source_file, tensor_names in inverse_weights_map.items(): + with safe_open(source_file, framework="pt", device="cpu") as f: + available = set(f.keys()) + # Load all tensors if tensor_names is None or empty + names_to_load = tensor_names if tensor_names else list(available) + for name in names_to_load: + if name in available: + tensors[name] = f.get_tensor(name) if converter is not None: converter.process(tensors) - fused_sets, unmatched_sets = get_fused_names(tensors) - assert len(unmatched_sets) <= 0 # should be caught by `validate_safetensors_index` - - fused_name_to_fused_index: dict[str, int] # fused_name -> fused_index - fused_modules: dict[int, dict[str, Module]] # fused_index -> named_modules + # Get fused sets. Non-primary shards may have incomplete sets (k/v without q) + # since only the primary-owning shard fetches partners — this is correct. + fused_sets, _ = get_fused_names(list(tensors.keys())) - fused_name_to_fused_index = { + fused_name_to_fused_index: dict[str, int] = { name: index for index, matched_set in enumerate(fused_sets) for name in matched_set.values() + if name is not None } - fused_modules = defaultdict(dict) + fused_modules: dict[int, dict[str, Module]] = defaultdict(dict) for module_name, name in match_quantizable_tensors(tensors, ignore, scheme.targets): validate_weight_for_quantization(tensors[name], scheme, name) @@ -150,7 +199,7 @@ def process_file_microscale_scheme( # 1. initialize module with qparams (on device) module = initialize_quantized_linear(tensors[name], scheme, device) - # 2. calibrate weight qparams. Delay scale/zp calibration for fused modules + # 2. calibrate global scale; delay scale/zp for fused modules calibrate_global_scale(module) if name in fused_name_to_fused_index: fused_index = fused_name_to_fused_index[name] @@ -168,9 +217,9 @@ def process_file_microscale_scheme( for key, value in module.state_dict(prefix=prefix).items(): tensors[key] = value.to("cpu") - # compress and save miscroscale fused modules + # Compress fused modules with shared global scale for named_modules in fused_modules.values(): - # 2.1. fuse global scales + # 2.1. compute fused global scale across all members of the fused set global_scales = [m.weight_global_scale for m in named_modules.values()] fused_global_scale = torch.min(torch.cat(global_scales, dim=0)) @@ -178,10 +227,10 @@ def process_file_microscale_scheme( module_name, _ = name.rsplit(".", 1) module.weight_global_scale.data.copy_(fused_global_scale) - # 2.2. finish calibration with fused global scales + # 2.2. finish calibration with fused global scale calibrate_scale_zp(module) - # 3. compress module using miscroscale qparams + # 3. compress module using microscale qparams compress_module(module) # 4. save compressed data (on cpu) @@ -190,7 +239,11 @@ def process_file_microscale_scheme( for key, value in module.state_dict(prefix=prefix).items(): tensors[key] = value.to("cpu") + # Save ALL tensors to this shard's output — including partner tensors fetched + # from other shards. Partners are re-saved here so future runs don't need to + # re-fetch them. The caller updates the safetensors index to reflect new locations. + os.makedirs(os.path.dirname(os.path.abspath(save_path)), exist_ok=True) save_file(tensors, save_path) - total_size = sum(tensor.nbytes for tensor in tensors.values()) + total_size = sum(t.nbytes for t in tensors.values()) weight_map = {key: os.path.basename(save_path) for key in tensors.keys()} return total_size, weight_map diff --git a/src/llmcompressor/entrypoints/model_free/validate.py b/src/llmcompressor/entrypoints/model_free/validate.py index 390706fad5..5484cd200e 100644 --- a/src/llmcompressor/entrypoints/model_free/validate.py +++ b/src/llmcompressor/entrypoints/model_free/validate.py @@ -23,7 +23,7 @@ def validate_scheme(scheme: QuantizationScheme) -> tuple[str, QuantizationScheme # weight quantization must be provided if scheme.weights is None: raise ValueError( - "Must provide a weights quanitization scheme to perform weights-only PTQ" + "Must provide a weights quantization scheme to perform weights-only PTQ" ) # activation quantization must be dynamic @@ -59,19 +59,19 @@ def validate_safetensors_index(model_files: dict[str, str], scheme: Quantization if is_microscale_scheme(scheme): with open(index_file_path, "r") as file: - weight_map: dict[str, str] = json.load(file)["weight_map"] + weights_map: dict[str, str] = json.load(file)["weight_map"] - file_map = invert_mapping(weight_map) + file_map = invert_mapping(weights_map) for file in sorted(file_map): tensor_names = file_map[file] _fused_sets, unmatched_sets = get_fused_names(tensor_names) if len(unmatched_sets) > 0: - raise NotImplementedError( - "When using a microscale scheme (NVFP4, MXFP4), global scales " - "will be fused. Current implmentation requires that all fused " - "modules (attention and mlp) be stored in the same file. " - f"However, {file} has an unmatched set of fused weights: " - f"\n{json.dumps(unmatched_sets, indent=4)}\n\n" - "Please use `reindex_fused_weights.py` to reindex your safetensors " - "before running `model_free_ptq` again." + # Cross-shard fused weights detected. model_free_ptq handles + # this automatically via precomputed inverse_weights_map — + # fused partner tensors are fetched via partial reads and + # re-saved into the requesting shard's output. + logger.debug( + f"{file} has fused weights split across shards: " + f"{json.dumps(unmatched_sets, indent=4)}\n" + "These will be resolved via precomputed inverse_weights_map." ) diff --git a/tests/llmcompressor/entrypoints/model_free/test_reindexing_elimination.py b/tests/llmcompressor/entrypoints/model_free/test_reindexing_elimination.py new file mode 100644 index 0000000000..6062974435 --- /dev/null +++ b/tests/llmcompressor/entrypoints/model_free/test_reindexing_elimination.py @@ -0,0 +1,247 @@ +""" +Tests for inverse_weights_map approach that eliminates the +reindex_fused_weights preprocessing step for microscale schemes. +""" + +import pytest +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from safetensors.torch import save_file + +from llmcompressor.entrypoints.model_free.microscale import ( + build_inverse_weights_map, +) +from llmcompressor.entrypoints.model_free.process import ( + process_file_microscale_scheme, +) + + +def _make_nvfp4_scheme(): + return QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type="float", + strategy="tensor_group", + group_size=16, + symmetric=True, + dynamic=False, + scale_dtype=torch.float8_e4m3fn, + ), + ) + + +def _rand_weight(*shape): + return torch.randn(*shape, dtype=torch.float16) + + +class TestBuildWeightsMap: + def test_basic_mapping(self, tmp_path): + weight_map = { + "model.layers.0.self_attn.q_proj.weight": "shard-00001.safetensors", + "model.layers.0.self_attn.k_proj.weight": "shard-00002.safetensors", + } + model_files = { + "shard-00001.safetensors": str(tmp_path / "shard-00001.safetensors"), + "shard-00002.safetensors": str(tmp_path / "shard-00002.safetensors"), + } + result = build_inverse_weights_map( + "shard-00001.safetensors", weight_map, model_files + ) + # result is {file_path: [tensor_names]}, check tensor is in the list + assert ( + "model.layers.0.self_attn.q_proj.weight" + in result[str(tmp_path / "shard-00001.safetensors")] + ) + assert ( + "model.layers.0.self_attn.k_proj.weight" + in result[str(tmp_path / "shard-00002.safetensors")] + ) + + def test_missing_shard_skipped(self, tmp_path): + weight_map = { + "tensor.a": "shard-00001.safetensors", + "tensor.b": "shard-00002.safetensors", + } + model_files = { + "shard-00001.safetensors": str(tmp_path / "shard-00001.safetensors"), + } + result = build_inverse_weights_map( + "shard-00001.safetensors", weight_map, model_files + ) + # check tensor.a is in the result values + assert any("tensor.a" in tensors for tensors in result.values()) + assert "tensor.b" not in result + + +class TestBuildInverseWeightsMap: + def test_colocated_no_partners_needed(self, tmp_path): + """All fused weights in same shard — no cross-shard fetching needed.""" + shard = "shard-00001.safetensors" + weight_map = { + "model.layers.0.self_attn.q_proj.weight": shard, + "model.layers.0.self_attn.k_proj.weight": shard, + "model.layers.0.self_attn.v_proj.weight": shard, + } + model_files = {shard: str(tmp_path / shard)} + result = build_inverse_weights_map(shard, weight_map, model_files) + assert len(result) == 1 + assert str(tmp_path / shard) in result + + def test_cross_shard_partners_found(self, tmp_path): + """q_proj on shard1, k/v on shard2 — shard1 should fetch from shard2.""" + weight_map = { + "model.layers.0.self_attn.q_proj.weight": "shard-00001.safetensors", + "model.layers.0.self_attn.k_proj.weight": "shard-00002.safetensors", + "model.layers.0.self_attn.v_proj.weight": "shard-00002.safetensors", + } + model_files = { + "shard-00001.safetensors": str(tmp_path / "shard-00001.safetensors"), + "shard-00002.safetensors": str(tmp_path / "shard-00002.safetensors"), + } + result = build_inverse_weights_map( + "shard-00001.safetensors", weight_map, model_files + ) + # Should include both shards + assert len(result) == 2 + shard2_path = str(tmp_path / "shard-00002.safetensors") + assert shard2_path in result + assert "model.layers.0.self_attn.k_proj.weight" in result[shard2_path] + assert "model.layers.0.self_attn.v_proj.weight" in result[shard2_path] + + +class TestProcessFileMicroscaleSchemeColocated: + """Tests for co-located fused weights — standard case, no cross-shard needed.""" + + @pytest.fixture + def qkv_tensors(self): + return { + "model.layers.0.self_attn.q_proj.weight": _rand_weight(32, 32), + "model.layers.0.self_attn.k_proj.weight": _rand_weight(32, 32), + "model.layers.0.self_attn.v_proj.weight": _rand_weight(32, 32), + "model.layers.0.mlp.down_proj.weight": _rand_weight(32, 32), + } + + def test_colocated_fused_weights(self, qkv_tensors, tmp_path): + """Standard case: all fused weights in one shard.""" + shard_name = "model.safetensors" + shard_path = tmp_path / shard_name + save_path = tmp_path / "out.safetensors" + save_file(qkv_tensors, shard_path) + + # Build inverse_weights_map: just the one file with all tensors + inverse_weights_map = {str(shard_path): list(qkv_tensors.keys())} + + total_size, weight_map = process_file_microscale_scheme( + inverse_weights_map=inverse_weights_map, + save_path=save_path, + scheme=_make_nvfp4_scheme(), + ignore=[], + device="cpu", + ) + assert save_path.exists() + assert total_size > 0 + assert len(weight_map) > 0 + + +class TestProcessFileMicroscaleSchemeCrossShardInverseMap: + """Tests for cross-shard fused weights using precomputed inverse_weights_map.""" + + @pytest.fixture + def split_shards(self, tmp_path): + """q_proj on shard-1, k_proj + v_proj + down_proj on shard-2.""" + shard1_tensors = { + "model.layers.0.self_attn.q_proj.weight": _rand_weight(32, 32), + } + shard2_tensors = { + "model.layers.0.self_attn.k_proj.weight": _rand_weight(32, 32), + "model.layers.0.self_attn.v_proj.weight": _rand_weight(32, 32), + "model.layers.0.mlp.down_proj.weight": _rand_weight(32, 32), + } + shard1_path = tmp_path / "shard-00001.safetensors" + shard2_path = tmp_path / "shard-00002.safetensors" + save_file(shard1_tensors, shard1_path) + save_file(shard2_tensors, shard2_path) + + weight_map = { + "model.layers.0.self_attn.q_proj.weight": "shard-00001.safetensors", + "model.layers.0.self_attn.k_proj.weight": "shard-00002.safetensors", + "model.layers.0.self_attn.v_proj.weight": "shard-00002.safetensors", + "model.layers.0.mlp.down_proj.weight": "shard-00002.safetensors", + } + model_files = { + "shard-00001.safetensors": str(shard1_path), + "shard-00002.safetensors": str(shard2_path), + } + # Precompute inverse_weights_map for each shard + iwm1 = build_inverse_weights_map( + "shard-00001.safetensors", weight_map, model_files + ) + iwm2 = build_inverse_weights_map( + "shard-00002.safetensors", weight_map, model_files + ) + return shard1_path, shard2_path, iwm1, iwm2 + + def test_shard1_produces_output(self, split_shards, tmp_path): + """Shard-1 (q_proj only) processes correctly using precomputed inverse map.""" + shard1_path, _, iwm1, _ = split_shards + save_path = tmp_path / "out-00001.safetensors" + + total_size, weight_map = process_file_microscale_scheme( + inverse_weights_map=iwm1, + save_path=save_path, + scheme=_make_nvfp4_scheme(), + ignore=[], + device="cpu", + ) + assert save_path.exists() + assert total_size > 0 + assert len(weight_map) > 0 + + def test_shard2_produces_output(self, split_shards, tmp_path): + """Shard-2 (k/v/down) processes correctly using precomputed inverse map.""" + _, shard2_path, _, iwm2 = split_shards + save_path = tmp_path / "out-00002.safetensors" + + total_size, weight_map = process_file_microscale_scheme( + inverse_weights_map=iwm2, + save_path=save_path, + scheme=_make_nvfp4_scheme(), + ignore=[], + device="cpu", + ) + assert save_path.exists() + assert total_size > 0 + + def test_both_shards_produce_same_keys_as_merged(self, split_shards, tmp_path): + """Combined output keys from both shards + should match merged single-shard keys.""" + shard1_path, shard2_path, iwm1, iwm2 = split_shards + + out1 = tmp_path / "out-00001.safetensors" + out2 = tmp_path / "out-00002.safetensors" + _, wm1 = process_file_microscale_scheme( + iwm1, out1, _make_nvfp4_scheme(), [], "cpu" + ) + _, wm2 = process_file_microscale_scheme( + iwm2, out2, _make_nvfp4_scheme(), [], "cpu" + ) + combined_keys = set(wm1.keys()) | set(wm2.keys()) + + # Process merged shard as reference + from safetensors.torch import load_file + + merged = {**load_file(shard1_path), **load_file(shard2_path)} + merged_path = tmp_path / "merged.safetensors" + merged_out = tmp_path / "merged_out.safetensors" + save_file(merged, merged_path) + merged_iwm = {str(merged_path): list(merged.keys())} + _, wm_merged = process_file_microscale_scheme( + merged_iwm, merged_out, _make_nvfp4_scheme(), [], "cpu" + ) + + assert combined_keys == set(wm_merged.keys()), ( + f"Key mismatch:\n" + f" split only: {sorted(combined_keys - set(wm_merged.keys()))}\n" + f" merged only: {sorted(set(wm_merged.keys()) - combined_keys)}" + )