|
2 | 2 | from typing import Dict, List, Optional, Tuple, Union
|
3 | 3 |
|
4 | 4 | import torch
|
5 |
| -from compressed_tensors.quantization import ( |
6 |
| - disable_quantization, |
7 |
| - find_name_or_class_matches, |
8 |
| -) |
| 5 | +from compressed_tensors.quantization import disable_quantization |
9 | 6 | from compressed_tensors.utils import (
|
10 | 7 | align_modules,
|
11 | 8 | get_execution_device,
|
| 9 | + match_named_modules, |
12 | 10 | update_offload_parameter,
|
13 | 11 | )
|
14 | 12 | from loguru import logger
|
|
29 | 27 | from llmcompressor.pipelines.cache import IntermediatesCache
|
30 | 28 | from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
|
31 | 29 | from llmcompressor.utils.helpers import calibration_forward_context
|
32 |
| -from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers |
| 30 | +from llmcompressor.utils.pytorch.module import get_layer_by_name |
33 | 31 |
|
34 | 32 | __all__ = ["AWQModifier"]
|
35 | 33 |
|
@@ -306,35 +304,27 @@ def _set_resolved_mappings(self, model: Module) -> None:
|
306 | 304 | """
|
307 | 305 | resolved_mappings: list[ResolvedMapping] = []
|
308 | 306 | for mapping_idx, mapping in enumerate(self.mappings):
|
309 |
| - smooth_layers = get_layers( |
310 |
| - mapping.smooth_layer, model, exclude_internal_modules=True |
311 |
| - ) |
312 |
| - smooth_names = [ |
313 |
| - smooth_name |
314 |
| - for smooth_name in smooth_layers |
315 |
| - if not find_name_or_class_matches(smooth_name, model, self.ignore) |
316 |
| - ] |
317 |
| - |
318 | 307 | num_skipped_mappings = 0
|
319 |
| - pbar = tqdm(smooth_names) |
320 |
| - for smooth_name in pbar: |
| 308 | + |
| 309 | + for smooth_name, smooth_layer in ( |
| 310 | + pbar := tqdm( |
| 311 | + match_named_modules(model, [mapping.smooth_layer], self.ignore) |
| 312 | + ) |
| 313 | + ): |
321 | 314 | pbar.set_description(
|
322 | 315 | f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
|
323 | 316 | f" ({num_skipped_mappings} skipped)"
|
324 | 317 | )
|
325 |
| - smooth_layer = smooth_layers[smooth_name] |
326 | 318 |
|
327 | 319 | smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
|
328 | 320 | smooth_parent = get_layer_by_name(smooth_parent_name, model)
|
329 | 321 |
|
330 | 322 | balance_layers, balance_names = [], []
|
331 | 323 | for balance_regex in mapping.balance_layers:
|
332 | 324 | # find the submodules that match the activation layer
|
333 |
| - for balance_suffix, balance_layer in get_layers( |
334 |
| - balance_regex, |
335 |
| - smooth_parent, |
336 |
| - exclude_internal_modules=True, |
337 |
| - ).items(): |
| 325 | + for balance_suffix, balance_layer in match_named_modules( |
| 326 | + smooth_parent, [balance_regex], self.ignore |
| 327 | + ): |
338 | 328 | balance_name = f"{smooth_parent_name}.{balance_suffix}"
|
339 | 329 |
|
340 | 330 | # exclude v_proj->o_proj mappings whose shapes are incompatible
|
@@ -579,6 +569,12 @@ def _compute_best_scale(
|
579 | 569 | best_scales = None
|
580 | 570 | best_error = float("inf")
|
581 | 571 |
|
| 572 | + org_sd = { |
| 573 | + k: v.cpu() |
| 574 | + for k, v in parent_module.state_dict().items() |
| 575 | + if v.device != torch.device("meta") |
| 576 | + } |
| 577 | + |
582 | 578 | device = get_execution_device(parent_module)
|
583 | 579 | x_mean = x_mean.view(-1).to(device)
|
584 | 580 | w_mean = w_mean.view(-1).to(device)
|
@@ -628,6 +624,8 @@ def _compute_best_scale(
|
628 | 624 | best_ratio = ratio
|
629 | 625 | best_scales = scales.clone()
|
630 | 626 |
|
| 627 | + parent_module.load_state_dict(org_sd, strict=False) |
| 628 | + |
631 | 629 | if best_ratio == -1:
|
632 | 630 | logger.debug(history)
|
633 | 631 | raise Exception(
|
|
0 commit comments