diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index c2b4a4ce36..dcefa2fa43 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -2,7 +2,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.utils import align_module_device +from compressed_tensors.utils import align_module_device, match_named_modules from loguru import logger from pydantic import ConfigDict, Field from torch.nn import Module @@ -14,11 +14,7 @@ handle_mapping_resolution_errors, ) from llmcompressor.utils.fsdp.helpers import get_fsdp_parent -from llmcompressor.utils.pytorch.module import ( - get_layers, - get_matching_layer, - match_targets, -) +from llmcompressor.utils.pytorch.module import get_layer_by_name MINIMUM_SMOOTHING_SCALE = 1e-5 @@ -196,31 +192,34 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: Transforms the list of activations to smooth and their corresponding weights into SmoothQuantMapping objects, resolving regular expressions. - For each activation in the mapping list, we find the corresponding weight to - balance by searching for the longest substring. For instance, if our balance - weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we - would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and - repeat for model.layer.1 and so on + For each activation in the mapping list, we find ALL corresponding weights to + balance by matching within the parent scope. This ensures all matching layers + are included, which is critical for MoE models where multiple experts need to + be balanced. """ resolved_mappings = [] for to_balance, to_smooth in self.mappings: - to_smooth_layers = get_layers(to_smooth, model) - for layer_name, smooth_layer in to_smooth_layers.items(): - if not match_targets(layer_name, self.ignore)[0]: - balance_layers = [] - for balance_suffix in to_balance: - # find the submodule that matches the activation layer - _, balance_layer = get_matching_layer( - balance_suffix, layer_name, model - ) - if balance_layer: - balance_layers.append(balance_layer) - # each mapping can contain multiple layers to balance, but only - # one layer to smooth - mapping = SmoothQuantMapping( - layer_name, smooth_layer, balance_layers + to_smooth_list = [to_smooth] if isinstance(to_smooth, str) else to_smooth + + for smooth_name, smooth_layer in match_named_modules( + model, to_smooth_list, self.ignore + ): + # Search for balance layers within the parent scope + smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) + smooth_parent = get_layer_by_name(smooth_parent_name, model) + + balance_layers = [ + balance_layer + for _, balance_layer in match_named_modules( + smooth_parent, to_balance, self.ignore + ) + ] + + if balance_layers: + resolved_mappings.append( + SmoothQuantMapping(smooth_name, smooth_layer, balance_layers) ) - resolved_mappings.append(mapping) + return resolved_mappings def _setup_scale_hooks(self): diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 72144cc0f7..6d2152fe47 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -361,8 +361,11 @@ def get_no_split_params(model: PreTrainedModel) -> Union[str, List[str]]: def get_layer_by_name(layer_name: str, module: Module) -> Module: """ Get the layer of a module by name. - :param layer_name: Name of the layer to find. + :param layer_name: Name of the layer to find. Empty string returns the + module itself. :param module: Module in which to search for layer_name :return: Module, the layer with name layer_name """ + if not layer_name: + return module return attrgetter(layer_name)(module) diff --git a/tests/llmcompressor/modifiers/smoothquant/test_base.py b/tests/llmcompressor/modifiers/smoothquant/test_base.py index 0b91232974..3f12de887f 100644 --- a/tests/llmcompressor/modifiers/smoothquant/test_base.py +++ b/tests/llmcompressor/modifiers/smoothquant/test_base.py @@ -1,4 +1,5 @@ import pytest +import torch from llmcompressor.modifiers.factory import ModifierFactory from llmcompressor.modifiers.smoothquant.base import SmoothQuantModifier @@ -41,3 +42,131 @@ def test_override_defaults(): assert non_default_sq.smoothing_strength == strength assert non_default_sq.mappings == dummy_map + + +@pytest.mark.unit +def test_moe_all_experts_smoothed(): + """ + Test that SmoothQuant smooths ALL experts in MoE models, not just expert.0. + + Verifies that all experts are included in balance_layers when resolving + mappings for MoE models with multiple experts. + """ + num_experts = 8 + hidden_size = 256 + + experts = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + "w1": torch.nn.Linear(hidden_size, hidden_size), + "w2": torch.nn.Linear(hidden_size, hidden_size), + } + ) + for _ in range(num_experts) + ] + ) + + model = torch.nn.ModuleDict( + { + "layers": torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + "input_layernorm": torch.nn.LayerNorm(hidden_size), + "mlp": torch.nn.ModuleDict( + { + "gate": torch.nn.Linear(hidden_size, num_experts), + "experts": experts, + } + ), + } + ) + ] + ) + } + ) + + sq = SmoothQuantModifier( + smoothing_strength=0.8, + mappings=[(["re:.*experts.*w1"], "re:.*input_layernorm")], + ignore=["re:.*gate"], + ) + + resolved_mappings = sq._resolve_mappings(model) + + assert len(resolved_mappings) == 1 + mapping = resolved_mappings[0] + + assert "input_layernorm" in mapping.smooth_name + assert ( + len(mapping.balance_layers) == num_experts + ), f"Expected {num_experts} balance layers, got {len(mapping.balance_layers)}" + + # Verify no duplicates + balance_layer_ids = [id(layer) for layer in mapping.balance_layers] + assert len(balance_layer_ids) == len(set(balance_layer_ids)) + + # Verify correct layers + expected_expert_w1s = {experts[i].w1 for i in range(num_experts)} + assert set(mapping.balance_layers) == expected_expert_w1s + + +@pytest.mark.unit +def test_moe_multiple_layers_all_experts_smoothed(): + """ + Test SmoothQuant with multiple MoE layers to ensure all experts across + all layers are smoothed correctly. + """ + num_layers = 2 + num_experts = 4 + hidden_size = 128 + + def create_moe_layer(): + experts = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + "w1": torch.nn.Linear(hidden_size, hidden_size), + "w2": torch.nn.Linear(hidden_size, hidden_size), + } + ) + for _ in range(num_experts) + ] + ) + + return torch.nn.ModuleDict( + { + "input_layernorm": torch.nn.LayerNorm(hidden_size), + "mlp": torch.nn.ModuleDict( + { + "gate": torch.nn.Linear(hidden_size, num_experts), + "experts": experts, + } + ), + } + ) + + model = torch.nn.ModuleDict( + {"layers": torch.nn.ModuleList([create_moe_layer() for _ in range(num_layers)])} + ) + + sq = SmoothQuantModifier( + smoothing_strength=0.8, + mappings=[(["re:.*experts.*w1"], "re:.*input_layernorm")], + ignore=["re:.*gate"], + ) + + resolved_mappings = sq._resolve_mappings(model) + + assert len(resolved_mappings) == num_layers + + for i, mapping in enumerate(resolved_mappings): + assert len(mapping.balance_layers) == num_experts, ( + f"Layer {i}: Expected {num_experts} balance layers, " + f"got {len(mapping.balance_layers)}" + ) + + # Verify all balance layers are unique + balance_layer_ids = [id(layer) for layer in mapping.balance_layers] + assert len(balance_layer_ids) == len(set(balance_layer_ids)) diff --git a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py index 49d0cdda3a..f3e948469e 100644 --- a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py @@ -8,15 +8,18 @@ @pytest.mark.unit def test_log_equalization_mapping(state): - mappings = [(["seq.fc2"], "seq.block1.fc1")] + # Use regex patterns with parent-scoped search + # Searches for balance layers within the parent of smooth layer + mappings = [(["re:^fc2$"], "re:.*block1\\.fc1$")] modifier = LogarithmicEqualizationModifier(mappings=mappings) modifier.ignore = [] modifier.resolved_mappings_ = modifier._resolve_mappings(state.model) - assert len(modifier.resolved_mappings_) == len(mappings) + assert len(modifier.resolved_mappings_) == 1 mapping = modifier.resolved_mappings_[0] - assert mapping.smooth_name == mappings[0][1] + assert mapping.smooth_name == "seq.block1.fc1" assert isinstance(mapping.smooth_layer, Linear) + assert len(mapping.balance_layers) == 1 assert isinstance(mapping.balance_layers[0], Linear) diff --git a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py index ee8844041e..aefa6b9579 100644 --- a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py @@ -6,15 +6,21 @@ @pytest.mark.unit def test_smooth_quant_mapping(state): - mappings = [(["seq.fc1"], "seq.fc2")] + # Use regex patterns with parent-scoped search + # ^fc1$ matches only direct child "fc1", not nested "block1.fc1" + mappings = [(["re:^fc1$"], "re:.*fc2$")] modifier = SmoothQuantModifier(mappings=mappings) modifier.ignore = [] modifier.resolved_mappings_ = modifier._resolve_mappings(state.model) - assert len(modifier.resolved_mappings_) == len(mappings) + # Should match seq.fc2 and block1.fc2 (both end with fc2) + assert len(modifier.resolved_mappings_) == 2 - mapping = modifier.resolved_mappings_[0] - assert mapping.smooth_name == mappings[0][1] - assert isinstance(mapping.smooth_layer, Linear) - assert isinstance(mapping.balance_layers[0], Linear) + # Verify seq.fc2 mapping - should find only seq.fc1 (direct child) + seq_mapping = [ + m for m in modifier.resolved_mappings_ if m.smooth_name == "seq.fc2" + ][0] + assert isinstance(seq_mapping.smooth_layer, Linear) + assert len(seq_mapping.balance_layers) == 1 + assert isinstance(seq_mapping.balance_layers[0], Linear) diff --git a/tests/llmcompressor/utils/pytorch/test_module.py b/tests/llmcompressor/utils/pytorch/test_module.py index 1ab40aa159..05cce846b4 100644 --- a/tests/llmcompressor/utils/pytorch/test_module.py +++ b/tests/llmcompressor/utils/pytorch/test_module.py @@ -29,6 +29,10 @@ def test_get_layer_by_name(example_nested_module): layer = get_layer_by_name("2.1", example_nested_module) assert layer == example_nested_module[2][1] + # Test that empty string returns the module itself + layer = get_layer_by_name("", example_nested_module) + assert layer == example_nested_module + # Test getting the parent of a non-existent layer with pytest.raises(AttributeError): get_layer_by_name("non_existent_layer", example_nested_module)