-
Notifications
You must be signed in to change notification settings - Fork 250
[Multi-modifier] Support scoped application of quantization config/status #1772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9f0e0ac
14486af
ff5067a
f99db2f
5da7b6d
32ad8dc
faee70f
4db397b
1c7ae4d
64f8f39
a892d2b
1d3eceb
75c7ca6
81cf4a1
af6a34b
2cc681f
855606e
b25d23b
1a6eca7
50fbf15
a0568f7
5e5e0fe
6cd0350
a2377d9
1619337
e21a933
c437c6f
4326ee3
2611ac6
2ea5698
33695d5
170f04b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from llmcompressor import oneshot | ||
from llmcompressor.modifiers.awq import AWQMapping, AWQModifier | ||
from llmcompressor.modifiers.quantization import GPTQModifier | ||
from llmcompressor.utils import dispatch_for_generation | ||
|
||
# Select model and load it. | ||
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
||
# Select calibration dataset. | ||
DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
DATASET_SPLIT = "train_sft" | ||
|
||
# Select number of samples. 512 samples is a good place to start. | ||
# Increasing the number of samples can improve accuracy. | ||
NUM_CALIBRATION_SAMPLES = 512 | ||
MAX_SEQUENCE_LENGTH = 2048 | ||
|
||
# Load dataset and preprocess. | ||
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
ds = ds.shuffle(seed=42) | ||
|
||
|
||
def preprocess(example): | ||
return { | ||
"text": tokenizer.apply_chat_template( | ||
example["messages"], | ||
tokenize=False, | ||
) | ||
} | ||
|
||
|
||
ds = ds.map(preprocess) | ||
|
||
|
||
# Tokenize inputs. | ||
def tokenize(sample): | ||
return tokenizer( | ||
sample["text"], | ||
padding=False, | ||
max_length=MAX_SEQUENCE_LENGTH, | ||
truncation=True, | ||
add_special_tokens=False, | ||
) | ||
|
||
|
||
ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
||
# Configure the quantization algorithm to run. | ||
# * quantize self_attn layers to W8A8 with GPTQ | ||
# * quantize mlp layers to W4A16 with AWQ | ||
# only include mappings pertaining to target layers | ||
recipe = [ | ||
GPTQModifier(targets=r"re:.*self_attn\.(k|q|o|v)_proj$", scheme="W8A8"), | ||
AWQModifier( | ||
targets=r"re:.*mlp\.(down|gate|up)_proj$", | ||
mappings=[ | ||
AWQMapping( | ||
"re:.*post_attention_layernorm$", | ||
["re:.*gate_proj$", "re:.*up_proj$"], | ||
), | ||
AWQMapping( | ||
"re:.*up_proj$", | ||
["re:.*down_proj$"], | ||
), | ||
], | ||
scheme="W4A16", | ||
), | ||
] | ||
|
||
# Apply algorithms. | ||
oneshot( | ||
model=model, | ||
dataset=ds, | ||
recipe=recipe, | ||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
# Option 1) run both modifiers in a single calibrated run | ||
pipeline="sequential", | ||
# Option 2) run each modifier in its own separate pipeline | ||
# pipeline="independent", | ||
) | ||
|
||
# Confirm generations of the quantized model look sane. | ||
print("\n\n") | ||
print("========== SAMPLE GENERATION ==============") | ||
dispatch_for_generation(model) | ||
sample = tokenizer("Hello my name is", return_tensors="pt") | ||
sample = {key: value.to(model.device) for key, value in sample.items()} | ||
output = model.generate(**sample, max_new_tokens=100) | ||
print(tokenizer.decode(output[0])) | ||
print("==========================================\n\n") | ||
|
||
# Save to disk compressed. | ||
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-gptq-w8a8-self_attn-awq-w4a16-mlp" | ||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
tokenizer.save_pretrained(SAVE_DIR) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,8 @@ | |
is_preset_scheme, | ||
preset_name_to_scheme, | ||
) | ||
from pydantic import Field, PrivateAttr, field_validator | ||
from compressed_tensors.utils import match_named_modules | ||
from pydantic import Field, PrivateAttr, field_validator, model_validator | ||
from torch.utils.hooks import RemovableHandle | ||
|
||
from llmcompressor.modifiers.quantization.calibration import ( | ||
|
@@ -58,8 +59,9 @@ class QuantizationMixin(HooksMixin): | |
|
||
:param config_groups: dictionary specifying quantization schemes to apply to target | ||
modules. Modules not matching a scheme target will NOT be quantized. | ||
:param targets: list of layer names to quantize if a scheme is provided. Defaults | ||
to Linear layers | ||
:param targets: list of layer names to quantize if a scheme is provided. If unset, | ||
will contain all targets listed in config_groups. If config_groups is also | ||
unset, will default to ["Linear"] (i.e. all Linear layers will be targeted). | ||
:param ignore: optional list of module class names or submodule names to not | ||
quantize even if they match a target in config_groups. Defaults to empty list. | ||
:param scheme: a single quantization scheme to apply to the model. This is a | ||
|
@@ -81,7 +83,7 @@ class QuantizationMixin(HooksMixin): | |
""" | ||
|
||
config_groups: Optional[Dict[str, QuantizationScheme]] = None | ||
targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"]) | ||
targets: Union[str, List[str]] = Field(default_factory=list) | ||
ignore: List[str] = Field(default_factory=list) | ||
scheme: Optional[Union[str, Dict[str, Any]]] = None | ||
kv_cache_scheme: Optional[QuantizationArgs] = None | ||
|
@@ -114,43 +116,71 @@ def validate_scheme( | |
|
||
return value | ||
|
||
@model_validator(mode="after") | ||
def validate_model_after(model: "QuantizationMixin") -> "QuantizationMixin": | ||
""" | ||
- If targets have not been set, aggregate targets from config_groups | ||
into a single unique list | ||
- If targets have still not been found, default to targets=["Linear"] | ||
""" | ||
|
||
if len(model.targets) > 0 and model.config_groups is not None: | ||
raise ValueError("Please specify either `targets` or `config_groups`") | ||
|
||
if len(model.targets) == 0 and model.config_groups is not None: | ||
for config_group in model.config_groups.values(): | ||
for target in config_group.targets: | ||
if target not in model.targets: | ||
model.targets.append(target) | ||
|
||
if len(model.targets) == 0: | ||
model.targets.append("Linear") | ||
|
||
return model | ||
|
||
def initialize_quantization(self, model: torch.nn.Module): | ||
""" | ||
Attach quantization schemes and observers to modules in the model according to | ||
Attach quantization schemes to modules in the model according to | ||
the quantization config specified on this modifier | ||
|
||
:param model: model to attach schemes and observers to | ||
""" | ||
reset_quantization_status(model) # reset any previously applied qconfigs | ||
|
||
# apply scheme and status to model | ||
config = self.resolve_quantization_config() | ||
|
||
for _, module in match_named_modules(model, self.targets, self.ignore): | ||
reset_quantization_status(module) # reset any previously applied qconfigs | ||
|
||
apply_quantization_config(model, config) | ||
|
||
# apply observers, disable quantization until calibration | ||
model.apply(self._initialize_observers) | ||
# disable quantization until calibration | ||
model.apply(disable_quantization) | ||
|
||
def start_calibration(self, model: torch.nn.Module): | ||
""" | ||
Register activation calibration hooks (including kv_cache quantization) and | ||
enable quantization as we calibrate | ||
Attach observers, register activation calibration hooks (including | ||
kv_cache quantization) and enable quantization as we calibrate | ||
|
||
:param model: model to prepare for calibration | ||
""" | ||
self._calibration_hooks = self._initialize_hooks(model) | ||
model.apply(apply_calibration_status) | ||
for _, module in match_named_modules(model, self.targets, self.ignore): | ||
self._initialize_observers(module) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why can't we keep this iniitialize_quantization? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. observers should be initialized on start to align with them being removed on_end. so this was moved into on_start instead. without this change the lifecycle with multiple quant modifiers will trigger observer hooks before the modifier starts (before it sees any data), which can now happen in a previous modifier lifecycle |
||
apply_calibration_status(module) | ||
|
||
model.apply(enable_quantization) # quantize at the same time as calibrate | ||
|
||
def end_calibration(self, model: torch.nn.Module): | ||
""" | ||
Remove calibration hooks and set the model status to frozen. Keep quantization | ||
enabled for future operations | ||
Remove calibration hooks and observers, and set the model status to frozen. | ||
Keep quantization enabled for future operations | ||
|
||
:param model: model to end calibration for | ||
""" | ||
self.remove_hooks(self._calibration_hooks) | ||
model.apply(freeze_module_quantization) # remove observers | ||
for _, module in match_named_modules(model, self.targets, self.ignore): | ||
freeze_module_quantization(module) # remove observers | ||
|
||
model.apply(enable_quantization) # keep quantization enabled | ||
|
||
def has_config(self) -> bool: | ||
|
@@ -240,7 +270,7 @@ def _initialize_observers(self, module: torch.nn.Module): | |
|
||
def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: | ||
hooks = set() | ||
for module in model.modules(): | ||
for _, module in match_named_modules(model, self.targets, self.ignore): | ||
if not hasattr(module, "quantization_scheme"): | ||
continue | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,12 +95,11 @@ def test_block_strategy_parsing(block_q_config_kwargs): | |
def test_actorder_resolution( | ||
has_actorder, actorder, q_config_kwargs, expected_0, expected_1 | ||
): | ||
if has_actorder: | ||
modifier = GPTQModifier(**q_config_kwargs, actorder=actorder) | ||
else: | ||
modifier = GPTQModifier(**q_config_kwargs) | ||
|
||
with pytest.raises(ValueError) if expected_0 == "error" else nullcontext(): | ||
if has_actorder: | ||
modifier = GPTQModifier(**q_config_kwargs, actorder=actorder) | ||
else: | ||
modifier = GPTQModifier(**q_config_kwargs) | ||
resolved = modifier.resolve_quantization_config() | ||
|
||
if expected_0 != "error": | ||
|
@@ -155,8 +154,8 @@ def test_config_resolution(strategies, actorder): | |
) | ||
def test_serialize_actorder(has_actorder, actorder, exp_actorder): | ||
if has_actorder: | ||
modifier = GPTQModifier(targets=["Linear"], actorder=actorder) | ||
modifier = GPTQModifier(targets=["Linear"], scheme="W8A8", actorder=actorder) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How was this targeting before you added the scheme? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it just passed init/validation but was never used. It was never applied to a model, so it would've never worked, i just added it to make sure improper configuration wasn't the reason the test was failing (it was ultimately something else causing the test to fail) |
||
else: | ||
modifier = GPTQModifier(targets=["Linear"]) | ||
modifier = GPTQModifier(targets=["Linear"], scheme="W8A8") | ||
|
||
assert modifier.model_dump()["actorder"] == exp_actorder |
Uh oh!
There was an error while loading. Please reload this page.