Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions examples/model_free_ptq/qwen3_fp8_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from llmcompressor import model_free_ptq

MODEL_ID = "Qwen/Qwen3-0.6B"
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-BLOCK"

# Apply FP8-Block to the model
# Once quantized, the model is saved
# using compressed-tensors to the SAVE_DIR.
model_free_ptq(
model_stub=MODEL_ID,
save_directory=SAVE_DIR,
scheme="FP8_BLOCK",
ignore=[
"model.embed_tokens",
"lm_head",
],
max_workers=15,
device="cuda:0",
)
15 changes: 11 additions & 4 deletions src/llmcompressor/entrypoints/model_free/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional
from typing import Iterable, Optional

import torch
import tqdm
Expand All @@ -20,6 +20,7 @@
from llmcompressor.entrypoints.model_free.process import (
process_file,
process_file_microscale_scheme,
validate_file,
)
from llmcompressor.entrypoints.model_free.save_utils import (
update_config,
Expand All @@ -37,7 +38,7 @@ def model_free_ptq(
model_stub: str | os.PathLike,
save_directory: str | os.PathLike,
scheme: QuantizationScheme | str,
ignore: Optional[list[str]] = None,
ignore: Iterable[str] = tuple(),
max_workers: int = 1,
device: Optional[torch.device | str] = None,
):
Expand Down Expand Up @@ -78,12 +79,18 @@ def model_free_ptq(
logger.info(f"Copying {file_path} {save_path}")
shutil.copyfile(resolved_path, save_path)

# 1-4. quantize and compress weights
with ThreadPoolExecutor(max_workers) as executor:
futures = [executor.submit(*job) for job in jobs]
# 1. validate quantizable tensors fail fast before long-running quantization
futures = [executor.submit(validate_file, *job[1:]) for job in jobs]
for future in tqdm.tqdm(
as_completed(futures), total=len(futures), desc="Validating"
):
future.result()

# 2-5. quantize and compress weights
total_size = 0
weight_map = dict()
futures = [executor.submit(*job) for job in jobs]
for future in tqdm.tqdm(
as_completed(futures), total=len(futures), desc="Quantizing"
):
Expand Down
17 changes: 17 additions & 0 deletions src/llmcompressor/entrypoints/model_free/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,32 @@
update_weight_global_scale,
update_weight_zp_scale,
)
from llmcompressor.observers.helpers import flatten_for_calibration

__all__ = [
"initialize_quantized_linear",
"validate_weight_for_quantization",
"calibrate_global_scale",
"calibrate_scale_zp",
"compress_module",
]


def validate_weight_for_quantization(
weight: torch.Tensor, scheme: QuantizationScheme, tensor_name: str
):
if weight.ndim != 2:
raise ValueError(
f"Unable to quantize tensor `{tensor_name}`: expected 2D linear weight, "
f"but got shape {tuple(weight.shape)}"
)

try:
flatten_for_calibration(weight, "weight", scheme.weights)
except Exception as exc:
raise ValueError(f"Unable to quantize tensor `{tensor_name}`: {exc}") from exc


def initialize_quantized_linear(
weight: torch.Tensor, scheme: QuantizationScheme, device: str | torch.device
) -> torch.nn.Module:
Expand Down
62 changes: 46 additions & 16 deletions src/llmcompressor/entrypoints/model_free/process.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from collections import defaultdict
from collections.abc import Iterator, Mapping
from typing import Iterable

import torch
from compressed_tensors.quantization import QuantizationScheme
Expand All @@ -12,20 +14,56 @@
calibrate_scale_zp,
compress_module,
initialize_quantized_linear,
validate_weight_for_quantization,
)
from llmcompressor.entrypoints.model_free.microscale import (
get_fused_names,
is_microscale_scheme,
)

__all__ = ["process_file", "process_file_microscale_scheme"]
__all__ = ["validate_file", "process_file", "process_file_microscale_scheme"]


def iter_quantizable_tensors(
tensors: Mapping[str, torch.Tensor],
ignore: Iterable[str],
) -> Iterator[tuple[str, str]]:
for name in list(tensors.keys()):
module_name, param_name = name.rsplit(".", 1)
is_linear_weight = param_name == "weight" and not module_name.endswith("norm")
is_ignored = any(_match_name(module_name, ign) for ign in ignore)
if not is_linear_weight or is_ignored:
continue

yield module_name, name


def validate_file(
file_path: str | os.PathLike,
save_path: str | os.PathLike,
scheme: QuantizationScheme,
ignore: Iterable[str],
device: str | torch.device,
):
"""
Validate that each quantizable tensor in a safetensors file can be quantized.

:param file_path: safetensors file to validate
:param scheme: quantization scheme to apply to tensors
:param ignore: modules to ignore. Modules ending with "norm" are automatically
ignored
"""
tensors = load_file(file_path)

for _, name in iter_quantizable_tensors(tensors, ignore):
validate_weight_for_quantization(tensors[name], scheme, name)


def process_file(
file_path: str | os.PathLike,
save_path: str | os.PathLike,
scheme: QuantizationScheme,
ignore: str | list[str],
ignore: Iterable[str],
device: str | torch.device,
) -> tuple[int, dict[str, str]]:
"""
Expand All @@ -41,12 +79,8 @@ def process_file(
assert not is_microscale_scheme(scheme), "Use `_process_file_microscale_scheme`"
tensors = load_file(file_path)

for name in list(tensors.keys()):
module_name, param_name = name.rsplit(".", 1)
is_linear_weight = param_name == "weight" and not module_name.endswith("norm")
is_ignored = any(_match_name(module_name, ign) for ign in ignore)
if not is_linear_weight or is_ignored:
continue
for module_name, name in iter_quantizable_tensors(tensors, ignore):
validate_weight_for_quantization(tensors[name], scheme, name)

# 1. initialize module with qparams (on device)
module = initialize_quantized_linear(tensors[name], scheme, device)
Expand All @@ -73,7 +107,7 @@ def process_file_microscale_scheme(
file_path: str | os.PathLike,
save_path: str | os.PathLike,
scheme: QuantizationScheme,
ignore: str | list[str],
ignore: Iterable[str],
device: str | torch.device,
) -> tuple[int, dict[str, str]]:
"""
Expand Down Expand Up @@ -101,12 +135,8 @@ def process_file_microscale_scheme(
}
fused_modules = defaultdict(dict)

for name in list(tensors.keys()):
module_name, param_name = name.rsplit(".", 1)
is_linear_weight = param_name == "weight" and not module_name.endswith("norm")
is_ignored = any(_match_name(module_name, ign) for ign in ignore)
if not is_linear_weight or is_ignored:
continue
for module_name, name in iter_quantizable_tensors(tensors, ignore):
validate_weight_for_quantization(tensors[name], scheme, name)

# 1. initialize module with qparams (on device)
module = initialize_quantized_linear(tensors[name], scheme, device)
Expand Down Expand Up @@ -136,7 +166,7 @@ def process_file_microscale_scheme(
fused_global_scale = torch.min(torch.cat(global_scales, dim=0))

for name, module in named_modules.items():
module_name, param_name = name.rsplit(".", 1)
module_name, _ = name.rsplit(".", 1)
module.weight_global_scale.data.copy_(fused_global_scale)

# 2.2. finish calibration with fused global scales
Expand Down
39 changes: 39 additions & 0 deletions tests/llmcompressor/pipelines/test_model_free_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
from safetensors.torch import save_file

from llmcompressor.entrypoints.model_free.process import validate_file


def _get_block_scheme() -> QuantizationScheme:
return QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=8,
type="float",
strategy="block",
symmetric=True,
dynamic=False,
block_structure=[16, 16],
),
)


def test_validate_file_raises_for_non_2d_linear_weight(tmp_path):
path = tmp_path / "bad_shape.safetensors"
save_file({"model.layers.0.mlp.down_proj.weight": torch.ones(128)}, str(path))

with pytest.raises(ValueError, match="model.layers.0.mlp.down_proj.weight"):
validate_file(path, None, _get_block_scheme(), [], None)


def test_validate_file_raises_for_block_incompatible_shape(tmp_path):
path = tmp_path / "bad_block.safetensors"
save_file(
{"model.layers.0.mlp.down_proj.weight": torch.ones(17, 16)},
str(path),
)

with pytest.raises(ValueError, match="strict division"):
validate_file(path, None, _get_block_scheme(), [], None)