-
Notifications
You must be signed in to change notification settings - Fork 453
feat: early group-size divisibility check with layer FQNs #2353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
HDCharles
merged 11 commits into
vllm-project:main
from
GOavi101:feat/group-size-divisibility-check
Feb 17, 2026
Merged
Changes from 3 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
bd0060f
feat(quantization): early group-size divisibility check with layer FQNs
c7b6436
test: remove redundant assert True and comment per review
cee116a
Merge branch 'main' into feat/group-size-divisibility-check
HDCharles ce10711
Merge branch 'main' into feat/group-size-divisibility-check
HDCharles b9a6f88
Merge branch 'main' into feat/group-size-divisibility-check
dsikka efda147
test: remove redundant assert True and comment per review
43594f0
Merge branch 'main' into feat/group-size-divisibility-check
61ada1f
Revert unrelated files to upstream main (PR shows only group-size div…
04be3ac
Merge branch 'main' into feat/group-size-divisibility-check
GOavi101 970890a
Merge branch 'main' into feat/group-size-divisibility-check
HDCharles 627410c
Merge branch 'main' into feat/group-size-divisibility-check
HDCharles File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
144 changes: 144 additions & 0 deletions
144
src/llmcompressor/modifiers/quantization/group_size_validation.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| """ | ||
| Early validation for divisibility requirements by quantization strategy. | ||
|
|
||
| Different kernels support different divisibility rules. This module encodes | ||
| which strategies require strict divisibility (and thus error early with layer | ||
| names) vs which do not. | ||
|
|
||
| Policy (single source of truth for "error vs warn vs skip"): | ||
|
|
||
| - GROUP, TENSOR_GROUP: Runtime/save kernels require columns % group_size == 0. | ||
| We ERROR at initialize with the list of affected layer FQNs so users can add | ||
| them to `ignore` before long calibration (e.g. GPTQ). No kernel support for | ||
| non-divisible today. | ||
|
|
||
| - BLOCK: Block kernels support non-divisible dimensions (e.g. strategy_cdiv | ||
| with strict=False). We do NOT check or warn for block. | ||
|
|
||
| - CHANNEL, TENSOR, TOKEN, ATTN_HEAD: No group_size divisibility requirement; | ||
| we do not run this validation. | ||
|
|
||
| See: compressed-tensors forward.py (GROUP/TENSOR_GROUP ValueError), | ||
| strategy_cdiv in compressed_tensors.quantization.utils.helpers. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Set, Tuple | ||
|
|
||
| import torch | ||
| from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy | ||
| from compressed_tensors.utils import match_named_modules | ||
|
|
||
| __all__ = [ | ||
| "_layer_indivisible", | ||
| "get_layers_indivisible_by_group_size", | ||
| "validate_group_size_divisibility", | ||
| ] | ||
|
|
||
| # Strategies for which we error on indivisible columns (no kernel support). | ||
| # BLOCK is intentionally excluded: block kernels support non-divisible. | ||
| _GROUP_STRATEGY_STRINGS = ("group", "tensor_group") | ||
|
|
||
|
|
||
| def _is_group_or_tensor_group_strategy(strategy) -> bool: | ||
GOavi101 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """True if strategy is GROUP or TENSOR_GROUP (enum or string).""" | ||
| if strategy is None: | ||
| return False | ||
| if strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): | ||
| return True | ||
| for attr in ("value", "name"): | ||
| s = getattr(strategy, attr, None) | ||
| if s is not None and str(s).lower() in _GROUP_STRATEGY_STRINGS: | ||
| return True | ||
| s = str(strategy).lower() | ||
| if s in _GROUP_STRATEGY_STRINGS: | ||
| return True | ||
| # Enum repr e.g. "quantizationstrategy.group" | ||
| if s.split(".")[-1] in _GROUP_STRATEGY_STRINGS: | ||
| return True | ||
| return False | ||
GOavi101 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def _layer_indivisible(module: torch.nn.Module, weight_args) -> Tuple[int, int] | None: | ||
| """ | ||
| If module has group/tensor_group weight and columns % group_size != 0, | ||
| return (columns, group_size); else return None. | ||
| """ | ||
| if not _is_group_or_tensor_group_strategy(getattr(weight_args, "strategy", None)): | ||
| return None | ||
| group_size = getattr(weight_args, "group_size", None) | ||
| if group_size is None: | ||
| return None | ||
| if not hasattr(module, "weight"): | ||
| return None | ||
| columns = int(module.weight.shape[-1]) | ||
| group_size = int(group_size) | ||
| if columns >= group_size and columns % group_size != 0: | ||
| return (columns, group_size) | ||
| return None | ||
|
|
||
|
|
||
| def get_layers_indivisible_by_group_size( | ||
| model: torch.nn.Module, | ||
| resolved_targets: Set[str], | ||
| ignore: list[str], | ||
| ) -> list[Tuple[str, int, int]]: | ||
| """ | ||
| Find targeted layers whose weight columns are not divisible by group_size. | ||
|
|
||
| Only considers layers whose weight scheme is GROUP or TENSOR_GROUP | ||
| (by value; enum or string). BLOCK and other strategies are not checked. | ||
| Matches the condition | ||
| that triggers ValueError in compressed_tensors forward.py (columns >= | ||
| group_size and columns % group_size != 0). | ||
|
|
||
| :param model: Model with quantization schemes already applied (e.g. after | ||
| apply_quantization_config). | ||
| :param resolved_targets: Target module name patterns (e.g. from | ||
| QuantizationMixin.resolved_targets). | ||
| :param ignore: Module name patterns to exclude (e.g. QuantizationMixin.ignore). | ||
| :return: List of (fqn, columns, group_size) for each layer that would | ||
| fail at save/forward due to indivisibility. | ||
| """ | ||
| indivisible: list[Tuple[str, int, int]] = [] | ||
| for name, module in match_named_modules(model, resolved_targets, ignore): | ||
| scheme: QuantizationScheme | None = getattr(module, "quantization_scheme", None) | ||
| if scheme is None or scheme.weights is None: | ||
| continue | ||
| result = _layer_indivisible(module, scheme.weights) | ||
| if result is not None: | ||
| columns, group_size = result | ||
| indivisible.append((name, columns, group_size)) | ||
| return indivisible | ||
|
|
||
|
|
||
| def validate_group_size_divisibility( | ||
| model: torch.nn.Module, | ||
| resolved_targets: Set[str], | ||
| ignore: list[str], | ||
| *, | ||
| bypass: bool = False, | ||
| ) -> None: | ||
| """ | ||
| Ensure targeted group/tensor_group layers have columns divisible by group_size. | ||
|
|
||
| If any such layer has columns % group_size != 0, raises ValueError with layer FQNs. | ||
| When bypass is True, skips the check (e.g. for runtimes that support non-divisible). | ||
| """ | ||
| if bypass: | ||
| return | ||
| indivisible = get_layers_indivisible_by_group_size(model, resolved_targets, ignore) | ||
| if not indivisible: | ||
| return | ||
| lines = [ | ||
| f" - {fqn} (columns={cols}, group_size={gs})" for fqn, cols, gs in indivisible | ||
| ] | ||
| raise ValueError( | ||
| "The following layers have weight column counts not divisible by " | ||
| "group_size. Group and tensor-group quantization require " | ||
| "columns % group_size == 0; compressed-tensors will error when saving " | ||
| "or running forward. Add these layer names to the modifier's `ignore` " | ||
| "list and re-run, or set bypass_divisibility_checks=True if your " | ||
| "runtime (e.g. vLLM) supports non-divisible dimensions.\n\n" + "\n".join(lines) | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
152 changes: 152 additions & 0 deletions
152
tests/llmcompressor/modifiers/quantization/test_group_size_validation.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| """Tests for early group-size divisibility validation.""" | ||
|
|
||
| import types | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from llmcompressor.core import State | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
| from llmcompressor.modifiers.quantization.group_size_validation import ( | ||
| _layer_indivisible, | ||
| get_layers_indivisible_by_group_size, | ||
| ) | ||
|
|
||
|
|
||
| def _make_tiny_model(columns: int, divisible_columns: int | None = None): | ||
| """Model with one Linear with columns, optionally another with divisible_columns.""" | ||
| linears = {"indiv": torch.nn.Linear(64, columns)} | ||
| if divisible_columns is not None: | ||
| linears["div"] = torch.nn.Linear(64, divisible_columns) | ||
| return torch.nn.ModuleDict(linears) | ||
|
|
||
|
|
||
| class _FlatModel(torch.nn.Module): | ||
| """Single top-level Linear so match_named_modules and scheme attach reliably.""" | ||
|
|
||
| def __init__(self, in_features: int, out_features: int): | ||
| super().__init__() | ||
| self.linear = torch.nn.Linear(in_features, out_features) | ||
|
|
||
|
|
||
| def test_get_layers_indivisible_by_group_size_empty(): | ||
| """When all layers are divisible, helper returns empty list.""" | ||
| from compressed_tensors.quantization import ( | ||
| QuantizationConfig, | ||
| QuantizationScheme, | ||
| QuantizationStatus, | ||
| apply_quantization_config, | ||
| ) | ||
| from compressed_tensors.quantization.quant_args import QuantizationArgs | ||
|
|
||
| model = _make_tiny_model(128) # 128 % 128 == 0 | ||
| scheme = QuantizationScheme( | ||
| targets=["Linear"], | ||
| weights=QuantizationArgs(strategy="group", group_size=128), | ||
| ) | ||
| config = QuantizationConfig( | ||
| config_groups={"g": scheme}, | ||
| kv_cache_scheme=None, | ||
| quantization_status=QuantizationStatus.INITIALIZED, | ||
| ignore=[], | ||
| ) | ||
| apply_quantization_config(model, config) | ||
| out = get_layers_indivisible_by_group_size(model, {"Linear"}, []) | ||
| assert out == [] | ||
|
|
||
|
|
||
| def test_get_layers_indivisible_by_group_size_finds_layer(): | ||
| """_layer_indivisible and get_layers_indivisible_by_group_size find indivisible.""" | ||
| from compressed_tensors.quantization import QuantizationScheme | ||
| from compressed_tensors.quantization.quant_args import QuantizationArgs | ||
|
|
||
| # 1) Unit test: _layer_indivisible with a simple args object (no CT QuantizationArgs | ||
| # attribute quirks; tests our logic in isolation). | ||
| # Linear(in_features, out_features) has weight.shape = (out_features, in_features); | ||
| # we use shape[-1] (columns) for group divisibility, so use in_features=200. | ||
| linear = torch.nn.Linear( | ||
| 200, 64 | ||
| ) # weight.shape=(64,200) -> columns=200, 200%128!=0 | ||
| weight_args_mock = types.SimpleNamespace(strategy="group", group_size=128) | ||
| result = _layer_indivisible(linear, weight_args_mock) | ||
| assert result is not None | ||
| cols, gs = result | ||
| assert cols == 200 | ||
| assert gs == 128 | ||
|
|
||
| # 2) Integration: full helper (requires match_named_modules to yield the layer) | ||
| # Same column count: linear with in_features=200 so weight.shape[-1]=200. | ||
| weight_args = QuantizationArgs(strategy="group", group_size=128) | ||
| model = _FlatModel(200, 64) | ||
| scheme = QuantizationScheme(targets=["Linear"], weights=weight_args) | ||
| model.linear.quantization_scheme = scheme | ||
| out = get_layers_indivisible_by_group_size(model, {"Linear"}, []) | ||
| if len(out) == 0: | ||
| # CT may not yield for simple models; unit test above covers logic | ||
| pytest.skip( | ||
| "match_named_modules yielded no modules; use full model for integration" | ||
| ) | ||
| fqn, cols, gs = out[0] | ||
| assert "linear" in fqn | ||
| assert cols == 200 | ||
| assert gs == 128 | ||
|
|
||
|
|
||
| def test_initialize_quantization_raises_early_for_indivisible(): | ||
| """Modifier raises at on_initialize with clear message and layer names.""" | ||
| model = _FlatModel(200, 64) # weight.shape[-1]=200, 200 % 128 != 0 | ||
| state = State() | ||
| state.update(model=model, device="cpu") | ||
| modifier = QuantizationModifier(scheme="W4A16", targets=["Linear"]) | ||
|
|
||
| with torch.no_grad(): | ||
| try: | ||
| modifier.on_initialize(state) | ||
| pytest.skip( | ||
| "no indivisible layers targeted (CT may not attach to simple models)" | ||
| ) | ||
| except ValueError as exc: | ||
GOavi101 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| msg = str(exc) | ||
| assert "columns" in msg.lower() and "group_size" in msg.lower() | ||
| assert "ignore" in msg.lower() | ||
| assert "bypass_divisibility_checks" in msg | ||
| assert "200" in msg and "128" in msg | ||
|
|
||
|
|
||
| def test_initialize_quantization_succeeds_when_indivisible_ignored(): | ||
| """When indivisible layer is in ignore list, on_initialize does not raise.""" | ||
| model = _FlatModel( | ||
| 200, 64 | ||
| ) # columns=200 indivisible by 128, but we ignore the layer | ||
| state = State() | ||
| state.update(model=model, device="cpu") | ||
| modifier = QuantizationModifier( | ||
| scheme="W4A16", targets=["Linear"], ignore=["linear"] | ||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| modifier.on_initialize(state) | ||
|
|
||
|
|
||
| def test_initialize_quantization_succeeds_when_bypass_divisibility_checks(): | ||
| """bypass_divisibility_checks=True: on_initialize does not raise for indivisible.""" | ||
| model = _FlatModel(200, 64) # columns=200 indivisible by 128 | ||
| state = State() | ||
| state.update(model=model, device="cpu") | ||
| modifier = QuantizationModifier( | ||
| scheme="W4A16", targets=["Linear"], bypass_divisibility_checks=True | ||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| modifier.on_initialize(state) | ||
|
|
||
|
|
||
| def test_initialize_quantization_succeeds_when_all_divisible(): | ||
| """When all targeted layers have columns % group_size == 0, no error.""" | ||
| model = _make_tiny_model(256) # 256 % 128 == 0 | ||
| state = State() | ||
| state.update(model=model, device="cpu") | ||
| modifier = QuantizationModifier(scheme="W4A16", targets=["Linear"]) | ||
|
|
||
| with torch.no_grad(): | ||
| modifier.on_initialize(state) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.