diff --git a/examples/convert_checkpoint/deepseek32_fpblock_example.py b/examples/convert_checkpoint/deepseek32_fpblock_example.py new file mode 100644 index 000000000..e5a409e77 --- /dev/null +++ b/examples/convert_checkpoint/deepseek32_fpblock_example.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from compressed_tensors.entrypoints.convert import ( + convert_checkpoint, + FP8BlockDequantizer, +) + +# deepseek-ai/DeepSeek-V3.2 checkpoint has layers that are quantized in the FP8 +# quant method's FP8_BLOCK scheme. This script will upconvert to bfloat16 so that +# the model can be compressed in another configuration. +MODEL_ID = "deepseek-ai/DeepSeek-V3.2" +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-bf16" + +# Convert DeepSeek-V3.2 back to dense bfloat16 format +convert_checkpoint( + model_stub=MODEL_ID, + save_directory=SAVE_DIR, + converter=FP8BlockDequantizer( + # `deepseek-ai/DeepSeek-V3.2` fp8-block-quantized layers, found by inspection + targets=[ + r"re:.*mlp.*\.(gate_up|gate|up|down)_proj$", + r"re:.*self_attn.*\.(kv_b|o|q_a|q_b)_proj$", + r"re:.*self_attn.kv_a_proj_with_mqa$", + r"re:.*self_attn.indexer.(wk|wq_b)$", + ], + ), + max_workers=4, +) diff --git a/examples/convert_checkpoint/qwen3_fpblock_example.py b/examples/convert_checkpoint/qwen3_fpblock_example.py new file mode 100644 index 000000000..dece24a83 --- /dev/null +++ b/examples/convert_checkpoint/qwen3_fpblock_example.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from compressed_tensors.entrypoints.convert import ( + convert_checkpoint, + FP8BlockDequantizer, +) + +MODEL_ID = "qwen-community/Qwen3-4B-FP8" +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1].rstrip("-FP8") + +# Convert Qwen3-4B-FP8 back to dense bfloat16 format +convert_checkpoint( + model_stub=MODEL_ID, + save_directory=SAVE_DIR, + converter=FP8BlockDequantizer( + # qwen-community/Qwen3-4B-FP8's fp8-block-quantized layers, found by inspection + targets=[ + r"re:.*mlp.*\.(gate_up|gate|up|down)_proj$", + r"re:.*self_attn.*\.(q|k|v|o)_proj$", + ], + weight_block_size=[128, 128], + ), + max_workers=8, +) diff --git a/src/compressed_tensors/entrypoints/convert/convert_checkpoint.py b/src/compressed_tensors/entrypoints/convert/convert_checkpoint.py index 732311102..b902300b0 100644 --- a/src/compressed_tensors/entrypoints/convert/convert_checkpoint.py +++ b/src/compressed_tensors/entrypoints/convert/convert_checkpoint.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json import os import shutil from collections.abc import Callable @@ -13,8 +14,12 @@ validate_file, write_checkpoint_quantization_config, ) -from compressed_tensors.entrypoints.convert.converters import Converter +from compressed_tensors.entrypoints.convert.converters import ( + Converter, + build_inverse_weight_maps, +) from compressed_tensors.utils.safetensors_load import ( + get_weight_map, get_checkpoint_files, is_weights_file, update_safetensors_index, @@ -32,11 +37,14 @@ def convert_checkpoint( max_workers: int = 1, ): """ - Convert a model checkpoint to compressed-tensors format without loading it up - in memory, instead operating directly on the model safetensors files. This - entrypoint operates on a model stub or folder containing weights saved in - safetensors files, and updates the corresponding quantization_config field in - the config.json. All additional files will be copied to new checkpoint. + Convert a model checkpoint to either: + - its equivalent quantized format in compressed-tensors + - the unquantized format + without loading it up in memory, instead operating directly on the model + safetensors files. This entrypoint operates on a model stub or folder containing + weights saved in safetensors files, and updates the corresponding + quantization_config field in the config.json. All additional files will be + copied to new checkpoint. :param model_stub: huggingface model hub or path to local weights files :param save_directory: new checkpoint will be saved in this directory. @@ -45,30 +53,49 @@ def convert_checkpoint( :param converters: converter we wish to apply to the checkpoint, e.g. conversion of some layers from some format to compressed-tensors """ - # validate arguments + # get all model_files for checkpoint model_files = get_checkpoint_files(model_stub) - # 0. collect safetensors files, copy files + # Read weight map from safetensors.index.json + weight_map = get_weight_map(model_files) + + # Build inverse_weight_maps, so that each job knows how to load up every necessary + # weight and its dependencies + inverse_weight_maps = build_inverse_weight_maps( + weight_map=weight_map, + model_files=model_files, + converters=[converter], + ) + + # Build validation/conversion jobs, copy over any other file validate_jobs = [] convert_jobs = [] for file_path, resolved_path in model_files.items(): save_path = Path(save_directory) / file_path if file_path.endswith("safetensors"): - validate_jobs.append((validate_file, resolved_path, converter)) - convert_jobs.append((convert_file, resolved_path, save_path, converter)) + assert ( + file_path in inverse_weight_maps + ), f"Could not find inverse_weight_map for file {file_path}" + validate_jobs.append( + (validate_file, inverse_weight_maps[file_path], converter) + ) + convert_jobs.append( + (convert_file, inverse_weight_maps[file_path], save_path, converter) + ) else: if is_weights_file(file_path): logger.warning(f"Skip processing for weights file {file_path}") - save_path.parent.mkdir(parents=True, exist_ok=True) - logger.info(f"Copying {file_path} {save_path}") - shutil.copyfile(resolved_path, save_path) + if str(resolved_path) != str(save_path): + save_path.parent.mkdir(parents=True, exist_ok=True) + logger.info(f"Copying {file_path} {save_path}") + shutil.copyfile(resolved_path, save_path) - # 1. validate quantizable tensors fail fast before long-running quantization + # Validate before long-running procssing job exec_jobs(validate_jobs, max_workers, desc="Validating") - # 2-5. quantize and compress weights + # Process weights, accumulating total bytes used and the new weight_map total_size = 0 weight_map = dict() convert_results = exec_jobs(convert_jobs, max_workers, desc="Converting") @@ -76,7 +103,7 @@ def convert_checkpoint( total_size += _total_size weight_map.update(_weight_map) - # 5. update config and safetensors index + # Update config and safetensors index write_checkpoint_quantization_config(save_directory, converter) update_safetensors_index(save_directory, total_size, weight_map) @@ -93,6 +120,13 @@ def exec_jobs( :param desc: tqdm description """ results = [] + + # For easier debugging, don't run single-threaded jobs via ThreadPoolExecutor + if max_workers == 1: + for job in tqdm.tqdm(jobs, desc=desc): + results.append(job[0](*job[1:])) + return results + with ThreadPoolExecutor(max_workers) as executor: futures = [executor.submit(*job) for job in jobs] for future in tqdm.tqdm(as_completed(futures), total=len(futures), desc=desc): diff --git a/src/compressed_tensors/entrypoints/convert/convert_file.py b/src/compressed_tensors/entrypoints/convert/convert_file.py index 517e5316d..e933b8f36 100644 --- a/src/compressed_tensors/entrypoints/convert/convert_file.py +++ b/src/compressed_tensors/entrypoints/convert/convert_file.py @@ -7,9 +7,13 @@ from compressed_tensors import __version__ as ct_version from compressed_tensors.base import COMPRESSION_VERSION_NAME, QUANTIZATION_CONFIG_NAME from compressed_tensors.entrypoints.convert import Converter -from compressed_tensors.utils.safetensors_load import find_config_path +from compressed_tensors.utils.safetensors_load import ( + InverseWeightMap, + find_config_path, + load_tensors_from_inverse_weight_map, +) from loguru import logger -from safetensors.torch import load_file, save_file +from safetensors.torch import save_file __all__ = [ @@ -34,17 +38,23 @@ def write_checkpoint_quantization_config( :param converter: Converter instance whose create_config() produces the updated quantization config """ - quant_config = converter.create_config() - - quant_config_data = quant_config.model_dump() - quant_config_data[COMPRESSION_VERSION_NAME] = ct_version + quant_config_data = None + if (quant_config := converter.create_config()) is not None: + quant_config_data = quant_config.model_dump() + quant_config_data[COMPRESSION_VERSION_NAME] = ct_version config_file_path = find_config_path(save_directory) if config_file_path is not None: with open(config_file_path, "r") as file: config_data = json.load(file) - config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data + if quant_config_data is None: + # if no new quant config, make sure checkpoint quant config is empty + if QUANTIZATION_CONFIG_NAME in config_data: + del config_data[QUANTIZATION_CONFIG_NAME] + else: + # if new quant config, overwrite checkpoint quant config + config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data with open(config_file_path, "w") as file: json.dump(config_data, file, indent=2, sort_keys=True) @@ -57,35 +67,45 @@ def write_checkpoint_quantization_config( def validate_file( - file_path: str | os.PathLike, + inverse_weight_map: InverseWeightMap, converter: Converter, ): """ Validate that each quantizable tensor in a safetensors file can be quantized. - :param file_path: safetensors file to validate + :param inverse_weight_map: mapping of resolved source file path -> + list of tensor names to load from that file. Precomputed by + build_inverse_weight_map() in the job-building phase. + Example: {"/path/shard0.safetensors": ["q_proj.weight"], + "/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]} :param converter: converter we wish to apply to the checkpoint, e.g. conversion of some layers from some format to compressed-tensors """ - tensors = load_file(file_path) + tensors = load_tensors_from_inverse_weight_map(inverse_weight_map) converter.validate(tensors) def convert_file( - file_path: str | os.PathLike, + inverse_weight_map: InverseWeightMap, save_path: str | os.PathLike, converter: Converter, ) -> tuple[int, dict[str, str]]: """ Convert tensors in a given safetensors file - :param file_path: safetensors file to process + :param inverse_weight_map: mapping of resolved source file path -> + list of tensor names to load from that file. Precomputed by + build_inverse_weight_map() in the job-building phase. + Example: {"/path/shard0.safetensors": ["q_proj.weight"], + "/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]} :param save_path: save path of file with quantized weights :param converter: converter we wish to apply to the checkpoint, e.g. conversion of some layers from some format to compressed-tensors + :returns: tuple of (total_size, weight_map), respectively the total size in bytes + of the saved file and dictionary of weight name -> save path """ - tensors = load_file(file_path) + tensors = load_tensors_from_inverse_weight_map(inverse_weight_map) converter.process(tensors) diff --git a/src/compressed_tensors/entrypoints/convert/converters/__init__.py b/src/compressed_tensors/entrypoints/convert/converters/__init__.py index 0fea2a6bf..cb60d3824 100644 --- a/src/compressed_tensors/entrypoints/convert/converters/__init__.py +++ b/src/compressed_tensors/entrypoints/convert/converters/__init__.py @@ -6,3 +6,4 @@ from .base import * from .modelopt_nvfp4 import * +from .fp8block_dequantizer import * diff --git a/src/compressed_tensors/entrypoints/convert/converters/base.py b/src/compressed_tensors/entrypoints/convert/converters/base.py index 58480009c..7bd7e0d1d 100644 --- a/src/compressed_tensors/entrypoints/convert/converters/base.py +++ b/src/compressed_tensors/entrypoints/convert/converters/base.py @@ -3,11 +3,15 @@ from __future__ import annotations +from collections import defaultdict from typing import TYPE_CHECKING, Protocol import torch +from compressed_tensors.utils.safetensors_load import InverseWeightMap +__all__ = ["Converter", "build_inverse_weight_maps"] + if TYPE_CHECKING: from compressed_tensors.quantization import QuantizationConfig @@ -30,7 +34,7 @@ def process(self, tensors: dict[str, torch.Tensor]): - `model.layers.0.self_attn.q_proj.weight` - `model.layers.0.mlp.up_proj.weight_packed` """ - pass + raise NotImplementedError() def validate(self, tensors: dict[str, torch.Tensor]): """ @@ -40,11 +44,96 @@ def validate(self, tensors: dict[str, torch.Tensor]): :param tensors: dictionary of tensor name to tensor, as loaded from safetensors file. """ - pass + raise NotImplementedError() - def create_config(self) -> QuantizationConfig: + def create_config(self) -> QuantizationConfig | None: """ Create compressed-tensors QuantizationConfig so that it can be set in the new model checkpoint's config.json. + If the converter is moving checkpoint to full-precision, have this function + return None, and quantization_config will be removed from config.json + """ + raise NotImplementedError() + + def get_dependencies(self, weight_name: str) -> set[str]: + """ + Given a weight name, return a set of all dependency weight names, so that + weights can be processed correctly and in a parallelized fashion. + If there are no dependencies, an empty dict should be returned. + + :returns: set[str] of dependency weight names """ - pass + raise NotImplementedError() + + +def build_inverse_weight_maps( + weight_map: dict[str, str], + model_files: dict[str, str], + converters: list[Converter], +) -> dict[str, InverseWeightMap]: + """ + For a given output shard, precompute exactly which tensors to load from + which source files — including required partner tensors from other shards. + + This is necessary because some converters require that a set of tensors are + accessible in order for them to be processed correctly. + + :param shard_name: the shard filename this job will process and save + :param weight_map: tensor name -> shard filename (from safetensors.index.json) + :param model_files: shard filename -> resolved absolute path + :return: {resolved_file_path: [tensor_names_to_load]} + """ + + def get_dependencies_recursive( + weight_name: str, converters: list[Converter], current_deps: dict[str, bool] + ) -> dict[str, bool]: + for converter in converters: + for dep, is_required in converter.get_dependencies(weight_name).items(): + if dep not in current_deps: + current_deps[dep] = is_required + get_dependencies_recursive(dep, converters, current_deps) + + return current_deps + + # map of weight name -> ( map of dependency name -> is_required ) + weight_deps_dict: dict[str, set[str]] = dict() + for weight_name, weight_shard_name in weight_map.items(): + weight_deps_dict[weight_name] = get_dependencies_recursive( + weight_name, converters, {} + ) + assert ( + weight_name not in weight_deps_dict[weight_name] + ), f"{weight_name} found in dependencies {weight_deps_dict[weight_name]}" + + # set of all dependencies (i.e. all weight names required by another) + all_dependencies: set[str] = set().union(*weight_deps_dict.values()) + + inverse_weight_maps: dict[str, InverseWeightMap] = defaultdict( + lambda: defaultdict(list) + ) + for weight_name, weight_shard_name in weight_map.items(): + if weight_name in all_dependencies: + # weight is a partner to some other primary tensor, skip it + continue + + # weight is purely a primary weight, is not a dependency of anything + # add it and all its dependencies + inverse_weight_map: InverseWeightMap = inverse_weight_maps[weight_shard_name] + dependency_weights = weight_deps_dict[weight_name] + for weight_to_add_name, is_required in [ + (weight_name, True), + *dependency_weights.items(), + ]: + if weight_to_add_name not in weight_map: + if is_required: + raise ValueError( + f"Required weight {weight_to_add_name} not found in weight map" + ) + else: + continue + weight_to_add_shard_name = weight_map[weight_to_add_name] + resolved_path = model_files[weight_to_add_shard_name] + inverse_weight_map[resolved_path].append(weight_to_add_name) + + # return dicts, not defaultdicts, to avoid silent errors + return {k: dict(v) for k, v in inverse_weight_maps.items()} diff --git a/src/compressed_tensors/entrypoints/convert/converters/fp8block_dequantizer.py b/src/compressed_tensors/entrypoints/convert/converters/fp8block_dequantizer.py new file mode 100644 index 000000000..d5aee6bc2 --- /dev/null +++ b/src/compressed_tensors/entrypoints/convert/converters/fp8block_dequantizer.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Iterable + +import torch +from compressed_tensors.entrypoints.convert.converters import Converter +from compressed_tensors.quantization import QuantizationConfig +from compressed_tensors.quantization.utils.helpers import ( + maybe_pad_tensor_for_block_quant, +) +from compressed_tensors.utils.match import match_name, match_quantizable_tensors +from loguru import logger + + +class FP8BlockDequantizer(Converter): + """ + Dequantize a checkpoint that has been block-quantized with FP8 quant_method + The resultant weights will be stored in user-provided dtype + """ + + def __init__( + self, + ignore: Iterable[str] = tuple(), + targets: Iterable[str] = tuple(), + weight_block_size: tuple[int] = (128, 128), + dtype=torch.bfloat16, + ): + self.ignore = ignore + self.targets = targets + self.weight_block_size = weight_block_size + self.dtype = dtype + + def process(self, tensors: dict[str, torch.Tensor]): + """ + Map the modelopt NVFP4 tensors to the appropriate compressed-tensors + NVFP4 format. + Some tensors require rename, some require inversion + - 1 / input_scale -> input_global_scale + - weight -> weight_packed + - 1 / weight_scale_2 -> weight_global_scale + """ + for module_name, name in match_quantizable_tensors( + tensors, self.ignore, self.targets, allow_nonquantizable=True + ): + param_name = name.rsplit(".", 1)[-1] + + if param_name == "weight": + # weight * weight_scale_inv -> dequantized weight + tensors[f"{module_name}.weight"] = self._create_dequantized_weight( + tensors[f"{module_name}.weight"], + tensors[f"{module_name}.weight_scale_inv"], + ) + del tensors[f"{module_name}.weight_scale_inv"] + + def validate(self, tensors: dict[str, torch.Tensor]): + """ + Ensure all tensor names of targeted layers are expected and no + untargeted layers have unexpected tensor names + """ + allowed_names = ["weight", "weight_scale_inv"] + + targeted_names = [ + name + for _, name in match_quantizable_tensors( + tensors, self.ignore, self.targets, allow_nonquantizable=True + ) + ] + for name in targeted_names: + module_name, param_name = name.rsplit(".", 1) + + if param_name == "weight": + if f"{module_name}.weight_scale_inv" not in tensors: + raise ValueError( + f"Found weight without corresponding weight_scale_inv {name}" + ) + elif param_name == "weight_scale_inv": + if f"{module_name}.weight" not in tensors: + raise ValueError( + f"Found weight_scale_inv without corresponding weight {name}" + ) + elif param_name not in allowed_names: + raise ValueError(f"Found unexpected targeted tensor {name}") + + disallowed_names = ["weight_scale_inv"] + untargeted_names = [ + name for name in tensors.keys() if name not in targeted_names + ] + for name in untargeted_names: + param_name = name.rsplit(".", 1)[-1] + + if param_name in disallowed_names: + raise ValueError(f"Found unexpected non-targeted tensor {name}") + + def create_config(self) -> QuantizationConfig | None: + return None + + def get_dependencies(self, weight_name: str) -> set[str]: + module_name, suffix = weight_name.rsplit(".", 1) + if ( + any([match_name(module_name, target) for target in self.targets]) + and not any([match_name(module_name, ignore) for ignore in self.ignore]) + and suffix == "weight" + ): + return set(f"{module_name}.weight_scale_inv") + return set() + + def _create_dequantized_weight( + self, weight: torch.Tensor, weight_scale_inv: torch.Tensor + ) -> torch.Tensor: + """ + Convert fp8 weight and fp32 weight_scale_inv tensors into + corresponding dequantized weight tensor. + Tensors are upscaled to fp32 before scaling + + :return: dequantized tensor in self.dtype and same shape as input weight tensor + """ + original_shape = weight.shape + block_height, block_width = self.weight_block_size + + # Pad tensor if dimensions are not evenly divisible by block size + weight = maybe_pad_tensor_for_block_quant(weight, tuple(self.weight_block_size)) + padded_shape = weight.shape + + # Reshape into blocks of shape: + # (num_rows_blocks, block_height, num_cols_blocks, block_width) + num_rows_blocks = padded_shape[0] // block_height + num_cols_blocks = padded_shape[1] // block_width + weight_blocks = weight.reshape( + num_rows_blocks, + block_height, + num_cols_blocks, + block_width, + ).transpose( + 1, 2 + ) # (num_rows_blocks, num_cols_blocks, block_height, block_width) + + # Expand scale_inv for broadcasting over block dimensions + # weight_scale_inv shape: (num_rows_blocks, num_cols_blocks) + # Expand to: (num_rows_blocks, num_cols_blocks, 1, 1) + scale_inv_expanded = weight_scale_inv.unsqueeze(-1).unsqueeze(-1) + + # Dequantize: weight_bf16 = weight_fp8 * weight_scale_inv + dequantized_blocks = ( + weight_blocks.to(torch.float32) * scale_inv_expanded.to(torch.float32) + ).to(self.dtype) + + # Restore padded shape + dequantized = dequantized_blocks.transpose(1, 2).reshape(padded_shape) + + # Truncate to original dimensions if padding was applied + if original_shape != padded_shape: + dequantized = dequantized[tuple([slice(v) for v in original_shape])] + + return dequantized diff --git a/src/compressed_tensors/entrypoints/convert/converters/modelopt_nvfp4.py b/src/compressed_tensors/entrypoints/convert/converters/modelopt_nvfp4.py index a36a388a5..f49183a3f 100644 --- a/src/compressed_tensors/entrypoints/convert/converters/modelopt_nvfp4.py +++ b/src/compressed_tensors/entrypoints/convert/converters/modelopt_nvfp4.py @@ -13,7 +13,7 @@ QuantizationStatus, ) from compressed_tensors.quantization.quant_scheme import NVFP4 -from compressed_tensors.utils.match import match_quantizable_tensors +from compressed_tensors.utils.match import match_name, match_quantizable_tensors class ModelOptNvfp4Converter(Converter): @@ -107,6 +107,30 @@ def validate(self, tensors: dict[str, torch.Tensor]): if param_name in disallowed_names: raise ValueError(f"Hit unexpected non-targeted tensor {name}") + def get_dependencies(self, weight_name: str) -> dict[str, bool]: + module_name, suffix = weight_name.rsplit(".", 1) + if ( + any([match_name(module_name, target) for target in self.targets]) + and not any([match_name(module_name, ignore) for ignore in self.ignore]) + and suffix == "weight" + ): + deps = set( + f"{module_name}.input_scale", + f"{module_name}.weight_scale", + f"{module_name}.weight_scale_2", + ) + + if self.kv_cache_scheme: + if module_name.endswith("k_proj"): + deps |= {f"{module_name}.k_scale"} + if module_name.endswith("v_proj"): + deps |= { f"{module_name}.v_scale"} + + + return deps + + return {} + def create_config(self) -> QuantizationConfig: return QuantizationConfig( config_groups={ diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index ef9e78a72..437ac1465 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -5,8 +5,10 @@ import os import re import struct +from collections import defaultdict from collections.abc import Iterable +import torch from huggingface_hub import list_repo_files from safetensors import safe_open from safetensors.torch import save_file @@ -21,9 +23,14 @@ "get_weight_mappings", "get_nested_weight_mappings", "get_quantization_parameter_to_path_mapping", + "get_file_map", + "InverseWeightMap", + "load_tensors_from_inverse_weight_map", "is_quantization_param", "find_config_path", "find_safetensors_index_path", + "find_safetensors_index_file", + "get_weight_map", "update_safetensors_index", "is_weights_file", "get_checkpoint_files", @@ -65,14 +72,16 @@ def get_checkpoint_files(model_stub: str | os.PathLike) -> dict[str, str]: # In the future, this function can accept and pass download kwargs to cached_file if os.path.exists(model_stub): - file_paths = _walk_directory_files(model_stub, ignore=".cache") + file_paths = _walk_directory_files( + model_stub, ignore=[".cache", ".gitattributes"] + ) else: file_paths = list_repo_files(model_stub) return {file_path: cached_file(model_stub, file_path) for file_path in file_paths} -def _walk_directory_files(root_dir: str, ignore: str | None = None) -> list[str]: +def _walk_directory_files(root_dir: str, ignore: Iterable[str]) -> list[str]: """ Return all file paths relative to root_dir, optionally skipping entries whose relative path starts with `ignore`. @@ -81,11 +90,14 @@ def _walk_directory_files(root_dir: str, ignore: str | None = None) -> list[str] :param ignore: optional path prefix to exclude from results :return: list of relative file paths """ + if isinstance(ignore, str): + ignore = [ignore] + all_files = [] for dirpath, _, filenames in os.walk(root_dir): for filename in filenames: rel_path = os.path.relpath(os.path.join(dirpath, filename), root_dir) - if not (ignore and rel_path.startswith(ignore)): + if not any([rel_path.startswith(i) for i in ignore]): all_files.append(rel_path) return all_files @@ -116,6 +128,45 @@ def find_config_path(save_directory: str | os.PathLike) -> str | None: return None +def find_safetensors_index_file(model_files: dict[str, str]) -> str | None: + """ + Find safetensors index file from full list of model_files + + :param model_files: mapping of file relative path to absolute path, usually the + result of `get_checkpoint_files` + :return: absolute path to the safetensors index file, or None if not found + """ + for file_path, resolved_path in model_files.items(): + if file_path.endswith(SAFE_WEIGHTS_INDEX_NAME): + return resolved_path + + return None + + +def get_weight_map(model_files: dict[str, str]) -> dict[str, str]: + """ + Get weight map from full list of model_files. + If safetensors index.json file is found, weight_map can be pulled from there. + Otherwise, it is created from the single safetensors weights file. + + :returns: weight map of form {weight name -> safetensor file name} + """ + index_file = find_safetensors_index_file(model_files) + if index_file is not None: + with open(index_file, "r") as f: + return json.load(f)["weight_map"] + + # if no index_file, use model.saftensors instead. + if SAFE_WEIGHTS_NAME not in model_files: + raise ValueError( + f"File {SAFE_WEIGHTS_NAME} expected but not found in {model_files.keys()}" + ) + + # create from model.safetensors + with safe_open(model_files[SAFE_WEIGHTS_NAME], "r") as file: + return {tensor: SAFE_WEIGHTS_NAME for tensor in file.keys()} + + def update_safetensors_index( save_directory: str | os.PathLike, total_size: int, @@ -358,6 +409,67 @@ def get_quantization_parameter_to_path_mapping(model_path: str) -> dict[str, str return mapping +def get_file_map(weight_map: WeightMappingType) -> dict[str, list[str]]: + """ + Given a safetensors index file's weight_map, which maps weight name to safetensors + file name, return a mapping of safetensors file name to list of weight names + + :param weight_map: mapping of weight name to safetensors file name. + result of `get_weight_mappings` + :returns: file_map + """ + + file_map = defaultdict(list) + for k, v in weight_map.items(): + file_map[v].append(k) + + return dict(file_map) + + +InverseWeightMap = dict[str, list[str] | None] +""" +Mapping of absolute path -> list of tensors. Used to pull tensors across different +safetensors files that must be loaded/processed together. Used in conjunction +with `load_tensors_from_inverse_weight_map` +""" + + +def load_tensors_from_inverse_weight_map( + inverse_weight_map: InverseWeightMap, + device: str | torch.device = torch.device("cpu"), +) -> dict[str, torch.Tensor]: + """ + Given an inverse_weight_map, which is a dictionary of file name to list of + tensor names, load up all listed tensor names + + :param inverse_weight_map: mapping of resolved source file path -> + list of tensor names to load from that file. Precomputed by + build_inverse_weight_map() in the job-building phase. + If list is empty, all tensors are pulled + Example: {"/path/shard0.safetensors": ["q_proj.weight"], + "/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]} + :param device: tensors will be loaded onto this device. Defaults to CPU + + :returns: mapping of tensor name to actual tensor loaded from safetensors file + Example: {"q_proj.weight": torch.Tensor(...), "k_proj.weight: torch.Tensor(...)} + """ + tensors: dict[str, torch.Tensor] = {} + for source_file, tensor_names in inverse_weight_map.items(): + with safe_open(source_file, framework="pt", device=str(device)) as f: + keys = f.keys() + # if tensor_names is empty, pull all tensors + if tensor_names is None or len(tensor_names) == 0: + tensor_names = keys + for tensor_name in tensor_names: + if tensor_name not in keys: + raise ValueError( + f"Expected to find tensor {tensor_name} in " + f"{source_file}, but tensor was not found." + ) + tensors[tensor_name] = f.get_tensor(tensor_name) + return tensors + + def is_quantization_param(name: str) -> bool: """ Checks is a parameter name is associated with a quantization parameter diff --git a/tests/test_entrypoints/convert/converters/test_build_inverse_weight_maps.py b/tests/test_entrypoints/convert/converters/test_build_inverse_weight_maps.py new file mode 100644 index 000000000..62b81428a --- /dev/null +++ b/tests/test_entrypoints/convert/converters/test_build_inverse_weight_maps.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest +import torch +from compressed_tensors.entrypoints.convert import ( + FP8BlockDequantizer, + build_inverse_weight_maps, +) +from compressed_tensors.utils.safetensors_load import get_checkpoint_files +from safetensors.torch import save_file + + +@pytest.mark.unit +def test_build_inverse_weight_maps(tmp_path): + """ + Test that reindex_checkpoint correctly moves tensors across files + so that weight and weight_scale_inv end up in the same file. + """ + # Create dummy checkpoint with weights split across files + model_dir = tmp_path / "model" + model_dir.mkdir() + + # File 1: has layer0.weight but NOT layer0.weight_scale_inv + file1_tensors = { + "embed_tokens.weight": torch.randn(128, 128, dtype=torch.float32), + "layer0.weight": torch.randn(128, 128, dtype=torch.float32).to( + torch.float8_e4m3fn + ), + "layer1.weight_scale_inv": torch.randn(1, 1, dtype=torch.float32), + } + file1_path = model_dir / "model-00001-of-00002.safetensors" + save_file(file1_tensors, str(file1_path)) + + # File 2: has layer0.weight_scale_inv and layer1.weight_scale_inv + file2_tensors = { + "layer0.weight_scale_inv": torch.randn(1, 1, dtype=torch.float32), + "layer1.weight": torch.randn(128, 128, dtype=torch.float32).to( + torch.float8_e4m3fn + ), + "layer2.weight": torch.randn(128, 128, dtype=torch.float32).to( + torch.float8_e4m3fn + ), + "layer2.weight_scale_inv": torch.randn(1, 1, dtype=torch.float32), + "lm_head.weight": torch.randn(128, 128, dtype=torch.float32), + } + file2_path = model_dir / "model-00002-of-00002.safetensors" + save_file(file2_tensors, str(file2_path)) + + # Create index file + weight_map = { + "embed_tokens.weight": "model-00001-of-00002.safetensors", + "layer0.weight": "model-00001-of-00002.safetensors", + "layer1.weight": "model-00002-of-00002.safetensors", + "layer0.weight_scale_inv": "model-00002-of-00002.safetensors", + "layer1.weight_scale_inv": "model-00001-of-00002.safetensors", + "layer2.weight": "model-00002-of-00002.safetensors", + "layer2.weight_scale_inv": "model-00002-of-00002.safetensors", + "lm_head.weight": "model-00002-of-00002.safetensors", + } + + index_data = { + "metadata": { + "total_size": sum( + t.numel() * t.element_size() + for tensors in [file1_tensors, file2_tensors] + for t in tensors.values() + ) + }, + "weight_map": weight_map, + } + index_path = model_dir / "model.safetensors.index.json" + with open(index_path, "w") as f: + json.dump(index_data, f) + + # Create config.json (required by get_checkpoint_files) + config_path = model_dir / "config.json" + with open(config_path, "w") as f: + json.dump({"model_type": "test"}, f) + + converter = FP8BlockDequantizer(targets=[r"re:.*layer\d.*"]) + + inverse_weight_maps = build_inverse_weight_maps( + weight_map=weight_map, + model_files=get_checkpoint_files(model_dir), + converters=[converter], + ) + + for file_name in ( + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + ): + assert ( + file_name in inverse_weight_maps + ), f"File {file_name} missing in inverse_weight_maps" + + seen_weight_names = set() + for inverse_weight_map in inverse_weight_maps.values(): + for weight_names in inverse_weight_map.values(): + for weight_name in weight_names: + assert ( + weight_name not in seen_weight_names + ), f"duplicate weight {weight_name} found" + seen_weight_names.add(weight_name) + + all_weight_names = set(weight_map.keys()) + assert ( + seen_weight_names >= all_weight_names + ), f"Some weights are missing, {all_weight_names - seen_weight_names}" + assert ( + all_weight_names >= seen_weight_names + ), f"Extraneous weights added, {seen_weight_names - all_weight_names}" diff --git a/tests/test_entrypoints/convert/converters/test_fp8block_bfloat16.py b/tests/test_entrypoints/convert/converters/test_fp8block_bfloat16.py new file mode 100644 index 000000000..add94b98b --- /dev/null +++ b/tests/test_entrypoints/convert/converters/test_fp8block_bfloat16.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +from compressed_tensors.entrypoints.convert import FP8BlockDequantizer + + +@pytest.mark.unit +def test_fp8_block_to_bfloat16_conversion(): + """ + Test that _create_bfloat16_weight correctly converts FP8 block-quantized + weights to bfloat16 by multiplying by the scale_inv per block. + """ + converter = FP8BlockDequantizer(weight_block_size=(128, 128)) + + # Create a weight tensor divisible by block size (256x256 = 2x2 blocks of 128x128) + original_weight = torch.randn(256, 256, dtype=torch.bfloat16) + + # Simulate block quantization: divide into blocks and create per-block scales + num_row_blocks = 2 + num_col_blocks = 2 + + # Create per-block scale_inv (2x2 for 2x2 blocks) + weight_scale_inv = torch.randn(num_row_blocks, num_col_blocks, dtype=torch.float32) + + # Convert original to fp8 (simulate quantization by just converting dtype) + weight_fp8 = original_weight.to(torch.float32).to(torch.float8_e4m3fn) + + # Test conversion + result = converter._create_bfloat16_weight(weight_fp8, weight_scale_inv) + + # Verify using helper + _verify_block_conversion(result, weight_fp8, weight_scale_inv, (128, 128)) + + +@pytest.mark.unit +def test_fp8_block_to_bfloat16_conversion_with_padding(): + """ + Test that _create_bfloat16_weight correctly handles tensors that need padding + (dimensions not evenly divisible by block size). + """ + converter = FP8BlockDequantizer(weight_block_size=(128, 128)) + + # Create a weight tensor NOT divisible by block size (200x300) + # Should be padded to 256x384 (2x3 blocks) + weight_fp8 = torch.randn(200, 300, dtype=torch.float32).to(torch.float8_e4m3fn) + + # Scale_inv for padded size: 2 row blocks x 3 col blocks + num_row_blocks = 2 # ceil(200/128) = 2 + num_col_blocks = 3 # ceil(300/128) = 3 + weight_scale_inv = torch.ones(num_row_blocks, num_col_blocks, dtype=torch.float32) + + # Test conversion + result = converter._create_bfloat16_weight(weight_fp8, weight_scale_inv) + + # Verify output shape matches original (not padded) + assert result.shape == (200, 300), "Output shape should match original, not padded" + assert result.dtype == torch.bfloat16, "Output dtype should be bfloat16" + + +@pytest.mark.unit +def test_fp8_block_converter_process(): + """ + Test that the converter's process method correctly converts FP8 block-quantized + tensors in a dict to bfloat16, removing weight_scale_inv tensors. + """ + converter = FP8BlockDequantizer( + targets=[r"re:.*layer\d+\.mlp\..*proj$"], weight_block_size=(128, 128) + ) + + # Create mock tensors dict with FP8 weights and scale_inv tensors + num_row_blocks = 2 + num_col_blocks = 2 + + # Non-targeted tensor (should not be modified) + non_targeted_weight = torch.randn(128, 128, dtype=torch.bfloat16) + + tensors = { + "model.layer0.mlp.up_proj.weight": torch.randn( + 256, 256, dtype=torch.float32 + ).to(torch.float8_e4m3fn), + "model.layer0.mlp.up_proj.weight_scale_inv": torch.randn( + num_row_blocks, num_col_blocks, dtype=torch.float32 + ), + "model.layer1.mlp.down_proj.weight": torch.randn( + 256, 256, dtype=torch.float32 + ).to(torch.float8_e4m3fn), + "model.layer1.mlp.down_proj.weight_scale_inv": torch.randn( + num_row_blocks, num_col_blocks, dtype=torch.float32 + ), + "model.embed_tokens.weight": non_targeted_weight, + } + + # Save references to original tensors before processing + weight_fp8_layer0 = tensors["model.layer0.mlp.up_proj.weight"].clone() + scale_inv_layer0 = tensors["model.layer0.mlp.up_proj.weight_scale_inv"].clone() + weight_fp8_layer1 = tensors["model.layer1.mlp.down_proj.weight"].clone() + scale_inv_layer1 = tensors["model.layer1.mlp.down_proj.weight_scale_inv"].clone() + + # Process the tensors + converter.process(tensors) + + # Verify that weight_scale_inv tensors were removed + assert ( + "model.layer0.mlp.up_proj.weight_scale_inv" not in tensors + ), "weight_scale_inv should be removed" + assert ( + "model.layer1.mlp.down_proj.weight_scale_inv" not in tensors + ), "weight_scale_inv should be removed" + + # Verify that weights were converted to bfloat16 + assert "model.layer0.mlp.up_proj.weight" in tensors, "weight should still exist" + assert "model.layer1.mlp.down_proj.weight" in tensors, "weight should still exist" + + # Verify the conversion is correct using helper + _verify_block_conversion( + tensors["model.layer0.mlp.up_proj.weight"], + weight_fp8_layer0, + scale_inv_layer0, + (128, 128), + ) + _verify_block_conversion( + tensors["model.layer1.mlp.down_proj.weight"], + weight_fp8_layer1, + scale_inv_layer1, + (128, 128), + ) + + # Verify non-targeted tensor was not modified + assert torch.equal( + tensors["model.embed_tokens.weight"], non_targeted_weight + ), "Non-targeted tensor should not be modified" + + +def _verify_block_conversion( + result: torch.Tensor, + weight_fp8: torch.Tensor, + weight_scale_inv: torch.Tensor, + block_size: tuple[int, int], +): + """ + Helper method to verify that FP8 block conversion to bfloat16 is correct. + Checks that each block is correctly scaled by its corresponding scale_inv value. + """ + block_height, block_width = block_size + num_row_blocks = weight_scale_inv.shape[0] + num_col_blocks = weight_scale_inv.shape[1] + + # Verify output properties + assert result.shape == weight_fp8.shape, "Output shape should match input shape" + assert result.dtype == torch.bfloat16, "Output dtype should be bfloat16" + + # Verify the conversion logic: each block should be multiplied by its scale_inv + for row_block in range(num_row_blocks): + for col_block in range(num_col_blocks): + row_start = row_block * block_height + row_end = min((row_block + 1) * block_height, result.shape[0]) + col_start = col_block * block_width + col_end = min((col_block + 1) * block_width, result.shape[1]) + + # Get the block from result + result_block = result[row_start:row_end, col_start:col_end] + + # Get expected: weight_fp8 block * scale_inv + expected_block = ( + weight_fp8[row_start:row_end, col_start:col_end].to(torch.float32) + * weight_scale_inv[row_block, col_block].to(torch.float32) + ).to(torch.bfloat16) + + # They should be equal (within floating point precision) + assert torch.allclose( + result_block.to(torch.float32), + expected_block.to(torch.float32), + rtol=1e-2, + atol=1e-3, + ), f"Block ({row_block}, {col_block}) conversion mismatch"