@@ -47,18 +47,21 @@ def test_set_resolved_mappings():
4747 "o_proj" : Linear (4 , 4 ),
4848 }
4949 )
50- mlp = torch .nn .ModuleList (
51- "experts" : torch .nn .ModuleList (
52- [
53- torch .nn .ModuleDict (
54- {
55- "gate_proj" : Linear (4 , 2 ),
56- "down_proj" : Linear (4 , 2 ),
57- }
58- )
59- for _ in range (3 )
60- ]
61- )
50+ mlp = torch .nn .ModuleDict (
51+ {
52+ "experts" : torch .nn .ModuleList (
53+ [
54+ torch .nn .ModuleDict (
55+ {
56+ "gate_proj" : Linear (4 , 2 ),
57+ "up_proj" : Linear (4 , 2 ),
58+ "down_proj" : Linear (2 , 4 ),
59+ }
60+ )
61+ for _ in range (3 )
62+ ]
63+ )
64+ }
6265 )
6366 model = torch .nn .ModuleDict (
6467 {
@@ -89,9 +92,12 @@ def test_set_resolved_mappings():
8992 if "self_attn.v_proj" in mapping .smooth_name :
9093 assert set (mapping .balance_names ) == {"decoder.self_attn.o_proj" }
9194 assert mapping .parent_name == "decoder.self_attn.o_proj"
92- if "mlp.up_proj" in mapping .smooth_name :
93- assert set (mapping .balance_names ) == {"decoder.mlp.0.down_proj" , "decoder.mlp.0.down_proj" , "decoder.mlp.0.down_proj" }
94- assert mapping .parent_name == "decoder.mlp.down_proj" # TODODODO
95+ if "mlp.experts" in mapping .smooth_name and "up_proj" in mapping .smooth_name :
96+ expert_idx = mapping .smooth_name .split ("." )[- 2 ]
97+ expected_down_proj = f"decoder.mlp.experts.{ expert_idx } .down_proj"
98+ assert set (mapping .balance_names ) == {expected_down_proj }
99+ assert mapping .parent_name == expected_down_proj
100+ assert mapping .parent == mlp ["experts" ][int (expert_idx )]["down_proj" ]
95101
96102 awq = AWQModifier (
97103 mappings = [
@@ -223,7 +229,7 @@ def test_get_lowest_non_module_list_ancestor():
223229 assert ancestor_name == "" and ancestor == model
224230
225231 ancestor_name , ancestor = get_lowest_non_module_list_ancestor (
226- [ "experts" ] , model
232+ "experts" , model
227233 )
228234 assert ancestor_name == "" and ancestor == model
229235
@@ -232,3 +238,64 @@ def test_get_lowest_non_module_list_ancestor():
232238 )
233239 assert ancestor_name == "experts.1.gate_proj" and ancestor == model ["experts" ][1 ]["gate_proj" ]
234240
241+
242+ @pytest .mark .unit
243+ def test_moe_multiple_balance_layers ():
244+ """Test AWQ mapping with multiple balance layers in MoE architecture"""
245+ awq = AWQModifier (
246+ mappings = [
247+ # Map input_layernorm to multiple experts' gate_proj and up_proj
248+ AWQMapping (
249+ "re:.*input_layernorm" ,
250+ ["re:.*gate_proj" , "re:.*up_proj" ],
251+ ),
252+ ],
253+ scheme = "W4A16_ASYM" ,
254+ )
255+
256+ # Create a simplified MoE model structure
257+ mlp = torch .nn .ModuleDict (
258+ {
259+ "experts" : torch .nn .ModuleList (
260+ [
261+ torch .nn .ModuleDict (
262+ {
263+ "gate_proj" : Linear (4 , 4 ),
264+ "up_proj" : Linear (4 , 4 ),
265+ "down_proj" : Linear (4 , 4 ),
266+ }
267+ )
268+ for _ in range (2 )
269+ ]
270+ )
271+ }
272+ )
273+ model = torch .nn .ModuleDict (
274+ {
275+ "layer" : torch .nn .ModuleDict (
276+ {
277+ "input_layernorm" : torch .nn .LayerNorm (4 ),
278+ "mlp" : mlp ,
279+ }
280+ )
281+ }
282+ )
283+
284+ awq ._set_resolved_mappings (model )
285+
286+ # Should have one mapping for input_layernorm
287+ assert len (awq ._resolved_mappings ) == 1
288+ mapping = awq ._resolved_mappings [0 ]
289+
290+ # Should map to all gate_proj and up_proj across all experts
291+ expected_balance_names = {
292+ "layer.mlp.experts.0.gate_proj" ,
293+ "layer.mlp.experts.0.up_proj" ,
294+ "layer.mlp.experts.1.gate_proj" ,
295+ "layer.mlp.experts.1.up_proj" ,
296+ }
297+ assert set (mapping .balance_names ) == expected_balance_names
298+
299+ assert mapping .parent_name == "layer.mlp"
300+ assert mapping .parent == mlp
301+
0 commit comments