|
2 | 2 | from typing import Callable, Dict, List, Optional, Tuple, Union |
3 | 3 |
|
4 | 4 | import torch |
5 | | -from compressed_tensors.utils import align_module_device |
| 5 | +from compressed_tensors.utils import align_module_device, match_named_modules |
6 | 6 | from loguru import logger |
7 | 7 | from pydantic import ConfigDict, Field |
8 | 8 | from torch.nn import Module |
|
14 | 14 | handle_mapping_resolution_errors, |
15 | 15 | ) |
16 | 16 | from llmcompressor.utils.fsdp.helpers import get_fsdp_parent |
17 | | -from llmcompressor.utils.pytorch.module import ( |
18 | | - get_layers, |
19 | | - get_matching_layer, |
20 | | - match_targets, |
21 | | -) |
| 17 | +from llmcompressor.utils.pytorch.module import get_layer_by_name |
22 | 18 |
|
23 | 19 | MINIMUM_SMOOTHING_SCALE = 1e-5 |
24 | 20 |
|
@@ -196,31 +192,38 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: |
196 | 192 | Transforms the list of activations to smooth and their corresponding weights |
197 | 193 | into SmoothQuantMapping objects, resolving regular expressions. |
198 | 194 |
|
199 | | - For each activation in the mapping list, we find the corresponding weight to |
200 | | - balance by searching for the longest substring. For instance, if our balance |
201 | | - weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we |
202 | | - would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and |
203 | | - repeat for model.layer.1 and so on |
| 195 | + For each activation in the mapping list, we find ALL corresponding weights to |
| 196 | + balance by matching within the parent scope. This ensures all matching layers |
| 197 | + are included, which is critical for MoE models where multiple experts need to |
| 198 | + be balanced. |
204 | 199 | """ |
205 | 200 | resolved_mappings = [] |
206 | 201 | for to_balance, to_smooth in self.mappings: |
207 | | - to_smooth_layers = get_layers(to_smooth, model) |
208 | | - for layer_name, smooth_layer in to_smooth_layers.items(): |
209 | | - if not match_targets(layer_name, self.ignore)[0]: |
210 | | - balance_layers = [] |
211 | | - for balance_suffix in to_balance: |
212 | | - # find the submodule that matches the activation layer |
213 | | - _, balance_layer = get_matching_layer( |
214 | | - balance_suffix, layer_name, model |
215 | | - ) |
216 | | - if balance_layer: |
217 | | - balance_layers.append(balance_layer) |
218 | | - # each mapping can contain multiple layers to balance, but only |
219 | | - # one layer to smooth |
220 | | - mapping = SmoothQuantMapping( |
221 | | - layer_name, smooth_layer, balance_layers |
| 202 | + to_smooth_list = [to_smooth] if isinstance(to_smooth, str) else to_smooth |
| 203 | + |
| 204 | + for smooth_name, smooth_layer in match_named_modules( |
| 205 | + model, to_smooth_list, self.ignore |
| 206 | + ): |
| 207 | + # Search for balance layers within the parent scope |
| 208 | + smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) |
| 209 | + smooth_parent = ( |
| 210 | + get_layer_by_name(smooth_parent_name, model) |
| 211 | + if smooth_parent_name |
| 212 | + else model |
| 213 | + ) |
| 214 | + |
| 215 | + balance_layers = [] |
| 216 | + for balance_regex in to_balance: |
| 217 | + for _, balance_layer in match_named_modules( |
| 218 | + smooth_parent, [balance_regex], self.ignore |
| 219 | + ): |
| 220 | + balance_layers.append(balance_layer) |
| 221 | + |
| 222 | + if balance_layers: |
| 223 | + resolved_mappings.append( |
| 224 | + SmoothQuantMapping(smooth_name, smooth_layer, balance_layers) |
222 | 225 | ) |
223 | | - resolved_mappings.append(mapping) |
| 226 | + |
224 | 227 | return resolved_mappings |
225 | 228 |
|
226 | 229 | def _setup_scale_hooks(self): |
|
0 commit comments