|
2 | 2 | from typing import Callable, Dict, List, Optional, Tuple, Union |
3 | 3 |
|
4 | 4 | import torch |
5 | | -from compressed_tensors.utils import align_module_device |
| 5 | +from compressed_tensors.utils import align_module_device, match_modules_set |
6 | 6 | from loguru import logger |
7 | 7 | from pydantic import ConfigDict, Field |
8 | 8 | from torch.nn import Module |
|
13 | 13 | get_layer_mappings_from_architecture, |
14 | 14 | handle_mapping_resolution_errors, |
15 | 15 | ) |
| 16 | +from llmcompressor.typing import NamedModules |
16 | 17 | from llmcompressor.utils.fsdp.helpers import get_fsdp_parent |
17 | 18 | from llmcompressor.utils.pytorch.module import ( |
18 | 19 | get_layers, |
@@ -54,6 +55,7 @@ class SmoothQuantMapping: |
54 | 55 |
|
55 | 56 | smooth_name: str |
56 | 57 | smooth_layer: Module |
| 58 | + balance_names: List[str] |
57 | 59 | balance_layers: List[Module] |
58 | 60 |
|
59 | 61 |
|
@@ -178,6 +180,13 @@ def on_finalize(self, state: State, **kwargs) -> bool: |
178 | 180 |
|
179 | 181 | return True |
180 | 182 |
|
| 183 | + def get_targets(self, model: torch.nn.Module) -> NamedModules: |
| 184 | + if not self.initialized_: |
| 185 | + raise ValueError("Cannot get targets before modifier has been initialized") |
| 186 | + |
| 187 | + for balance_targets, smooth_target in self.mappings: |
| 188 | + yield from match_modules_set(model, (*balance_targets, smooth_target)) |
| 189 | + |
181 | 190 | def _infer_mappings_from_model( |
182 | 191 | self, |
183 | 192 | model: Module, |
@@ -207,18 +216,20 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: |
207 | 216 | to_smooth_layers = get_layers(to_smooth, model) |
208 | 217 | for layer_name, smooth_layer in to_smooth_layers.items(): |
209 | 218 | if not match_targets(layer_name, self.ignore)[0]: |
| 219 | + balance_names = [] |
210 | 220 | balance_layers = [] |
211 | 221 | for balance_suffix in to_balance: |
212 | 222 | # find the submodule that matches the activation layer |
213 | | - _, balance_layer = get_matching_layer( |
| 223 | + balance_name, balance_layer = get_matching_layer( |
214 | 224 | balance_suffix, layer_name, model |
215 | 225 | ) |
216 | 226 | if balance_layer: |
| 227 | + balance_names.append(balance_name) |
217 | 228 | balance_layers.append(balance_layer) |
218 | 229 | # each mapping can contain multiple layers to balance, but only |
219 | 230 | # one layer to smooth |
220 | 231 | mapping = SmoothQuantMapping( |
221 | | - layer_name, smooth_layer, balance_layers |
| 232 | + layer_name, smooth_layer, balance_names, balance_layers |
222 | 233 | ) |
223 | 234 | resolved_mappings.append(mapping) |
224 | 235 | return resolved_mappings |
|
0 commit comments