Skip to content

Conversation

@rahul-tuli
Copy link
Collaborator

Problem

SmoothQuant only smoothed the first expert (expert.0) in Mixture of Experts (MoE) models, leaving all other experts unsmoothed. This caused severe accuracy degradation for MoE models

Root Cause

The _resolve_mappings() method in SmoothQuantModifier used get_matching_layer(), which returns only the first regex match instead of iterating over all matches. For MoE models with regex patterns like "re:.*experts.*w1", this meant only expert.0.w1 was smoothed while experts 1-N were ignored.

# Before (BUGGY)
for balance_suffix in to_balance:
    _, balance_layer = get_matching_layer(balance_suffix, layer_name, model)
    # ❌ Returns only expert.0, ignores experts 1-15

Solution

Replace get_matching_layer() with match_named_modules() from compressed_tensors.utils to iterate over ALL matched layers. This follows the same proven pattern used in AWQModifier.

# After (FIXED)
for balance_regex in to_balance:
    for _, balance_layer in match_named_modules(smooth_parent, [balance_regex], self.ignore):
        balance_layers.append(balance_layer)
    # ✅ Returns ALL experts (expert.0, expert.1, ..., expert.15)

Key Changes

  1. Updated imports: Use match_named_modules from compressed_tensors.utils
  2. Rewrote _resolve_mappings(): Iterate over all matched layers instead of just the first

Tests Added

Added unit tests to encompass the issue to verify MoE support, these tests fail on main but pass with current diff:

1. test_moe_all_experts_smoothed

Verifies all 8 experts in a single MoE layer are included in balance_layers:

num_experts = 8
# ... create MoE model with 8 experts ...
resolved_mappings = sq._resolve_mappings(model)
assert len(mapping.balance_layers) == num_experts  # All 8 experts

2. test_moe_multiple_layers_all_experts_smoothed

Verifies correct behavior across multiple transformer layers:

num_layers = 2
num_experts = 4
# ... create model with 2 layers, 4 experts each ...
assert len(resolved_mappings) == num_layers
for mapping in resolved_mappings:
    assert len(mapping.balance_layers) == num_experts  # All 4 experts per layer

Test Results

All tests pass successfully:

$ python -m pytest tests/llmcompressor/modifiers/smoothquant/test_base.py -v

test_smooth_quant_is_registered                          ✅ PASSED
test_smooth_quant_defaults                               ✅ PASSED
test_override_defaults                                   ✅ PASSED
test_moe_all_experts_smoothed                            ✅ PASSED
test_moe_multiple_layers_all_experts_smoothed            ✅ PASSED

========================= 5 passed in 0.41s =========================

Before Fix (Tests Failed)

AssertionError: Expected 8 balance layers, got 1
# Only expert.0 was smoothed ❌

After Fix (Tests Pass)

All 8 experts smoothed ✅
All tests passing ✅

Related Issues

Fixes the SmoothQuant MoE bug reported in the community discussion about MoE quantization support.

@github-actions
Copy link

github-actions bot commented Dec 2, 2025

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

Add comprehensive unit tests to verify that SmoothQuant correctly
handles Mixture of Experts (MoE) models by smoothing all experts,
not just the first one.

Tests added:
- test_moe_all_experts_smoothed: Verifies all 8 experts in a single
  MoE layer are included in balance_layers
- test_moe_multiple_layers_all_experts_smoothed: Verifies correct
  behavior across multiple transformer layers with 4 experts each

These tests currently fail with the existing implementation, which
only matches the first expert due to get_matching_layer() returning
a single match instead of iterating over all matches.

Signed-off-by: Rahul-Tuli <[email protected]>
Replace get_matching_layer() with match_named_modules() to iterate
over ALL matched layers instead of returning only the first match.
This fixes a critical bug where only expert.0 was smoothed in MoE
models, leaving all other experts unsmoothed and causing severe
accuracy degradation.

Changes:
- Use match_named_modules from compressed_tensors.utils to iterate
  over all matching modules
- Search for balance layers within the parent module scope for
  better locality
- Follow the same pattern already proven to work in AWQModifier

This fix ensures all experts in MoE models (Mixtral, Qwen3, Phi,
DeepSeek) are properly smoothed during quantization.

Signed-off-by: Rahul-Tuli <[email protected]>
@rahul-tuli rahul-tuli marked this pull request as ready for review December 2, 2025 15:50
@rahul-tuli rahul-tuli requested review from HDCharles, dsikka, fynnsu, kylesayrs and shanjiaz and removed request for fynnsu December 2, 2025 15:50
@rahul-tuli rahul-tuli self-assigned this Dec 2, 2025
@rahul-tuli rahul-tuli added bug Something isn't working ready When a PR is ready for review labels Dec 2, 2025
Copy link
Collaborator

@fynnsu fynnsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

Comment on lines +216 to +220
for balance_regex in to_balance:
for _, balance_layer in match_named_modules(
smooth_parent, [balance_regex], self.ignore
):
balance_layers.append(balance_layer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just do

for _, balance_layer in match_named_modules(
    smooth_parent, to_balance, self.ignore
):
    balance_layers.append(balance_layer)

since match_named_modules already contains a loop over targets (link).

or maybe even

balance_layers = [balance_layer for _, balance_layer in match_named_modules(smooth_parent, to_balance, self.ignore)]

Comment on lines +209 to +213
smooth_parent = (
get_layer_by_name(smooth_parent_name, model)
if smooth_parent_name
else model
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is totally fine as is but it might be cleaner if we updated get_layer_by_name so that get_layer_by_name("", model) returns model.

This could be done by adding an if statement before this line to handle the special case.

That would allow us to remove the if smooth_parent_name else model logic.

Copy link
Collaborator

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems good to me

# one layer to smooth
mapping = SmoothQuantMapping(
layer_name, smooth_layer, balance_layers
to_smooth_list = [to_smooth] if isinstance(to_smooth, str) else to_smooth
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm working on a matching util that will make this mapping-matching easier, fyi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants