diff --git a/src/llmcompressor/entrypoints/model_free/__init__.py b/src/llmcompressor/entrypoints/model_free/__init__.py index 6ef1051ba4..616065bafb 100644 --- a/src/llmcompressor/entrypoints/model_free/__init__.py +++ b/src/llmcompressor/entrypoints/model_free/__init__.py @@ -21,7 +21,7 @@ gpu_if_available, ) from llmcompressor.entrypoints.model_free.microscale import ( - build_inverse_weights_map, + build_microscale_inverse_weights_map, is_microscale_scheme, ) from llmcompressor.entrypoints.model_free.process import ( @@ -87,17 +87,10 @@ def model_free_ptq( shutil.copyfile(resolved_path, save_path) # build quantization jobs - if is_microscale_scheme(scheme): - jobs = _build_microscale_jobs( - model_files, save_directory, scheme, ignore, device, converter - ) - else: - jobs = _build_standard_jobs( - model_files, save_directory, scheme, ignore, device, converter - ) + jobs = _build_jobs(model_files, save_directory, scheme, ignore, device, converter) # 1. validate quantizable tensors — fail fast before long-running quantization - validate_jobs = _build_validate_jobs(jobs) + validate_jobs = [(validate_file, *job[1:]) for job in jobs] exec_jobs(validate_jobs, max_workers, desc="Validating") # 2-5. quantize and compress weights @@ -114,29 +107,7 @@ def model_free_ptq( update_config(save_directory, scheme_name, scheme, ignore, converter) -def _build_standard_jobs( - model_files: dict[str, str], - save_directory: str | os.PathLike, - scheme: QuantizationScheme, - ignore: Iterable[str], - device: torch.device, - converter: Converter | None, - job_fn=None, -) -> list[tuple]: - """Build one job per safetensors file using the given processing function.""" - if job_fn is None: - job_fn = process_file - jobs = [] - for file_path, resolved_path in model_files.items(): - if file_path.endswith("safetensors"): - save_path = Path(save_directory) / file_path - jobs.append( - (job_fn, resolved_path, save_path, scheme, ignore, device, converter) - ) - return jobs - - -def _build_microscale_jobs( +def _build_jobs( model_files: dict[str, str], save_directory: str | os.PathLike, scheme: QuantizationScheme, @@ -152,10 +123,17 @@ def _build_microscale_jobs( from other shards. This avoids runtime fused-partner discovery inside the process function and eliminates redundant tensor reads. - Job tuple format: - (process_file_microscale_scheme, inverse_weights_map, save_path, - scheme, ignore, device, converter) + :returns: list of jobs tuples + (job_fn, inverse_weights_map, save_path, scheme, ignore, device, converter) """ + if is_microscale_scheme(scheme): + job_fn = process_file_microscale_scheme + build_inverse_weights_map = build_microscale_inverse_weights_map + else: + job_fn = process_file + # TODO brian-dellabetta (#2491): update here in follow-up PR based on converter + build_inverse_weights_map = None + index_file = find_safetensors_index_file(model_files) if index_file is None: @@ -170,7 +148,7 @@ def _build_microscale_jobs( inverse_weights_map = {resolved_path: []} jobs.append( ( - process_file_microscale_scheme, + job_fn, inverse_weights_map, save_path, scheme, @@ -194,11 +172,14 @@ def _build_microscale_jobs( # Precompute exactly which tensors to load from which files for this shard, # including fused partner tensors that live in other shards - inverse_weights_map = build_inverse_weights_map( - shard_name=shard_name, - weight_map=weight_map, - model_files=model_files, - ) + if build_inverse_weights_map is None: + inverse_weights_map = {resolved_path: []} + else: + inverse_weights_map = build_inverse_weights_map( + shard_name=shard_name, + weight_map=weight_map, + model_files=model_files, + ) if len(inverse_weights_map) > 1: partner_shards = [s for s in inverse_weights_map if s != resolved_path] @@ -209,7 +190,7 @@ def _build_microscale_jobs( jobs.append( ( - process_file_microscale_scheme, + job_fn, inverse_weights_map, save_path, scheme, @@ -220,63 +201,3 @@ def _build_microscale_jobs( ) return jobs - - -def _build_validate_jobs(jobs: list[tuple]) -> list[tuple]: - """ - Build validation jobs from processing jobs. - - Handles both job formats: - - Standard/fallback: (proc_fn, file_path_str, save_path, scheme, ignore, device, \ - converter) - - Microscale with index: (proc_fn, inverse_weights_map_dict, save_path, scheme, \ - ignore, device, converter) - """ - validate_jobs = [] - for job in jobs: - # job[0] is the processing function - # Check if second element is a dict (microscale with index) - # or string (standard/fallback) - second_arg = job[1] - - if isinstance(second_arg, dict): - # Microscale job with inverse_weights_map dict - _, inverse_weights_map, save_path, scheme, ignore, device, converter = job - # Use first source file path from inverse_weights_map for validation - source_file = next(iter(inverse_weights_map.keys())) - validate_jobs.append( - ( - validate_file, - source_file, - save_path, - scheme, - ignore, - device, - converter, - inverse_weights_map, - ) - ) - else: - # Standard job or microscale fallback: second_arg is file_path string - _, file_path, save_path, scheme, ignore, device, converter = job - validate_jobs.append( - ( - validate_file, - file_path, - save_path, - scheme, - ignore, - device, - converter, - None, - ) - ) - return validate_jobs - - -def _get_all_tensor_names(file_path: str) -> list[str]: - """Get all tensor names from a safetensors file without loading tensors.""" - from safetensors import safe_open - - with safe_open(file_path, framework="pt", device="cpu") as f: - return list(f.keys()) diff --git a/src/llmcompressor/entrypoints/model_free/microscale.py b/src/llmcompressor/entrypoints/model_free/microscale.py index e48f515dba..eab8782a43 100644 --- a/src/llmcompressor/entrypoints/model_free/microscale.py +++ b/src/llmcompressor/entrypoints/model_free/microscale.py @@ -9,7 +9,7 @@ ) __all__ = [ - "build_inverse_weights_map", + "build_microscale_inverse_weights_map", "is_microscale_scheme", "get_fused_names", "DEFAULT_FUSED_MAPPINGS", @@ -78,7 +78,7 @@ def get_fused_names( return matched, unmatched -def build_inverse_weights_map( +def build_microscale_inverse_weights_map( shard_name: str, weight_map: dict[str, str], model_files: dict[str, str], diff --git a/src/llmcompressor/entrypoints/model_free/process.py b/src/llmcompressor/entrypoints/model_free/process.py index 6923d63996..7b58b4b719 100644 --- a/src/llmcompressor/entrypoints/model_free/process.py +++ b/src/llmcompressor/entrypoints/model_free/process.py @@ -8,7 +8,7 @@ from compressed_tensors.quantization import QuantizationScheme from compressed_tensors.utils import match_quantizable_tensors from safetensors import safe_open -from safetensors.torch import load_file, save_file +from safetensors.torch import save_file from torch.nn import Module from llmcompressor.entrypoints.model_free.lifecycle import ( @@ -36,7 +36,6 @@ def validate_file( ignore: Iterable[str], device: str | torch.device, converter: Converter | None = None, - weights_map: dict[str, str] | None = None, ): """ Validate that each quantizable tensor in a safetensors file can be quantized. @@ -49,19 +48,8 @@ def validate_file( :param device: device used to quantize and compress weights :param converter: optional converter to apply to the checkpoint, e.g. conversion of some layers from some format to compressed-tensors - :param weights_map: optional mapping of tensor name -> source file path, - built from safetensors.index.json. Reserved for future use by callers - that need cross-shard tensor location lookup during validation. """ - # Extract file path from inverse_weights_map (standard mode: load all) - # Backward compatibility: handle both dict and Path/string formats - if not isinstance(inverse_weights_map, dict): - # Legacy call with file_path - wrap it as inverse_weights_map - inverse_weights_map = {inverse_weights_map: None} - # Extract source file from inverse_weights_map - source_file = next(iter(inverse_weights_map.keys())) - # Extract source file from inverse_weights_map - tensors = load_file(source_file) + tensors = _load_tensors_from_inverse_weights_map(inverse_weights_map, device) if converter is not None: converter.validate(tensors) @@ -92,16 +80,8 @@ def process_file( e.g. conversion of some layers from some format to compressed-tensors """ assert not is_microscale_scheme(scheme), "Use `process_file_microscale_scheme`" - # Extract file path from inverse_weights_map (standard mode: load all) - # Backward compatibility: handle both dict and Path/string formats - if not isinstance(inverse_weights_map, dict): - # Legacy call with file_path - wrap it as inverse_weights_map - inverse_weights_map = {inverse_weights_map: None} - # Extract source file from inverse_weights_map - source_file = next(iter(inverse_weights_map.keys())) - # Extract source file from inverse_weights_map - source_file = next(iter(inverse_weights_map.keys())) - tensors = load_file(source_file) + + tensors = _load_tensors_from_inverse_weights_map(inverse_weights_map, device) if converter is not None: converter.process(tensors) @@ -152,7 +132,7 @@ def process_file_microscale_scheme( :param inverse_weights_map: mapping of resolved source file path -> list of tensor names to load from that file. Precomputed by - build_inverse_weights_map() in the job-building phase. + build_microscale_inverse_weights_map() in the job-building phase. Example: {"/path/shard0.safetensors": ["q_proj.weight"], "/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]} :param save_path: output path for this shard's compressed weights @@ -165,18 +145,7 @@ def process_file_microscale_scheme( """ assert is_microscale_scheme(scheme), "Use `process_file` for non-microscale scheme" - # Load all required tensors using true partial reads via safe_open. - # inverse_weights_map tells us exactly which tensors to load from each file — - # no entire-file loads, no runtime discovery. - tensors: dict[str, torch.Tensor] = {} - for source_file, tensor_names in inverse_weights_map.items(): - with safe_open(source_file, framework="pt", device="cpu") as f: - available = set(f.keys()) - # Load all tensors if tensor_names is None or empty - names_to_load = tensor_names if tensor_names else list(available) - for name in names_to_load: - if name in available: - tensors[name] = f.get_tensor(name) + tensors = _load_tensors_from_inverse_weights_map(inverse_weights_map, device) if converter is not None: converter.process(tensors) @@ -247,3 +216,40 @@ def process_file_microscale_scheme( total_size = sum(t.nbytes for t in tensors.values()) weight_map = {key: os.path.basename(save_path) for key in tensors.keys()} return total_size, weight_map + + +# TODO brian-dellabetta (#2491): move to compressed-tensors.utils.safetensors_load +def _load_tensors_from_inverse_weights_map( + inverse_weights_map: dict[str, list[str] | None], + device: str | torch.device, +) -> dict[str, torch.Tensor]: + """ + Given an inverse_weights_map, which is a dictionary of file name to list of + tensor names, load up all listed tensor names + + :param inverse_weights_map: mapping of resolved source file path -> + list of tensor names to load from that file. Precomputed by + build_inverse_weights_map() in the job-building phase. + If list is empty, all tensors are pulled + Example: {"/path/shard0.safetensors": ["q_proj.weight"], + "/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]} + :param device: tensors will be loaded onto this device. + + :returns: mapping of tensor name to actual tensor loaded from safetensors file + Example: {"q_proj.weight": torch.Tensor(...), "k_proj.weight: torch.Tensor(...)} + """ + tensors: dict[str, torch.Tensor] = {} + for source_file, tensor_names in inverse_weights_map.items(): + with safe_open(source_file, framework="pt", device=str(device)) as f: + keys = f.keys() + # if tensor_names is empty, pull all tensors + if tensor_names is None or len(tensor_names) == 0: + tensor_names = keys + for tensor_name in tensor_names: + if tensor_name not in keys: + raise ValueError( + f"Expected to find tensor {tensor_name} in " + f"{source_file}, but tensor was not found." + ) + tensors[tensor_name] = f.get_tensor(tensor_name) + return tensors diff --git a/tests/llmcompressor/entrypoints/model_free/test_model_free_validation.py b/tests/llmcompressor/entrypoints/model_free/test_model_free_validation.py index 6129c0e942..517890a25f 100644 --- a/tests/llmcompressor/entrypoints/model_free/test_model_free_validation.py +++ b/tests/llmcompressor/entrypoints/model_free/test_model_free_validation.py @@ -25,7 +25,7 @@ def test_validate_file_raises_for_non_2d_linear_weight(tmp_path): save_file({"model.layers.0.mlp.down_proj.weight": torch.ones(128)}, str(path)) with pytest.raises(ValueError, match="model.layers.0.mlp.down_proj.weight"): - validate_file(path, None, _get_block_scheme(), [], None) + validate_file({str(path): []}, None, _get_block_scheme(), [], "cpu") def test_validate_file_does_not_raise_for_block_incompatible_shape(tmp_path): @@ -35,4 +35,4 @@ def test_validate_file_does_not_raise_for_block_incompatible_shape(tmp_path): str(path), ) - validate_file(path, None, _get_block_scheme(), [], None) + validate_file({str(path): []}, None, _get_block_scheme(), [], "cpu") diff --git a/tests/llmcompressor/entrypoints/model_free/test_reindexing_elimination.py b/tests/llmcompressor/entrypoints/model_free/test_reindexing_elimination.py index 6062974435..21e032ea75 100644 --- a/tests/llmcompressor/entrypoints/model_free/test_reindexing_elimination.py +++ b/tests/llmcompressor/entrypoints/model_free/test_reindexing_elimination.py @@ -9,7 +9,7 @@ from safetensors.torch import save_file from llmcompressor.entrypoints.model_free.microscale import ( - build_inverse_weights_map, + build_microscale_inverse_weights_map, ) from llmcompressor.entrypoints.model_free.process import ( process_file_microscale_scheme, @@ -45,7 +45,7 @@ def test_basic_mapping(self, tmp_path): "shard-00001.safetensors": str(tmp_path / "shard-00001.safetensors"), "shard-00002.safetensors": str(tmp_path / "shard-00002.safetensors"), } - result = build_inverse_weights_map( + result = build_microscale_inverse_weights_map( "shard-00001.safetensors", weight_map, model_files ) # result is {file_path: [tensor_names]}, check tensor is in the list @@ -66,7 +66,7 @@ def test_missing_shard_skipped(self, tmp_path): model_files = { "shard-00001.safetensors": str(tmp_path / "shard-00001.safetensors"), } - result = build_inverse_weights_map( + result = build_microscale_inverse_weights_map( "shard-00001.safetensors", weight_map, model_files ) # check tensor.a is in the result values @@ -84,7 +84,7 @@ def test_colocated_no_partners_needed(self, tmp_path): "model.layers.0.self_attn.v_proj.weight": shard, } model_files = {shard: str(tmp_path / shard)} - result = build_inverse_weights_map(shard, weight_map, model_files) + result = build_microscale_inverse_weights_map(shard, weight_map, model_files) assert len(result) == 1 assert str(tmp_path / shard) in result @@ -99,7 +99,7 @@ def test_cross_shard_partners_found(self, tmp_path): "shard-00001.safetensors": str(tmp_path / "shard-00001.safetensors"), "shard-00002.safetensors": str(tmp_path / "shard-00002.safetensors"), } - result = build_inverse_weights_map( + result = build_microscale_inverse_weights_map( "shard-00001.safetensors", weight_map, model_files ) # Should include both shards @@ -174,10 +174,10 @@ def split_shards(self, tmp_path): "shard-00002.safetensors": str(shard2_path), } # Precompute inverse_weights_map for each shard - iwm1 = build_inverse_weights_map( + iwm1 = build_microscale_inverse_weights_map( "shard-00001.safetensors", weight_map, model_files ) - iwm2 = build_inverse_weights_map( + iwm2 = build_microscale_inverse_weights_map( "shard-00002.safetensors", weight_map, model_files ) return shard1_path, shard2_path, iwm1, iwm2