Skip to content

Commit 190337a

Browse files
committed
Add tests for MoE support in SmoothQuant
Add comprehensive unit tests to verify that SmoothQuant correctly handles Mixture of Experts (MoE) models by smoothing all experts, not just the first one. Tests added: - test_moe_all_experts_smoothed: Verifies all 8 experts in a single MoE layer are included in balance_layers - test_moe_multiple_layers_all_experts_smoothed: Verifies correct behavior across multiple transformer layers with 4 experts each These tests currently fail with the existing implementation, which only matches the first expert due to get_matching_layer() returning a single match instead of iterating over all matches. Signed-off-by: Rahul-Tuli <[email protected]>
1 parent db0b68d commit 190337a

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed

tests/llmcompressor/modifiers/smoothquant/test_base.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import torch
23

34
from llmcompressor.modifiers.factory import ModifierFactory
45
from llmcompressor.modifiers.smoothquant.base import SmoothQuantModifier
@@ -41,3 +42,131 @@ def test_override_defaults():
4142

4243
assert non_default_sq.smoothing_strength == strength
4344
assert non_default_sq.mappings == dummy_map
45+
46+
47+
@pytest.mark.unit
48+
def test_moe_all_experts_smoothed():
49+
"""
50+
Test that SmoothQuant smooths ALL experts in MoE models, not just expert.0.
51+
52+
Verifies that all experts are included in balance_layers when resolving
53+
mappings for MoE models with multiple experts.
54+
"""
55+
num_experts = 8
56+
hidden_size = 256
57+
58+
experts = torch.nn.ModuleList(
59+
[
60+
torch.nn.ModuleDict(
61+
{
62+
"w1": torch.nn.Linear(hidden_size, hidden_size),
63+
"w2": torch.nn.Linear(hidden_size, hidden_size),
64+
}
65+
)
66+
for _ in range(num_experts)
67+
]
68+
)
69+
70+
model = torch.nn.ModuleDict(
71+
{
72+
"layers": torch.nn.ModuleList(
73+
[
74+
torch.nn.ModuleDict(
75+
{
76+
"input_layernorm": torch.nn.LayerNorm(hidden_size),
77+
"mlp": torch.nn.ModuleDict(
78+
{
79+
"gate": torch.nn.Linear(hidden_size, num_experts),
80+
"experts": experts,
81+
}
82+
),
83+
}
84+
)
85+
]
86+
)
87+
}
88+
)
89+
90+
sq = SmoothQuantModifier(
91+
smoothing_strength=0.8,
92+
mappings=[(["re:.*experts.*w1"], "re:.*input_layernorm")],
93+
ignore=["re:.*gate"],
94+
)
95+
96+
resolved_mappings = sq._resolve_mappings(model)
97+
98+
assert len(resolved_mappings) == 1
99+
mapping = resolved_mappings[0]
100+
101+
assert "input_layernorm" in mapping.smooth_name
102+
assert (
103+
len(mapping.balance_layers) == num_experts
104+
), f"Expected {num_experts} balance layers, got {len(mapping.balance_layers)}"
105+
106+
# Verify no duplicates
107+
balance_layer_ids = [id(layer) for layer in mapping.balance_layers]
108+
assert len(balance_layer_ids) == len(set(balance_layer_ids))
109+
110+
# Verify correct layers
111+
expected_expert_w1s = {experts[i].w1 for i in range(num_experts)}
112+
assert set(mapping.balance_layers) == expected_expert_w1s
113+
114+
115+
@pytest.mark.unit
116+
def test_moe_multiple_layers_all_experts_smoothed():
117+
"""
118+
Test SmoothQuant with multiple MoE layers to ensure all experts across
119+
all layers are smoothed correctly.
120+
"""
121+
num_layers = 2
122+
num_experts = 4
123+
hidden_size = 128
124+
125+
def create_moe_layer():
126+
experts = torch.nn.ModuleList(
127+
[
128+
torch.nn.ModuleDict(
129+
{
130+
"w1": torch.nn.Linear(hidden_size, hidden_size),
131+
"w2": torch.nn.Linear(hidden_size, hidden_size),
132+
}
133+
)
134+
for _ in range(num_experts)
135+
]
136+
)
137+
138+
return torch.nn.ModuleDict(
139+
{
140+
"input_layernorm": torch.nn.LayerNorm(hidden_size),
141+
"mlp": torch.nn.ModuleDict(
142+
{
143+
"gate": torch.nn.Linear(hidden_size, num_experts),
144+
"experts": experts,
145+
}
146+
),
147+
}
148+
)
149+
150+
model = torch.nn.ModuleDict(
151+
{"layers": torch.nn.ModuleList([create_moe_layer() for _ in range(num_layers)])}
152+
)
153+
154+
sq = SmoothQuantModifier(
155+
smoothing_strength=0.8,
156+
mappings=[(["re:.*experts.*w1"], "re:.*input_layernorm")],
157+
ignore=["re:.*gate"],
158+
)
159+
160+
resolved_mappings = sq._resolve_mappings(model)
161+
162+
assert len(resolved_mappings) == num_layers
163+
164+
for i, mapping in enumerate(resolved_mappings):
165+
assert len(mapping.balance_layers) == num_experts, (
166+
f"Layer {i}: Expected {num_experts} balance layers, "
167+
f"got {len(mapping.balance_layers)}"
168+
)
169+
170+
# Verify all balance layers are unique
171+
balance_layer_ids = [id(layer) for layer in mapping.balance_layers]
172+
assert len(balance_layer_ids) == len(set(balance_layer_ids))

0 commit comments

Comments
 (0)