diff --git a/examples/model_free_ptq/deepseek_r1_nvfp4_fp8_block.py b/examples/model_free_ptq/deepseek_r1_nvfp4_fp8_block.py new file mode 100644 index 0000000000..594d680c0a --- /dev/null +++ b/examples/model_free_ptq/deepseek_r1_nvfp4_fp8_block.py @@ -0,0 +1,54 @@ +from compressed_tensors.entrypoints.convert import ( + ModelOptNvfp4Converter, +) +from compressed_tensors.quantization import ( + QuantizationScheme, +) +from compressed_tensors.quantization.quant_scheme import FP8_BLOCK + +from llmcompressor import model_free_ptq + +MODEL_ID = "nvidia/DeepSeek-R1-NVFP4" +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-BLOCK" + + +# Convert modelopt NVFP4 format to compressed-tensors format and +# apply FP8-Block to the model's compatible self_attn Linear layers +# Once quantized, the model is saved to SAVE_DIR. +model_free_ptq( + model_stub=MODEL_ID, + save_directory=SAVE_DIR, + scheme=QuantizationScheme( + **FP8_BLOCK, + targets=[ + # Target fused layers, must have the same quant config + # shape 576x7168 is compatible with block size 128x128 + # - self_attn.kv_a_proj_with_mqa + # - self_attn.q_a_proj + "re:.*self_attn.(kv_a_proj_with_mqa|q_a_proj)$", + # Skip self_attn.kv_b_proj, already dequantized by MLA + # Target remaining self_attn layers: + # - self_attn.o_proj + # - self_attn.q_b_proj + "re:.*self_attn.(o_proj|q_b_proj).*", + ], + ), + max_workers=8, + device="cuda:0", + converter=ModelOptNvfp4Converter( + targets=[ + # nvidia/DeepSeek-R1-NVFP4's nvfp4-quantized layers, found by inspection + # - model.layers.0.mlp.down_proj.weight + # - model.layers.0.mlp.gate_proj.weight + # - model.layers.0.mlp.up_proj.weight + # - model.layers.3.mlp.shared_experts.down_proj.weight + # - model.layers.3.mlp.shared_experts.gate_proj.weight + # - model.layers.3.mlp.shared_experts.up_proj.weight + # - model.layers.3.mlp.experts.0.down_proj.weight + # - model.layers.3.mlp.experts.0.gate_proj.weight + # - model.layers.3.mlp.experts.0.up_proj.weight + # NOTE: gate_up_proj also needs to be targeted, gate/up are fused + "re:.*mlp.*(gate_up|gate|up|down)_proj$" + ] + ), +) diff --git a/src/llmcompressor/entrypoints/model_free/__init__.py b/src/llmcompressor/entrypoints/model_free/__init__.py index 745ce86076..0c3aea70c5 100644 --- a/src/llmcompressor/entrypoints/model_free/__init__.py +++ b/src/llmcompressor/entrypoints/model_free/__init__.py @@ -1,11 +1,16 @@ import os import shutil -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Iterable, Optional import torch -import tqdm +from compressed_tensors.entrypoints.convert import ( + Converter, + exec_jobs, + get_checkpoint_files, + is_weights_file, + update_safetensors_index, +) from compressed_tensors.quantization import QuantizationScheme from loguru import logger @@ -13,10 +18,6 @@ from llmcompressor.entrypoints.model_free.microscale import ( is_microscale_scheme, ) -from llmcompressor.entrypoints.model_free.model_utils import ( - get_checkpoint_files, - is_weights_file, -) from llmcompressor.entrypoints.model_free.process import ( process_file, process_file_microscale_scheme, @@ -24,7 +25,6 @@ ) from llmcompressor.entrypoints.model_free.save_utils import ( update_config, - update_safetensors_index, ) from llmcompressor.entrypoints.model_free.validate import ( validate_safetensors_index, @@ -41,6 +41,7 @@ def model_free_ptq( ignore: Iterable[str] = tuple(), max_workers: int = 1, device: Optional[torch.device | str] = None, + converter: Converter | None = None, ): """ Quantize a model without the need for a model definition. This function operates on @@ -52,6 +53,10 @@ def model_free_ptq( ignored :param max_workers: number of worker threads to process files with :param device: gpu device to accelerate quantization with + :param converter: optional converter to apply to the checkpoint to convert it to + compressed-tensors format before running model-free PTQ + e.g. conversion of some layers from modelopt format to compressed-tensors + See compressed-tensors convert_checkpoint entrypoint for more information """ # validate arguments model_files = get_checkpoint_files(model_stub) @@ -70,7 +75,9 @@ def model_free_ptq( save_path = Path(save_directory) / file_path if file_path.endswith("safetensors"): - jobs.append((job_fn, resolved_path, save_path, scheme, ignore, device)) + jobs.append( + (job_fn, resolved_path, save_path, scheme, ignore, device, converter) + ) else: if is_weights_file(file_path): @@ -79,25 +86,19 @@ def model_free_ptq( logger.info(f"Copying {file_path} {save_path}") shutil.copyfile(resolved_path, save_path) - with ThreadPoolExecutor(max_workers) as executor: - # 1. validate quantizable tensors fail fast before long-running quantization - futures = [executor.submit(validate_file, *job[1:]) for job in jobs] - for future in tqdm.tqdm( - as_completed(futures), total=len(futures), desc="Validating" - ): - future.result() + # 1. validate quantizable tensors fail fast before long-running quantization + exec_jobs( + [(validate_file, *job[1:]) for job in jobs], max_workers, desc="Validating" + ) - # 2-5. quantize and compress weights - total_size = 0 - weight_map = dict() - futures = [executor.submit(*job) for job in jobs] - for future in tqdm.tqdm( - as_completed(futures), total=len(futures), desc="Quantizing" - ): - _total_size, _weight_map = future.result() - total_size += _total_size - weight_map.update(_weight_map) + # 2-5. quantize and compress weights + total_size = 0 + weight_map = dict() + quantize_results = exec_jobs(jobs, max_workers, desc="Quantizing") + for _total_size, _weight_map in quantize_results: + total_size += _total_size + weight_map.update(_weight_map) # 5. update config and safetensors index - update_config(save_directory, scheme_name, scheme, ignore) + update_config(save_directory, scheme_name, scheme, ignore, converter) update_safetensors_index(save_directory, total_size, weight_map) diff --git a/src/llmcompressor/entrypoints/model_free/helpers.py b/src/llmcompressor/entrypoints/model_free/helpers.py index 9b1b130880..3a1a651e31 100644 --- a/src/llmcompressor/entrypoints/model_free/helpers.py +++ b/src/llmcompressor/entrypoints/model_free/helpers.py @@ -1,4 +1,3 @@ -import os import re from collections import defaultdict from typing import Mapping, TypeVar @@ -6,12 +5,9 @@ import torch from compressed_tensors.utils.match import match_name from loguru import logger -from transformers.file_utils import CONFIG_NAME __all__ = [ "gpu_if_available", - "find_safetensors_index_path", - "find_config_path", "find_safetensors_index_file", "match_names_set_eager", "MatchedNamesSet", @@ -43,22 +39,6 @@ def gpu_if_available(device: torch.device | str | None) -> torch.device: return torch.device("cpu") -def find_safetensors_index_path(save_directory: str | os.PathLike) -> str | None: - for file_name in os.listdir(save_directory): - if file_name.endswith("safetensors.index.json"): - return os.path.join(save_directory, file_name) - - return None - - -def find_config_path(save_directory: str | os.PathLike) -> str | None: - for file_name in os.listdir(save_directory): - if file_name in (CONFIG_NAME, "params.json"): - return os.path.join(save_directory, file_name) - - return None - - def find_safetensors_index_file(model_files: dict[str, str]) -> str | None: for file_path, resolved_path in model_files.items(): if file_path.endswith("safetensors.index.json"): diff --git a/src/llmcompressor/entrypoints/model_free/model_utils.py b/src/llmcompressor/entrypoints/model_free/model_utils.py deleted file mode 100644 index 5a6b92f716..0000000000 --- a/src/llmcompressor/entrypoints/model_free/model_utils.py +++ /dev/null @@ -1,48 +0,0 @@ -import os - -from huggingface_hub import list_repo_files -from transformers.utils.hub import cached_file - -__all__ = ["get_checkpoint_files", "is_weights_file"] - -weights_files = [ - ".bin", - ".safetensors", - ".pth", - ".msgpack", - ".pt", -] - - -def is_weights_file(file_name: str) -> bool: - return any(file_name.endswith(suffix) for suffix in weights_files) - - -def get_checkpoint_files(model_stub: str | os.PathLike) -> dict[str, str]: - # In the future, this function can accept and pass download kwargs to cached_file - - if os.path.exists(model_stub): - file_paths = walk_file_paths(model_stub, ignore=".cache") - else: - file_paths = list_repo_files(model_stub) - - return {file_path: cached_file(model_stub, file_path) for file_path in file_paths} - - -def walk_file_paths(root_dir: str, ignore: str | None = None) -> list[str]: - """ - Return all file paths relative to the root directory - """ - - all_files = [] - for dirpath, _, filenames in os.walk(root_dir): - for filename in filenames: - rel_path = os.path.relpath(os.path.join(dirpath, filename), root_dir) - if not (ignore and rel_path.startswith(ignore)): - all_files.append(rel_path) - return all_files - - -# distinguish relative file paths from absolute/resolved file paths -# relative file paths are used to find the save path -# resolved file paths are what are used to load data diff --git a/src/llmcompressor/entrypoints/model_free/process.py b/src/llmcompressor/entrypoints/model_free/process.py index 9b4412ee1d..44835c7f81 100644 --- a/src/llmcompressor/entrypoints/model_free/process.py +++ b/src/llmcompressor/entrypoints/model_free/process.py @@ -1,11 +1,11 @@ import os from collections import defaultdict -from collections.abc import Iterator, Mapping from typing import Iterable import torch +from compressed_tensors.entrypoints.convert import Converter from compressed_tensors.quantization import QuantizationScheme -from compressed_tensors.utils.match import match_name +from compressed_tensors.utils import match_quantizable_tensors from safetensors.torch import load_file, save_file from torch.nn import Module @@ -21,21 +21,11 @@ is_microscale_scheme, ) -__all__ = ["validate_file", "process_file", "process_file_microscale_scheme"] - - -def iter_quantizable_tensors( - tensors: Mapping[str, torch.Tensor], - ignore: Iterable[str], -) -> Iterator[tuple[str, str]]: - for name in list(tensors.keys()): - module_name, param_name = name.rsplit(".", 1) - is_linear_weight = param_name == "weight" and not module_name.endswith("norm") - is_ignored = any(match_name(module_name, ign) for ign in ignore) - if not is_linear_weight or is_ignored: - continue - - yield module_name, name +__all__ = [ + "validate_file", + "process_file", + "process_file_microscale_scheme", +] def validate_file( @@ -44,6 +34,7 @@ def validate_file( scheme: QuantizationScheme, ignore: Iterable[str], device: str | torch.device, + converter: Converter | None = None, ): """ Validate that each quantizable tensor in a safetensors file can be quantized. @@ -52,10 +43,15 @@ def validate_file( :param scheme: quantization scheme to apply to tensors :param ignore: modules to ignore. Modules ending with "norm" are automatically ignored + :param converter: optional converter to apply to the checkpoint, + e.g. conversion of some layers from some format to compressed-tensors """ tensors = load_file(file_path) - for _, name in iter_quantizable_tensors(tensors, ignore): + if converter is not None: + converter.validate(tensors) + + for _, name in match_quantizable_tensors(tensors, ignore, scheme.targets): validate_weight_for_quantization(tensors[name], scheme, name) @@ -65,6 +61,7 @@ def process_file( scheme: QuantizationScheme, ignore: Iterable[str], device: str | torch.device, + converter: Converter | None = None, ) -> tuple[int, dict[str, str]]: """ Quantize and compress tensors in a given safetensors file @@ -75,11 +72,16 @@ def process_file( :param ignore: modules to ignore. Modules ending with "norm" are automatically ignored :param device: device used to quantize and compress weights + :param converter: optional converter to apply to the checkpoint, + e.g. conversion of some layers from some format to compressed-tensors """ assert not is_microscale_scheme(scheme), "Use `_process_file_microscale_scheme`" tensors = load_file(file_path) - for module_name, name in iter_quantizable_tensors(tensors, ignore): + if converter is not None: + converter.process(tensors) + + for module_name, name in match_quantizable_tensors(tensors, ignore, scheme.targets): validate_weight_for_quantization(tensors[name], scheme, name) # 1. initialize module with qparams (on device) @@ -109,6 +111,7 @@ def process_file_microscale_scheme( scheme: QuantizationScheme, ignore: Iterable[str], device: str | torch.device, + converter: Converter | None = None, ) -> tuple[int, dict[str, str]]: """ Quantize and compress tensors in a given safetensors file @@ -119,9 +122,15 @@ def process_file_microscale_scheme( :param ignore: modules to ignore. Modules ending with "norm" are automatically ignored :param device: device used to quantize and compress weights + :param converter: optional converter to apply to the checkpoint, + e.g. conversion of some layers from some format to compressed-tensors """ assert is_microscale_scheme(scheme), "Use `_process_file` for non-microscale scheme" tensors = load_file(file_path) + + if converter is not None: + converter.process(tensors) + fused_sets, unmatched_sets = get_fused_names(tensors) assert len(unmatched_sets) <= 0 # should be caught by `validate_safetensors_index` @@ -135,7 +144,7 @@ def process_file_microscale_scheme( } fused_modules = defaultdict(dict) - for module_name, name in iter_quantizable_tensors(tensors, ignore): + for module_name, name in match_quantizable_tensors(tensors, ignore, scheme.targets): validate_weight_for_quantization(tensors[name], scheme, name) # 1. initialize module with qparams (on device) diff --git a/src/llmcompressor/entrypoints/model_free/reindex_fused_weights.py b/src/llmcompressor/entrypoints/model_free/reindex_fused_weights.py index a5b5bbd2d5..a3f7508869 100644 --- a/src/llmcompressor/entrypoints/model_free/reindex_fused_weights.py +++ b/src/llmcompressor/entrypoints/model_free/reindex_fused_weights.py @@ -7,6 +7,11 @@ import torch import tqdm +from compressed_tensors.entrypoints.convert import ( + get_checkpoint_files, + is_weights_file, + update_safetensors_index, +) from loguru import logger from safetensors.torch import load_file, save_file @@ -15,11 +20,6 @@ invert_mapping, ) from llmcompressor.entrypoints.model_free.microscale import get_fused_names -from llmcompressor.entrypoints.model_free.model_utils import ( - get_checkpoint_files, - is_weights_file, -) -from llmcompressor.entrypoints.model_free.save_utils import update_safetensors_index def parse_args(): diff --git a/src/llmcompressor/entrypoints/model_free/save_utils.py b/src/llmcompressor/entrypoints/model_free/save_utils.py index 6d7ad2908b..2b27c0da15 100644 --- a/src/llmcompressor/entrypoints/model_free/save_utils.py +++ b/src/llmcompressor/entrypoints/model_free/save_utils.py @@ -9,16 +9,17 @@ SPARSITY_CONFIG_NAME, TRANSFORM_CONFIG_NAME, ) +from compressed_tensors.config import CompressionFormat +from compressed_tensors.entrypoints.convert import Converter, find_config_path from compressed_tensors.quantization import ( QuantizationConfig, QuantizationScheme, QuantizationStatus, ) from loguru import logger +from pydantic import ValidationError -from .helpers import find_config_path, find_safetensors_index_path - -__all__ = ["update_config", "update_safetensors_index"] +__all__ = ["update_config"] def update_config( @@ -26,29 +27,31 @@ def update_config( scheme_name: str, scheme: QuantizationScheme, ignore: list[str], + converter: Converter | None = None, ): - # construct quantization config - qconfig = QuantizationConfig.model_validate( - { - "config_groups": {scheme_name: scheme}, - "ignore": ignore, - "quantization_status": QuantizationStatus.COMPRESSED, - } + """ + Update Quantization config for model stub in save_directory, + based on the provided scheme and converter. + Quantization config will either be created or updated, see + create_or_update_quant_config docstring for more info. + """ + config_file_path = find_config_path(save_directory) + + qconfig = create_or_update_quant_config( + config_file_path, scheme_name, scheme, ignore, converter ) # construct compression (quantization) config - qconfig_data = qconfig.model_dump(exclude=[QUANTIZATION_METHOD_NAME, "format"]) + qconfig_data = qconfig.model_dump(exclude=[QUANTIZATION_METHOD_NAME]) qconfig_data = { COMPRESSION_VERSION_NAME: ct_version, QUANTIZATION_METHOD_NAME: "compressed-tensors", SPARSITY_CONFIG_NAME: {}, TRANSFORM_CONFIG_NAME: {}, - "format": scheme.format, **qconfig_data, } # write results to config.json file - config_file_path = find_config_path(save_directory) if config_file_path is not None: with open(config_file_path, "r") as file: config_data = json.load(file) @@ -60,29 +63,80 @@ def update_config( else: logger.warning( - f"Could not find config file in {save_directory}. " - f"Please {json.dumps(qconfig_data, indent=2, sort_keys=True)}" + f"Could not find config file in {save_directory}. Please set " + "quantization_config to: \n" + f"{json.dumps(qconfig_data, indent=2, sort_keys=True)}" ) -def update_safetensors_index( - save_directory: str | os.PathLike, - total_size: int, - weight_map: dict[str, str], -): - file_path = find_safetensors_index_path(save_directory) - if file_path is None: - return +def create_or_update_quant_config( + config_file_path: str | None, + scheme_name: str, + scheme: QuantizationScheme, + ignore: list[str], + converter: Converter | None = None, +) -> QuantizationConfig: + """ + Create or update quantization_config in 3 possible ways: + 1) If converting from a format that isn't compressed-tensors, + create new quant config based on converter and append scheme + 2) If checkpoint is in a pre-existing compressed-tensors format, + use its quantization_config as starting point and append scheme + 3) Otherwise, create from scratch based on scheme + """ - with open(file_path, "w") as file: - json.dump( + qconfig = None + if converter is not None: + # original checkpoint is not in compressed-tensors format + # assume quantization_config needs be created from scratch + qconfig = converter.create_config() + elif config_file_path is not None: + # load up quantization_config, if pre-existing compressed-tensors + # format exists, append to it instead of creating from scratch + with open(config_file_path, "r") as file: + config_data = json.load(file) + + if QUANTIZATION_CONFIG_NAME in config_data: + qconfigdata = config_data[QUANTIZATION_CONFIG_NAME] + # version in json but not allowed in QuantizationConfig + qconfigdata.pop(COMPRESSION_VERSION_NAME, None) + try: + qconfig = QuantizationConfig.model_validate(qconfigdata) + except ValidationError as e: + logger.warning( + "Unable to parse original checkpoint quantization_config. " + f"The quantization_config will be created from scratch: {e}" + ) + else: + logger.info( + "No pre-existing quantization_config found. " + "The quantization_config will be created from scratch" + ) + + if qconfig is None: + # construct quantization config from scratch + qconfig = QuantizationConfig.model_validate( { - "metadata": { - "total_size": total_size, - }, - "weight_map": weight_map, - }, - file, - indent=2, - sort_keys=True, + "config_groups": {scheme_name: scheme}, + "ignore": ignore, + "quantization_status": QuantizationStatus.COMPRESSED, + "format": scheme.format, + } ) + else: + # update pre-existing quantization config + scheme_name = ( + f"config_group_{len(qconfig.config_groups)}" + if scheme_name in qconfig.config_groups + else scheme_name + ) + qconfig.config_groups[scheme_name] = scheme + unique_formats = set(scheme.format for scheme in qconfig.config_groups.values()) + qconfig.format = ( + next(iter(unique_formats)) + if len(unique_formats) == 1 + else CompressionFormat.mixed_precision.value + ) + qconfig.quantization_status = QuantizationStatus.COMPRESSED + + return qconfig diff --git a/src/llmcompressor/entrypoints/model_free/validate.py b/src/llmcompressor/entrypoints/model_free/validate.py index c27b782987..390706fad5 100644 --- a/src/llmcompressor/entrypoints/model_free/validate.py +++ b/src/llmcompressor/entrypoints/model_free/validate.py @@ -47,7 +47,8 @@ def validate_scheme(scheme: QuantizationScheme) -> tuple[str, QuantizationScheme # target all modules; filter by ignore list # technically this should be "re:.*", but vllm's # ct moe layer has a hard coded check for "Linear" - scheme.targets = ["Linear"] + if len(scheme.targets) == 0: + scheme.targets.append("Linear") return scheme_name, scheme diff --git a/tests/llmcompressor/entrypoints/__init__.py b/tests/llmcompressor/entrypoints/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/llmcompressor/entrypoints/model_free/__init__.py b/tests/llmcompressor/entrypoints/model_free/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/llmcompressor/entrypoints/model_free/test_convert_checkpoint.py b/tests/llmcompressor/entrypoints/model_free/test_convert_checkpoint.py new file mode 100644 index 0000000000..cd230884ec --- /dev/null +++ b/tests/llmcompressor/entrypoints/model_free/test_convert_checkpoint.py @@ -0,0 +1,106 @@ +import json + +import pytest +from compressed_tensors.entrypoints.convert import ( + ModelOptNvfp4Converter, + convert_checkpoint, +) +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationConfig, + QuantizationType, +) +from compressed_tensors.quantization.quant_scheme import NVFP4 + +from tests.testing_utils import requires_cadence + +# NOTE: This file contains tests for compressed_tensors.entrypoints.convert +# that are either long-running or involve larger models. They have been placed +# here to leverage llm-compressor's nightly testing CI/CD. + + +@requires_cadence("nightly") +def test_convert_checkpoint(tmp_path): + """ + Test that compressed-tensors convert_checkpoint entrypoint + can be run on a pre-existing modelopt checkpoint + """ + MODEL_ID = "nvidia/Qwen3-8B-NVFP4" + convert_outdir = tmp_path / "convert_out" + + right_targets = [ + r"re:.*mlp.*\.(gate_up|gate|up|down)_proj$", + r"re:.*self_attn.*\.(q|k|v|o)_proj$", + ] + wrong_targets = [ + r"re:.*mlp.*\.(gate_up|gate|up|down)_proj$", + r"re:.*self_attn.*\.(q|k|o)_proj$", + ] + right_kv_cache_scheme = QuantizationArgs( + num_bits=8, dynamic=False, type=QuantizationType.FLOAT + ) + wrong_kv_cache_scheme = None + + with pytest.raises(ValueError): + convert_checkpoint( + model_stub=MODEL_ID, + save_directory=convert_outdir, + converter=ModelOptNvfp4Converter( + targets=right_targets, + kv_cache_scheme=wrong_kv_cache_scheme, + ), + ) + + with pytest.raises(ValueError): + convert_checkpoint( + model_stub=MODEL_ID, + save_directory=convert_outdir, + converter=ModelOptNvfp4Converter( + targets=wrong_targets, + kv_cache_scheme=right_kv_cache_scheme, + ), + ) + + convert_checkpoint( + model_stub=MODEL_ID, + save_directory=convert_outdir, + converter=ModelOptNvfp4Converter( + targets=right_targets, + kv_cache_scheme=right_kv_cache_scheme, + ), + ) + + with open(convert_outdir / "config.json", "r") as f: + config = json.load(f) + + qconfig = QuantizationConfig.model_validate(config["quantization_config"]) + + assert qconfig.format == "nvfp4-pack-quantized" + assert qconfig.quant_method == "compressed-tensors" + assert len(qconfig.config_groups) == 1 + # assert weights and input_activations are a superset of what's in the NVFP4 preset + assert ( + qconfig.config_groups["config_group_0"].weights.model_dump().items() + >= NVFP4["weights"].model_dump().items() + ) + assert ( + qconfig.config_groups["config_group_0"].input_activations.model_dump().items() + >= NVFP4["input_activations"].model_dump().items() + ) + + with open(convert_outdir / "model.safetensors.index.json", "r") as f: + allowed_suffixes = [ + "weight", + "weight_scale", + "weight_packed", + "weight_global_scale", + "input_global_scale", + "k_scale", + "v_scale", + ] + data = json.load(f) + keys = data["weight_map"].keys() + for key in keys: + assert any( + key.endswith(suffix) for suffix in allowed_suffixes + ), f"Unexpected key found: {key}" diff --git a/tests/llmcompressor/pipelines/test_model_free_ptq.py b/tests/llmcompressor/entrypoints/model_free/test_model_free_ptq.py similarity index 56% rename from tests/llmcompressor/pipelines/test_model_free_ptq.py rename to tests/llmcompressor/entrypoints/model_free/test_model_free_ptq.py index 8decda7c05..d51acbc384 100644 --- a/tests/llmcompressor/pipelines/test_model_free_ptq.py +++ b/tests/llmcompressor/entrypoints/model_free/test_model_free_ptq.py @@ -3,7 +3,11 @@ import pytest import torch -from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, +) +from compressed_tensors.utils.match import match_name from safetensors.torch import load_file from llmcompressor import model_free_ptq, oneshot @@ -13,7 +17,7 @@ def _get_tiny_w4a16_quant(): return QuantizationScheme( - targets=["Linear"], + targets=["re:.*self_attn.(q|k|o|v)_proj"], weights=QuantizationArgs( num_bits=4, type="int", @@ -27,7 +31,7 @@ def _get_tiny_w4a16_quant(): def _get_tiny_block_quant(): return QuantizationScheme( - targets=["Linear"], + targets=["re:.*mlp.(down|gate|up)_proj"], weights=QuantizationArgs( num_bits=8, type="float", @@ -84,6 +88,57 @@ def test_model_free_ptq_matches_oneshot(scheme, tmp_path): _assert_config_equal(ptq_outdir / "config.json", oneshot_outdir / "config.json") +@requires_gpu +@pytest.mark.parametrize( + "schemes", + [(_get_tiny_w4a16_quant(), _get_tiny_block_quant())], +) +def test_stacked_model_free_ptq_matches_oneshot(schemes, tmp_path): + """ + Test that model_free_ptq can be stacked, also tests that + model_free_ptq can be run on a pre-existing CT checkpoint + """ + + model = "Qwen/Qwen3-0.6B" + ignore = ["model.embed_tokens", "lm_head"] + device = "cuda:0" + + ptq_outdirs = [tmp_path / f"weights_out_{idx}" for idx in range(len(schemes))] + oneshot_outdir = tmp_path / "oneshot_out" + + for idx, scheme in enumerate(schemes): + model_free_ptq( + model if idx == 0 else ptq_outdirs[idx - 1], + ptq_outdirs[idx], + scheme=scheme, + max_workers=2, + device=device, + ignore=ignore, + ) + + config_groups = { + f"config_group_{idx}": scheme for idx, scheme in enumerate(schemes) + } + recipe = QuantizationModifier(config_groups=config_groups, ignore=ignore) + + oneshot( + model=model, + precision="auto", + recipe=recipe, + output_dir=oneshot_outdir, + ) + + ptq_outdir = ptq_outdirs[-1] + ptq_st_files = _get_safetensors_files(ptq_outdir) + oneshot_st_files = _get_safetensors_files(oneshot_outdir) + assert set(ptq_st_files) == set(oneshot_st_files) + + for file_name in ptq_st_files: + _assert_safetensors_equal(ptq_outdir / file_name, oneshot_outdir / file_name) + + _assert_config_equal(ptq_outdir / "config.json", oneshot_outdir / "config.json") + + def _get_safetensors_files(dir_path: str) -> list[str]: return [ file_name @@ -101,7 +156,10 @@ def _assert_safetensors_equal(a_path: str, b_path: str) -> bool: if "lm_head.weight" in a and "lm_head.weight" not in b: del a["lm_head.weight"] - assert a.keys() == b.keys(), (a.keys() - b.keys(), b.keys() - a.keys()) + assert a.keys() == b.keys(), ( + sorted(a.keys() - b.keys()), + sorted(b.keys() - a.keys()), + ) for key in a.keys(): value_equal = torch.equal(a[key].to(torch.bfloat16), b[key].to(torch.bfloat16)) @@ -135,17 +193,24 @@ def _assert_config_equal(a_path: str, b_path: str): a_ignore = a_qconfig.pop("ignore") b_ignore = b_qconfig.pop("ignore") - assert set(b_ignore).issubset(set(a_ignore)) + # QuantizationModifier updates ignore lists with any non-targeted layers + # model_free_ptq does not. Rather than asserting sets are equal, + # confirm none conflict with targets + all_ignores = set(a_ignore).union(set(b_ignore)) + + assert len(a_config_groups) == len(b_config_groups) + a_schemes = list(a_config_groups.values()) + b_schemes = list(b_config_groups.values()) - assert len(a_config_groups) == 1 - assert len(b_config_groups) == 1 - a_scheme = list(a_config_groups.values())[0] - b_scheme = list(b_config_groups.values())[0] + for a_scheme, b_scheme in zip(a_schemes, b_schemes): + # TODO: remove this pop after + # https://github.com/vllm-project/compressed-tensors/pull/489 lands and + # src/llmcompressor/entrypoints/weights_ptq/helpers.py:34 is removed + a_scheme["weights"].pop("observer") + b_scheme["weights"].pop("observer") - # TODO: remove this pop after - # https://github.com/vllm-project/compressed-tensors/pull/489 lands and - # src/llmcompressor/entrypoints/weights_ptq/helpers.py:34 is removed - a_scheme["weights"].pop("observer") - b_scheme["weights"].pop("observer") + assert a_scheme == b_scheme - assert a_scheme == b_scheme + for ignore in all_ignores: + for target in a_scheme["targets"]: + assert not match_name(ignore, target) diff --git a/tests/llmcompressor/pipelines/test_model_free_validation.py b/tests/llmcompressor/entrypoints/model_free/test_model_free_validation.py similarity index 100% rename from tests/llmcompressor/pipelines/test_model_free_validation.py rename to tests/llmcompressor/entrypoints/model_free/test_model_free_validation.py diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py index e4e15d300f..55994f208b 100644 --- a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -66,10 +66,13 @@ def test_calib_deepseekv3_module(): config = DeepseekV3Config() with torch.device("cuda"): original = OriginalDeepseekV3MoE(config).eval() + for param in original.parameters(): + param.data.normal_(mean=0.0, std=0.02) # Create dummy input tensor that simulates hidden_states hidden_dim = config.hidden_size batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, device="cuda") with calibration_forward_context(original): @@ -78,9 +81,9 @@ def test_calib_deepseekv3_module(): module = CalibrationDeepseekV3MoE(original, config, calibrate_all_experts=True) with calibration_forward_context(module): output = module(sample) - assert torch.allclose(true_output, output, atol=1e-6) + assert torch.nn.functional.mse_loss(true_output, output) < 1e-10 module = CalibrationDeepseekV3MoE(original, config, calibrate_all_experts=False) with calibration_forward_context(module): output = module(sample) - assert torch.allclose(true_output, output, atol=1e-6) + assert torch.nn.functional.mse_loss(true_output, output) < 1e-10