Skip to content

Commit 5c4d2d6

Browse files
committed
Fix SmoothQuant to smooth all experts in MoE models
Replace get_matching_layer() with match_named_modules() to iterate over ALL matched layers instead of returning only the first match. This fixes a critical bug where only expert.0 was smoothed in MoE models, leaving all other experts unsmoothed and causing severe accuracy degradation. Changes: - Use match_named_modules from compressed_tensors.utils to iterate over all matching modules - Search for balance layers within the parent module scope for better locality - Follow the same pattern already proven to work in AWQModifier This fix ensures all experts in MoE models (Mixtral, Qwen3, Phi, DeepSeek) are properly smoothed during quantization. Signed-off-by: Rahul-Tuli <[email protected]>
1 parent 190337a commit 5c4d2d6

File tree

1 file changed

+30
-27
lines changed
  • src/llmcompressor/modifiers/smoothquant

1 file changed

+30
-27
lines changed

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Callable, Dict, List, Optional, Tuple, Union
33

44
import torch
5-
from compressed_tensors.utils import align_module_device
5+
from compressed_tensors.utils import align_module_device, match_named_modules
66
from loguru import logger
77
from pydantic import ConfigDict, Field
88
from torch.nn import Module
@@ -14,11 +14,7 @@
1414
handle_mapping_resolution_errors,
1515
)
1616
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
17-
from llmcompressor.utils.pytorch.module import (
18-
get_layers,
19-
get_matching_layer,
20-
match_targets,
21-
)
17+
from llmcompressor.utils.pytorch.module import get_layer_by_name
2218

2319
MINIMUM_SMOOTHING_SCALE = 1e-5
2420

@@ -196,31 +192,38 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
196192
Transforms the list of activations to smooth and their corresponding weights
197193
into SmoothQuantMapping objects, resolving regular expressions.
198194
199-
For each activation in the mapping list, we find the corresponding weight to
200-
balance by searching for the longest substring. For instance, if our balance
201-
weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
202-
would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
203-
repeat for model.layer.1 and so on
195+
For each activation in the mapping list, we find ALL corresponding weights to
196+
balance by matching within the parent scope. This ensures all matching layers
197+
are included, which is critical for MoE models where multiple experts need to
198+
be balanced.
204199
"""
205200
resolved_mappings = []
206201
for to_balance, to_smooth in self.mappings:
207-
to_smooth_layers = get_layers(to_smooth, model)
208-
for layer_name, smooth_layer in to_smooth_layers.items():
209-
if not match_targets(layer_name, self.ignore)[0]:
210-
balance_layers = []
211-
for balance_suffix in to_balance:
212-
# find the submodule that matches the activation layer
213-
_, balance_layer = get_matching_layer(
214-
balance_suffix, layer_name, model
215-
)
216-
if balance_layer:
217-
balance_layers.append(balance_layer)
218-
# each mapping can contain multiple layers to balance, but only
219-
# one layer to smooth
220-
mapping = SmoothQuantMapping(
221-
layer_name, smooth_layer, balance_layers
202+
to_smooth_list = [to_smooth] if isinstance(to_smooth, str) else to_smooth
203+
204+
for smooth_name, smooth_layer in match_named_modules(
205+
model, to_smooth_list, self.ignore
206+
):
207+
# Search for balance layers within the parent scope
208+
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
209+
smooth_parent = (
210+
get_layer_by_name(smooth_parent_name, model)
211+
if smooth_parent_name
212+
else model
213+
)
214+
215+
balance_layers = []
216+
for balance_regex in to_balance:
217+
for _, balance_layer in match_named_modules(
218+
smooth_parent, [balance_regex], self.ignore
219+
):
220+
balance_layers.append(balance_layer)
221+
222+
if balance_layers:
223+
resolved_mappings.append(
224+
SmoothQuantMapping(smooth_name, smooth_layer, balance_layers)
222225
)
223-
resolved_mappings.append(mapping)
226+
224227
return resolved_mappings
225228

226229
def _setup_scale_hooks(self):

0 commit comments

Comments
 (0)