Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.utils import align_module_device
from compressed_tensors.utils import align_module_device, match_named_modules
from loguru import logger
from pydantic import ConfigDict, Field
from torch.nn import Module
Expand All @@ -14,11 +14,7 @@
handle_mapping_resolution_errors,
)
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.pytorch.module import (
get_layers,
get_matching_layer,
match_targets,
)
from llmcompressor.utils.pytorch.module import get_layer_by_name

MINIMUM_SMOOTHING_SCALE = 1e-5

Expand Down Expand Up @@ -196,31 +192,34 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
Transforms the list of activations to smooth and their corresponding weights
into SmoothQuantMapping objects, resolving regular expressions.
For each activation in the mapping list, we find the corresponding weight to
balance by searching for the longest substring. For instance, if our balance
weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
repeat for model.layer.1 and so on
For each activation in the mapping list, we find ALL corresponding weights to
balance by matching within the parent scope. This ensures all matching layers
are included, which is critical for MoE models where multiple experts need to
be balanced.
"""
resolved_mappings = []
for to_balance, to_smooth in self.mappings:
to_smooth_layers = get_layers(to_smooth, model)
for layer_name, smooth_layer in to_smooth_layers.items():
if not match_targets(layer_name, self.ignore)[0]:
balance_layers = []
for balance_suffix in to_balance:
# find the submodule that matches the activation layer
_, balance_layer = get_matching_layer(
balance_suffix, layer_name, model
)
if balance_layer:
balance_layers.append(balance_layer)
# each mapping can contain multiple layers to balance, but only
# one layer to smooth
mapping = SmoothQuantMapping(
layer_name, smooth_layer, balance_layers
to_smooth_list = [to_smooth] if isinstance(to_smooth, str) else to_smooth

for smooth_name, smooth_layer in match_named_modules(
model, to_smooth_list, self.ignore
):
# Search for balance layers within the parent scope
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
smooth_parent = get_layer_by_name(smooth_parent_name, model)

balance_layers = [
balance_layer
for _, balance_layer in match_named_modules(
smooth_parent, to_balance, self.ignore
)
]

if balance_layers:
resolved_mappings.append(
SmoothQuantMapping(smooth_name, smooth_layer, balance_layers)
)
resolved_mappings.append(mapping)

return resolved_mappings

def _setup_scale_hooks(self):
Expand Down
5 changes: 4 additions & 1 deletion src/llmcompressor/utils/pytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,11 @@ def get_no_split_params(model: PreTrainedModel) -> Union[str, List[str]]:
def get_layer_by_name(layer_name: str, module: Module) -> Module:
"""
Get the layer of a module by name.
:param layer_name: Name of the layer to find.
:param layer_name: Name of the layer to find. Empty string returns the
module itself.
:param module: Module in which to search for layer_name
:return: Module, the layer with name layer_name
"""
if not layer_name:
return module
return attrgetter(layer_name)(module)
129 changes: 129 additions & 0 deletions tests/llmcompressor/modifiers/smoothquant/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import torch

from llmcompressor.modifiers.factory import ModifierFactory
from llmcompressor.modifiers.smoothquant.base import SmoothQuantModifier
Expand Down Expand Up @@ -41,3 +42,131 @@ def test_override_defaults():

assert non_default_sq.smoothing_strength == strength
assert non_default_sq.mappings == dummy_map


@pytest.mark.unit
def test_moe_all_experts_smoothed():
"""
Test that SmoothQuant smooths ALL experts in MoE models, not just expert.0.
Verifies that all experts are included in balance_layers when resolving
mappings for MoE models with multiple experts.
"""
num_experts = 8
hidden_size = 256

experts = torch.nn.ModuleList(
[
torch.nn.ModuleDict(
{
"w1": torch.nn.Linear(hidden_size, hidden_size),
"w2": torch.nn.Linear(hidden_size, hidden_size),
}
)
for _ in range(num_experts)
]
)

model = torch.nn.ModuleDict(
{
"layers": torch.nn.ModuleList(
[
torch.nn.ModuleDict(
{
"input_layernorm": torch.nn.LayerNorm(hidden_size),
"mlp": torch.nn.ModuleDict(
{
"gate": torch.nn.Linear(hidden_size, num_experts),
"experts": experts,
}
),
}
)
]
)
}
)

sq = SmoothQuantModifier(
smoothing_strength=0.8,
mappings=[(["re:.*experts.*w1"], "re:.*input_layernorm")],
ignore=["re:.*gate"],
)

resolved_mappings = sq._resolve_mappings(model)

assert len(resolved_mappings) == 1
mapping = resolved_mappings[0]

assert "input_layernorm" in mapping.smooth_name
assert (
len(mapping.balance_layers) == num_experts
), f"Expected {num_experts} balance layers, got {len(mapping.balance_layers)}"

# Verify no duplicates
balance_layer_ids = [id(layer) for layer in mapping.balance_layers]
assert len(balance_layer_ids) == len(set(balance_layer_ids))

# Verify correct layers
expected_expert_w1s = {experts[i].w1 for i in range(num_experts)}
assert set(mapping.balance_layers) == expected_expert_w1s


@pytest.mark.unit
def test_moe_multiple_layers_all_experts_smoothed():
"""
Test SmoothQuant with multiple MoE layers to ensure all experts across
all layers are smoothed correctly.
"""
num_layers = 2
num_experts = 4
hidden_size = 128

def create_moe_layer():
experts = torch.nn.ModuleList(
[
torch.nn.ModuleDict(
{
"w1": torch.nn.Linear(hidden_size, hidden_size),
"w2": torch.nn.Linear(hidden_size, hidden_size),
}
)
for _ in range(num_experts)
]
)

return torch.nn.ModuleDict(
{
"input_layernorm": torch.nn.LayerNorm(hidden_size),
"mlp": torch.nn.ModuleDict(
{
"gate": torch.nn.Linear(hidden_size, num_experts),
"experts": experts,
}
),
}
)

model = torch.nn.ModuleDict(
{"layers": torch.nn.ModuleList([create_moe_layer() for _ in range(num_layers)])}
)

sq = SmoothQuantModifier(
smoothing_strength=0.8,
mappings=[(["re:.*experts.*w1"], "re:.*input_layernorm")],
ignore=["re:.*gate"],
)

resolved_mappings = sq._resolve_mappings(model)

assert len(resolved_mappings) == num_layers

for i, mapping in enumerate(resolved_mappings):
assert len(mapping.balance_layers) == num_experts, (
f"Layer {i}: Expected {num_experts} balance layers, "
f"got {len(mapping.balance_layers)}"
)

# Verify all balance layers are unique
balance_layer_ids = [id(layer) for layer in mapping.balance_layers]
assert len(balance_layer_ids) == len(set(balance_layer_ids))
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@

@pytest.mark.unit
def test_log_equalization_mapping(state):
mappings = [(["seq.fc2"], "seq.block1.fc1")]
# Use regex patterns with parent-scoped search
# Searches for balance layers within the parent of smooth layer
mappings = [(["re:^fc2$"], "re:.*block1\\.fc1$")]
modifier = LogarithmicEqualizationModifier(mappings=mappings)

modifier.ignore = []
modifier.resolved_mappings_ = modifier._resolve_mappings(state.model)

assert len(modifier.resolved_mappings_) == len(mappings)
assert len(modifier.resolved_mappings_) == 1

mapping = modifier.resolved_mappings_[0]
assert mapping.smooth_name == mappings[0][1]
assert mapping.smooth_name == "seq.block1.fc1"
assert isinstance(mapping.smooth_layer, Linear)
assert len(mapping.balance_layers) == 1
assert isinstance(mapping.balance_layers[0], Linear)
18 changes: 12 additions & 6 deletions tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@

@pytest.mark.unit
def test_smooth_quant_mapping(state):
mappings = [(["seq.fc1"], "seq.fc2")]
# Use regex patterns with parent-scoped search
# ^fc1$ matches only direct child "fc1", not nested "block1.fc1"
mappings = [(["re:^fc1$"], "re:.*fc2$")]
modifier = SmoothQuantModifier(mappings=mappings)

modifier.ignore = []
modifier.resolved_mappings_ = modifier._resolve_mappings(state.model)

assert len(modifier.resolved_mappings_) == len(mappings)
# Should match seq.fc2 and block1.fc2 (both end with fc2)
assert len(modifier.resolved_mappings_) == 2

mapping = modifier.resolved_mappings_[0]
assert mapping.smooth_name == mappings[0][1]
assert isinstance(mapping.smooth_layer, Linear)
assert isinstance(mapping.balance_layers[0], Linear)
# Verify seq.fc2 mapping - should find only seq.fc1 (direct child)
seq_mapping = [
m for m in modifier.resolved_mappings_ if m.smooth_name == "seq.fc2"
][0]
assert isinstance(seq_mapping.smooth_layer, Linear)
assert len(seq_mapping.balance_layers) == 1
assert isinstance(seq_mapping.balance_layers[0], Linear)
4 changes: 4 additions & 0 deletions tests/llmcompressor/utils/pytorch/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def test_get_layer_by_name(example_nested_module):
layer = get_layer_by_name("2.1", example_nested_module)
assert layer == example_nested_module[2][1]

# Test that empty string returns the module itself
layer = get_layer_by_name("", example_nested_module)
assert layer == example_nested_module

# Test getting the parent of a non-existent layer
with pytest.raises(AttributeError):
get_layer_by_name("non_existent_layer", example_nested_module)