Skip to content

Commit 4dec2c3

Browse files
[AWQ] fixes to matching logic and #1742 bugfix (#1759)
SUMMARY: - [x] I introduced a bug in #1742 that caused `lm_eval` awq test to fail. This reverts that change, re-setting the original state dict of the parent module in the grid search for best scales. - [x] This also updates to the new module matching API, excluding from the resolved mappings any modules that match the list in `ignore`. This should resolve a user issue with command-a-vision, which has k_proj etc. layers in the vision_encoder that we want to exclude in our resolved mappings. TEST PLAN: awq lm_eval test is passing now. Running command-a-vision check --------- Signed-off-by: Brian Dellabetta <[email protected]>
1 parent f4b78b7 commit 4dec2c3

File tree

1 file changed

+20
-22
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+20
-22
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
from typing import Dict, List, Optional, Tuple, Union
33

44
import torch
5-
from compressed_tensors.quantization import (
6-
disable_quantization,
7-
find_name_or_class_matches,
8-
)
5+
from compressed_tensors.quantization import disable_quantization
96
from compressed_tensors.utils import (
107
align_modules,
118
get_execution_device,
9+
match_named_modules,
1210
update_offload_parameter,
1311
)
1412
from loguru import logger
@@ -29,7 +27,7 @@
2927
from llmcompressor.pipelines.cache import IntermediatesCache
3028
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
3129
from llmcompressor.utils.helpers import calibration_forward_context
32-
from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers
30+
from llmcompressor.utils.pytorch.module import get_layer_by_name
3331

3432
__all__ = ["AWQModifier"]
3533

@@ -306,35 +304,27 @@ def _set_resolved_mappings(self, model: Module) -> None:
306304
"""
307305
resolved_mappings: list[ResolvedMapping] = []
308306
for mapping_idx, mapping in enumerate(self.mappings):
309-
smooth_layers = get_layers(
310-
mapping.smooth_layer, model, exclude_internal_modules=True
311-
)
312-
smooth_names = [
313-
smooth_name
314-
for smooth_name in smooth_layers
315-
if not find_name_or_class_matches(smooth_name, model, self.ignore)
316-
]
317-
318307
num_skipped_mappings = 0
319-
pbar = tqdm(smooth_names)
320-
for smooth_name in pbar:
308+
309+
for smooth_name, smooth_layer in (
310+
pbar := tqdm(
311+
match_named_modules(model, [mapping.smooth_layer], self.ignore)
312+
)
313+
):
321314
pbar.set_description(
322315
f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
323316
f" ({num_skipped_mappings} skipped)"
324317
)
325-
smooth_layer = smooth_layers[smooth_name]
326318

327319
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
328320
smooth_parent = get_layer_by_name(smooth_parent_name, model)
329321

330322
balance_layers, balance_names = [], []
331323
for balance_regex in mapping.balance_layers:
332324
# find the submodules that match the activation layer
333-
for balance_suffix, balance_layer in get_layers(
334-
balance_regex,
335-
smooth_parent,
336-
exclude_internal_modules=True,
337-
).items():
325+
for balance_suffix, balance_layer in match_named_modules(
326+
smooth_parent, [balance_regex], self.ignore
327+
):
338328
balance_name = f"{smooth_parent_name}.{balance_suffix}"
339329

340330
# exclude v_proj->o_proj mappings whose shapes are incompatible
@@ -579,6 +569,12 @@ def _compute_best_scale(
579569
best_scales = None
580570
best_error = float("inf")
581571

572+
org_sd = {
573+
k: v.cpu()
574+
for k, v in parent_module.state_dict().items()
575+
if v.device != torch.device("meta")
576+
}
577+
582578
device = get_execution_device(parent_module)
583579
x_mean = x_mean.view(-1).to(device)
584580
w_mean = w_mean.view(-1).to(device)
@@ -628,6 +624,8 @@ def _compute_best_scale(
628624
best_ratio = ratio
629625
best_scales = scales.clone()
630626

627+
parent_module.load_state_dict(org_sd, strict=False)
628+
631629
if best_ratio == -1:
632630
logger.debug(history)
633631
raise Exception(

0 commit comments

Comments
 (0)