Skip to content

Commit 27303c4

Browse files
[Multi-modifier] Support scoped application of quantization config/status (#1772)
SUMMARY: Prerequisites: * neuralmagic/compressed-tensors#432 This allows for multi-modifier support by scoping the application of quantization config/status to only the modules in the model that match the given targets/ignore configuration, rather than all modules. Initialization of observers is moved to on_start (instead of on_initialize) to match their removal on_end (and not on_finalize). This prevents collision during the multi-modifier lifecycle - [x] Update AWQ - [x] Update QuantizationModifier - [x] Update QuantizationMixin - [x] Update GPTQ - [x] No other quantization modifiers exist TEST PLAN: - Tests were added to neuralmagic/compressed-tensors#432 to confirm correct application of multiple modifiers. - Added an example in this PR to show how AWQ and GPTQ can be applied heterogeneously to a model, along with a small README. Logs show alternating AWQ and GPTQ messages for `"sequential"`, and correct behavior for `"independent"` pipelines. [Model checkpoint](https://huggingface.co/nm-testing/Meta-Llama-3-8B-Instruct-selfattn-w8a8-mlp-w4a16-sequential/tree/main) for the sequential pipeline shows correct application of W8A8 to self_attn layers and W4A16 to mlp layers. config.json and safetensors weights all look as expected --------- Signed-off-by: Brian Dellabetta <[email protected]>
1 parent f19deaa commit 27303c4

File tree

12 files changed

+189
-38
lines changed

12 files changed

+189
-38
lines changed

examples/quantization_non_uniform/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,12 @@ We demonstrate mixed precision by quantizing models to both int8 and int4, and i
99
## Multiple Strategies
1010

1111
It may also be interesting to quantize a model with two different [quantization strategies](https://github.com/neuralmagic/compressed-tensors/blob/a2bfc03e9d52824ba5d6d2a50c8741dd9bccd5d3/src/compressed_tensors/quantization/quant_args.py#L93) such as group, channel, or per-tensor. [Here](https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_non_uniform/quantization_fp8_multiple_strategies.py) we apply fp8 quantization where all the attention weights are quantized using the per-channel strategy, and all the mlp weights are quantized using per-tensor. This is accomplished through defining multiple config groups in the recipe. The produced model is compressed using the `float-quantized` compressor and can be directly run in vllm.
12+
13+
## Quantization with Multiple Quantization Modifiers
14+
15+
This section outlines how multiple quantization modifiers can be applied to the same model for mixed-precision quantization, for example applying AWQ W4A16 to a model's `self_attn` layers and GPTQ W8A8 to its `mlp` layers. This heterogeneous application of multiple modifiers comes in 2 flavors:
16+
17+
1. Run every modifier in a single, sequential pipeline, performing a single calibrated run. See `./quantization_multiple_modifiers.py` for an example.
18+
2. Run each modifier in its own, independent pipeline, performing a calibrated run for each modifier. To run each modifier independently, run `./quantization_multiple_modifiers.py` with `oneshot(..., pipeline="independent")` instead of `pipeline="sequential"`.
19+
20+
This is an advanced usage of `llm-compressor` and an active area of research. Best practices will be provided in a future release, after further research and sensitivity analysis.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.awq import AWQMapping, AWQModifier
6+
from llmcompressor.modifiers.quantization import GPTQModifier
7+
from llmcompressor.utils import dispatch_for_generation
8+
9+
# Select model and load it.
10+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
11+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
12+
tokenizer = AutoTokenizer.from_pretrained(model_id)
13+
14+
# Select calibration dataset.
15+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
16+
DATASET_SPLIT = "train_sft"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Load dataset and preprocess.
24+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
25+
ds = ds.shuffle(seed=42)
26+
27+
28+
def preprocess(example):
29+
return {
30+
"text": tokenizer.apply_chat_template(
31+
example["messages"],
32+
tokenize=False,
33+
)
34+
}
35+
36+
37+
ds = ds.map(preprocess)
38+
39+
40+
# Tokenize inputs.
41+
def tokenize(sample):
42+
return tokenizer(
43+
sample["text"],
44+
padding=False,
45+
max_length=MAX_SEQUENCE_LENGTH,
46+
truncation=True,
47+
add_special_tokens=False,
48+
)
49+
50+
51+
ds = ds.map(tokenize, remove_columns=ds.column_names)
52+
53+
# Configure the quantization algorithm to run.
54+
# * quantize self_attn layers to W8A8 with GPTQ
55+
# * quantize mlp layers to W4A16 with AWQ
56+
# only include mappings pertaining to target layers
57+
recipe = [
58+
GPTQModifier(targets=r"re:.*self_attn\.(k|q|o|v)_proj$", scheme="W8A8"),
59+
AWQModifier(
60+
targets=r"re:.*mlp\.(down|gate|up)_proj$",
61+
mappings=[
62+
AWQMapping(
63+
"re:.*post_attention_layernorm$",
64+
["re:.*gate_proj$", "re:.*up_proj$"],
65+
),
66+
AWQMapping(
67+
"re:.*up_proj$",
68+
["re:.*down_proj$"],
69+
),
70+
],
71+
scheme="W4A16",
72+
),
73+
]
74+
75+
# Apply algorithms.
76+
oneshot(
77+
model=model,
78+
dataset=ds,
79+
recipe=recipe,
80+
max_seq_length=MAX_SEQUENCE_LENGTH,
81+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
82+
# Option 1) run both modifiers in a single calibrated run
83+
pipeline="sequential",
84+
# Option 2) run each modifier in its own separate pipeline
85+
# pipeline="independent",
86+
)
87+
88+
# Confirm generations of the quantized model look sane.
89+
print("\n\n")
90+
print("========== SAMPLE GENERATION ==============")
91+
dispatch_for_generation(model)
92+
sample = tokenizer("Hello my name is", return_tensors="pt")
93+
sample = {key: value.to(model.device) for key, value in sample.items()}
94+
output = model.generate(**sample, max_new_tokens=100)
95+
print(tokenizer.decode(output[0]))
96+
print("==========================================\n\n")
97+
98+
# Save to disk compressed.
99+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-gptq-w8a8-self_attn-awq-w4a16-mlp"
100+
model.save_pretrained(SAVE_DIR, save_compressed=True)
101+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modifiers/awq/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,10 @@ def on_end(self, state: State, event: Event, **kwargs):
265265

266266
self.ended_ = True
267267

268-
modules = list(state.model.modules())
269-
for module in tqdm(modules, desc="Calibrating weights"):
268+
for _, module in tqdm(
269+
match_named_modules(state.model, self.targets, self.ignore),
270+
desc="Calibrating weights",
271+
):
270272
update_weight_zp_scale(module)
271273

272274
QuantizationMixin.end_calibration(self, state.model)

src/llmcompressor/modifiers/awq/mappings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ class AWQMapping:
157157
"Phi3ForCausalLM": _phi_mappings,
158158
"Phi3VForCausalLM": _phi_mappings,
159159
"Qwen2ForCausalLM": _default_mappings,
160+
"Qwen2_5OmniThinkerForConditionalGeneration": _default_mappings,
160161
"Qwen2MoeForCausalLM": _moe_default_mappings,
161162
"Qwen3ForCausalLM": _default_mappings,
162163
"Qwen3MoeForCausalLM": _moe_default_mappings,

src/llmcompressor/modifiers/quantization/cache.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010
from typing import Any, Dict, List, Optional, Tuple
1111

12-
from compressed_tensors.quantization import KVCacheScaleType
13-
from compressed_tensors.quantization.quant_args import QuantizationArgs
12+
from compressed_tensors.quantization import KVCacheScaleType, QuantizationArgs
1413
from torch import Tensor
1514
from transformers import DynamicCache
1615

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
align_module_device,
1414
get_execution_device,
1515
getattr_chain,
16+
match_named_modules,
1617
update_offload_parameter,
1718
)
1819
from loguru import logger
@@ -161,7 +162,10 @@ def on_initialize(self, state: State, **kwargs) -> bool:
161162
QuantizationMixin.initialize_quantization(self, state.model)
162163

163164
# prepare module names
164-
self._module_names = {m: name for name, m in state.model.named_modules()}
165+
self._module_names = {
166+
m: name
167+
for name, m in match_named_modules(state.model, self.targets, self.ignore)
168+
}
165169

166170
return True
167171

@@ -174,7 +178,7 @@ def on_start(self, state: State, event: Event, **kwargs):
174178

175179
# register gptq hooks
176180
added_hook = False
177-
for module in state.model.modules():
181+
for _, module in match_named_modules(state.model, self.targets, self.ignore):
178182
if getattr_chain(module, "quantization_scheme.weights", None) is not None:
179183
# HACK: previously, embeddings were not quantized because they were not
180184
# accessible by the layer compressor. For now, we manually ignore it,

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import tqdm
2+
from compressed_tensors.utils import match_named_modules
23

34
from llmcompressor.core import Event, EventType, State
45
from llmcompressor.modifiers import Modifier
@@ -69,14 +70,16 @@ def on_start(self, state: State, event: Event, **kwargs):
6970
self.started_ = True
7071
QuantizationMixin.start_calibration(self, state.model)
7172

72-
modules = list(state.model.modules())
73+
named_modules = list(
74+
match_named_modules(state.model, self.targets, self.ignore)
75+
)
7376
# TODO: this step can be combined with update_weight_zp_scale
7477
# once update_fused_layer_weight_global_scales is removed
7578
# and not required by vLLM
76-
for module in tqdm.tqdm(modules):
79+
for _, module in tqdm.tqdm(named_modules):
7780
update_weight_global_scale(module)
7881

79-
for module in tqdm.tqdm(modules, desc="Calibrating weights"):
82+
for _, module in tqdm.tqdm(named_modules, desc="Calibrating weights"):
8083
update_fused_layer_weight_global_scales(module)
8184
update_weight_zp_scale(module)
8285

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
is_preset_scheme,
1515
preset_name_to_scheme,
1616
)
17-
from pydantic import Field, PrivateAttr, field_validator
17+
from compressed_tensors.utils import match_named_modules
18+
from pydantic import Field, PrivateAttr, field_validator, model_validator
1819
from torch.utils.hooks import RemovableHandle
1920

2021
from llmcompressor.modifiers.quantization.calibration import (
@@ -58,8 +59,9 @@ class QuantizationMixin(HooksMixin):
5859
5960
:param config_groups: dictionary specifying quantization schemes to apply to target
6061
modules. Modules not matching a scheme target will NOT be quantized.
61-
:param targets: list of layer names to quantize if a scheme is provided. Defaults
62-
to Linear layers
62+
:param targets: list of layer names to quantize if a scheme is provided. If unset,
63+
will contain all targets listed in config_groups. If config_groups is also
64+
unset, will default to ["Linear"] (i.e. all Linear layers will be targeted).
6365
:param ignore: optional list of module class names or submodule names to not
6466
quantize even if they match a target in config_groups. Defaults to empty list.
6567
:param scheme: a single quantization scheme to apply to the model. This is a
@@ -81,7 +83,7 @@ class QuantizationMixin(HooksMixin):
8183
"""
8284

8385
config_groups: Optional[Dict[str, QuantizationScheme]] = None
84-
targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"])
86+
targets: Union[str, List[str]] = Field(default_factory=list)
8587
ignore: List[str] = Field(default_factory=list)
8688
scheme: Optional[Union[str, Dict[str, Any]]] = None
8789
kv_cache_scheme: Optional[QuantizationArgs] = None
@@ -114,43 +116,71 @@ def validate_scheme(
114116

115117
return value
116118

119+
@model_validator(mode="after")
120+
def validate_model_after(model: "QuantizationMixin") -> "QuantizationMixin":
121+
"""
122+
- If targets have not been set, aggregate targets from config_groups
123+
into a single unique list
124+
- If targets have still not been found, default to targets=["Linear"]
125+
"""
126+
127+
if len(model.targets) > 0 and model.config_groups is not None:
128+
raise ValueError("Please specify either `targets` or `config_groups`")
129+
130+
if len(model.targets) == 0 and model.config_groups is not None:
131+
for config_group in model.config_groups.values():
132+
for target in config_group.targets:
133+
if target not in model.targets:
134+
model.targets.append(target)
135+
136+
if len(model.targets) == 0:
137+
model.targets.append("Linear")
138+
139+
return model
140+
117141
def initialize_quantization(self, model: torch.nn.Module):
118142
"""
119-
Attach quantization schemes and observers to modules in the model according to
143+
Attach quantization schemes to modules in the model according to
120144
the quantization config specified on this modifier
121145
122146
:param model: model to attach schemes and observers to
123147
"""
124-
reset_quantization_status(model) # reset any previously applied qconfigs
125-
126148
# apply scheme and status to model
127149
config = self.resolve_quantization_config()
150+
151+
for _, module in match_named_modules(model, self.targets, self.ignore):
152+
reset_quantization_status(module) # reset any previously applied qconfigs
153+
128154
apply_quantization_config(model, config)
129155

130-
# apply observers, disable quantization until calibration
131-
model.apply(self._initialize_observers)
156+
# disable quantization until calibration
132157
model.apply(disable_quantization)
133158

134159
def start_calibration(self, model: torch.nn.Module):
135160
"""
136-
Register activation calibration hooks (including kv_cache quantization) and
137-
enable quantization as we calibrate
161+
Attach observers, register activation calibration hooks (including
162+
kv_cache quantization) and enable quantization as we calibrate
138163
139164
:param model: model to prepare for calibration
140165
"""
141166
self._calibration_hooks = self._initialize_hooks(model)
142-
model.apply(apply_calibration_status)
167+
for _, module in match_named_modules(model, self.targets, self.ignore):
168+
self._initialize_observers(module)
169+
apply_calibration_status(module)
170+
143171
model.apply(enable_quantization) # quantize at the same time as calibrate
144172

145173
def end_calibration(self, model: torch.nn.Module):
146174
"""
147-
Remove calibration hooks and set the model status to frozen. Keep quantization
148-
enabled for future operations
175+
Remove calibration hooks and observers, and set the model status to frozen.
176+
Keep quantization enabled for future operations
149177
150178
:param model: model to end calibration for
151179
"""
152180
self.remove_hooks(self._calibration_hooks)
153-
model.apply(freeze_module_quantization) # remove observers
181+
for _, module in match_named_modules(model, self.targets, self.ignore):
182+
freeze_module_quantization(module) # remove observers
183+
154184
model.apply(enable_quantization) # keep quantization enabled
155185

156186
def has_config(self) -> bool:
@@ -240,7 +270,7 @@ def _initialize_observers(self, module: torch.nn.Module):
240270

241271
def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
242272
hooks = set()
243-
for module in model.modules():
273+
for _, module in match_named_modules(model, self.targets, self.ignore):
244274
if not hasattr(module, "quantization_scheme"):
245275
continue
246276

tests/llmcompressor/modifiers/calibration/test_frozen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_set_module_for_calibration():
3737
layer = Linear(4, 4)
3838

3939
initialize_module_for_quantization(layer, quantization_scheme)
40-
layer.quantization_status = QuantizationStatus("calibration")
40+
layer.quantization_status = QuantizationStatus.CALIBRATION
4141
initialize_observer(layer, "weight")
4242

4343
# should have both input and weight observer after initalizing
@@ -48,4 +48,4 @@ def test_set_module_for_calibration():
4848
assert not hasattr(layer, "input_observer")
4949
assert not hasattr(layer, "weight_observer")
5050

51-
assert layer.quantization_status == QuantizationStatus("frozen")
51+
assert layer.quantization_status == QuantizationStatus.FROZEN

tests/llmcompressor/modifiers/quantization/test_base.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,11 @@ def test_block_strategy_parsing(block_q_config_kwargs):
9595
def test_actorder_resolution(
9696
has_actorder, actorder, q_config_kwargs, expected_0, expected_1
9797
):
98-
if has_actorder:
99-
modifier = GPTQModifier(**q_config_kwargs, actorder=actorder)
100-
else:
101-
modifier = GPTQModifier(**q_config_kwargs)
102-
10398
with pytest.raises(ValueError) if expected_0 == "error" else nullcontext():
99+
if has_actorder:
100+
modifier = GPTQModifier(**q_config_kwargs, actorder=actorder)
101+
else:
102+
modifier = GPTQModifier(**q_config_kwargs)
104103
resolved = modifier.resolve_quantization_config()
105104

106105
if expected_0 != "error":
@@ -155,8 +154,8 @@ def test_config_resolution(strategies, actorder):
155154
)
156155
def test_serialize_actorder(has_actorder, actorder, exp_actorder):
157156
if has_actorder:
158-
modifier = GPTQModifier(targets=["Linear"], actorder=actorder)
157+
modifier = GPTQModifier(targets=["Linear"], scheme="W8A8", actorder=actorder)
159158
else:
160-
modifier = GPTQModifier(targets=["Linear"])
159+
modifier = GPTQModifier(targets=["Linear"], scheme="W8A8")
161160

162161
assert modifier.model_dump()["actorder"] == exp_actorder

0 commit comments

Comments
 (0)