Skip to content

Commit b78b052

Browse files
dsikkakylesayrs
andauthored
[Actorder] Fix GPTQ actorder logic, only apply actorder to weight group args (#1815)
SUMMARY: - Don't set actorder to static if running channel quantization - The modifier level actoder value is also getting serialized incorrectly - we set it to None if the strategy is not GROUP but still missing a step to serialize it correctlt when it is None - FIx test case which is using an incorrect activation quantization strategy - Update compress / decompress test case Testing - Fixes failing tests - There is still a bug where `Sentinel` values are not serialized correctly --------- Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Kyle Sayers <[email protected]>
1 parent 0a1c9be commit b78b052

File tree

3 files changed

+63
-24
lines changed

3 files changed

+63
-24
lines changed

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from typing import Dict, List, Optional, Tuple, Union
44

55
import torch
6-
from compressed_tensors.quantization import QuantizationConfig, QuantizationScheme
6+
from compressed_tensors.quantization import (
7+
QuantizationConfig,
8+
QuantizationScheme,
9+
QuantizationStrategy,
10+
)
711
from compressed_tensors.quantization.quant_args import ActivationOrdering
812
from compressed_tensors.utils import (
913
align_module_device,
@@ -107,6 +111,7 @@ class GPTQModifier(Modifier, QuantizationMixin):
107111
sequential_targets: Union[str, List[str], None] = None
108112
block_size: int = 128
109113
dampening_frac: Optional[float] = 0.01
114+
# TODO: this does not serialize / will be incorrectly written
110115
actorder: Optional[Union[ActivationOrdering, Sentinel]] = Sentinel("static")
111116
offload_hessians: bool = False
112117

@@ -149,9 +154,11 @@ def resolve_actorder(existing):
149154

150155
for scheme in config.config_groups.values():
151156
assert isinstance(scheme, QuantizationScheme)
152-
if scheme.weights is not None:
157+
if (
158+
getattr_chain(scheme, "weights.strategy", None)
159+
== QuantizationStrategy.GROUP
160+
):
153161
scheme.weights.actorder = resolve_actorder(scheme.weights.actorder)
154-
155162
return config
156163

157164
def on_initialize(self, state: State, **kwargs) -> bool:

tests/llmcompressor/modifiers/quantization/test_base.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from contextlib import nullcontext
22

33
import pytest
4+
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
45

56
from llmcompressor.modifiers.quantization import GPTQModifier
67

@@ -107,3 +108,37 @@ def test_actorder_resolution(
107108
assert resolved.config_groups["group_0"].weights.actorder == expected_0
108109
assert resolved.config_groups["group_1"].input_activations.actorder is None
109110
assert resolved.config_groups["group_1"].weights.actorder == expected_1
111+
112+
113+
@pytest.mark.parametrize(
114+
"strategies,actorder",
115+
[
116+
(["group"], None),
117+
(["group"], "static"),
118+
(["group"], "group"),
119+
(["channel", "group"], None),
120+
(["channel", "group"], "static"),
121+
(["channel", "group"], "group"),
122+
(["group", "channel"], None),
123+
(["group", "channel"], "static"),
124+
(["group", "channel"], "group"),
125+
],
126+
)
127+
def test_config_resolution(strategies, actorder):
128+
config_groups = {
129+
str(index): QuantizationScheme(
130+
targets=[],
131+
weights=QuantizationArgs(
132+
strategy=strategy, group_size=(128 if strategy == "group" else None)
133+
),
134+
)
135+
for index, strategy in enumerate(strategies)
136+
}
137+
138+
modifier = GPTQModifier(config_groups=config_groups, actorder=actorder)
139+
modifier.resolve_quantization_config()
140+
141+
# validate that actorder was applied
142+
for config_group in modifier.config_groups.values():
143+
if config_group.weights.strategy == "group":
144+
assert config_group.weights.actorder == actorder

tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,6 @@ def test_compressor_stacking(model_stub, recipe, sparse_format, quant_format, tm
348348
concatenate_data = False
349349
num_calibration_samples = 64
350350
splits = {"calibration": "train[:10%]"}
351-
empty_model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype="auto")
352351

353352
oneshot(
354353
model=model_stub,
@@ -357,29 +356,18 @@ def test_compressor_stacking(model_stub, recipe, sparse_format, quant_format, tm
357356
recipe=recipe,
358357
concatenate_data=concatenate_data,
359358
splits=splits,
360-
clear_sparse_session=False,
361359
)
362360

363361
# Fetch the oneshot model
364362
model = get_session_model()
365363
og_state_dict = model.state_dict()
366364
path = tmp_path / "compressed"
367365

368-
# Compress and save
369-
model.save_pretrained(
370-
path,
371-
quantization_format=quant_format,
372-
save_compressed=True,
373-
)
374-
375-
# Verify config on disk
376-
config = AutoConfig.from_pretrained(path)
377-
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
378-
quant_config = ModelCompressor.parse_quantization_config(compression_config)
379-
380366
# As HFQuantizer doesn't decompress the model, use the compressor to decompress
381367
# the model instead
382-
compressor = ModelCompressor.from_compression_config(compression_config)
368+
compressor = ModelCompressor.from_pretrained_model(
369+
model, sparsity_config=sparse_format, quantization_format=quant_format
370+
)
383371

384372
assert (
385373
compressor.sparsity_compressor is not None
@@ -389,16 +377,15 @@ def test_compressor_stacking(model_stub, recipe, sparse_format, quant_format, tm
389377
assert (
390378
compressor.quantization_compressor is not None
391379
), "Quantization compressor not initialized"
392-
assert quant_config["format"] == quant_format
393380

381+
compressor.compress_model(model)
382+
compressor.decompress_model(model)
394383
compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
395-
compressor.decompress(model_path=path, model=empty_model)
396384

397385
# Verify the abs difference between the decompressed model
398386
# and the original model
399-
reconstructed_state_dict = empty_model.state_dict()
400-
assert len(og_state_dict) == len(reconstructed_state_dict)
401-
for key in og_state_dict.keys():
387+
reconstructed_state_dict = model.state_dict()
388+
for key in reconstructed_state_dict.keys():
402389
dense_tensor = og_state_dict[key].to(device)
403390
reconstructed_tensor = reconstructed_state_dict[key].to(device)
404391
assert dense_tensor.dtype == reconstructed_tensor.dtype
@@ -409,6 +396,16 @@ def test_compressor_stacking(model_stub, recipe, sparse_format, quant_format, tm
409396
assert not torch.any(diff > 0.025), f"Max diff: {torch.max(diff)}"
410397
else:
411398
assert torch.equal(dense_tensor, reconstructed_tensor)
399+
400+
# Recompress and save; validate correct formats used
401+
model.save_pretrained(path)
402+
config = AutoConfig.from_pretrained(path)
403+
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
404+
quant_config = ModelCompressor.parse_quantization_config(compression_config)
405+
sparsity_config = ModelCompressor.parse_sparsity_config(compression_config)
406+
assert quant_config["format"] == quant_format
407+
assert sparsity_config["format"] == sparse_format
408+
412409
if os.path.isdir(tmp_path):
413410
shutil.rmtree(tmp_path)
414411

@@ -588,7 +585,7 @@ def _quantization_config_from_string(config_str, q_type):
588585
quantize_activations=quantize_activations,
589586
a_bits=a_bits,
590587
a_type=q_type,
591-
a_strategy="channel",
588+
a_strategy="tensor",
592589
)
593590

594591

0 commit comments

Comments
 (0)