Skip to content

Commit 2b138a7

Browse files
committed
tests
Summary Signed-off-by: HDCharles <[email protected]>
1 parent 0f265f9 commit 2b138a7

File tree

2 files changed

+90
-22
lines changed

2 files changed

+90
-22
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
371371
else:
372372
# for multiple balance layers, find lowest common parent
373373
ancestor_name = get_lowest_common_ancestor_name(balance_names)
374-
ancestor, ancestor_name = get_lowest_non_module_list_ancestor(ancestor_name, )
374+
ancestor_name, ancestor = get_lowest_non_module_list_ancestor(ancestor_name, model)
375375

376376
resolved_mappings.append(
377377
ResolvedMapping(
@@ -807,7 +807,8 @@ def _accumulate_mean(
807807

808808
def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]:
809809
"""
810-
Given a name: foo.bar.baz, finds lowest ancestor that's not a ModuleList
810+
Given a name and a model, finds lowest ancestor of
811+
named module that's not a ModuleList
811812
i.e. module_list.module_dict.module_list -> module_list.module_dict
812813
i.e. module_list.module_dict -> module_list.module_dict
813814
(self is an ancestor of self)
@@ -823,7 +824,7 @@ def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Modu
823824
while True:
824825
if name == "":
825826
return "", module
826-
module = get_layer_by_name(name, module)
827-
if not isinstance(module, torch.nn.ModuleList):
828-
return name, module
829-
name = ".".join(parent_name.split(".")[:-1])
827+
current_module = get_layer_by_name(name, module)
828+
if not isinstance(current_module, torch.nn.ModuleList):
829+
return name, current_module
830+
name = ".".join(name.split(".")[:-1])

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 83 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,21 @@ def test_set_resolved_mappings():
4747
"o_proj": Linear(4, 4),
4848
}
4949
)
50-
mlp = torch.nn.ModuleList(
51-
"experts": torch.nn.ModuleList(
52-
[
53-
torch.nn.ModuleDict(
54-
{
55-
"gate_proj": Linear(4, 2),
56-
"down_proj": Linear(4, 2),
57-
}
58-
)
59-
for _ in range(3)
60-
]
61-
)
50+
mlp = torch.nn.ModuleDict(
51+
{
52+
"experts": torch.nn.ModuleList(
53+
[
54+
torch.nn.ModuleDict(
55+
{
56+
"gate_proj": Linear(4, 2),
57+
"up_proj": Linear(4, 2),
58+
"down_proj": Linear(2, 4),
59+
}
60+
)
61+
for _ in range(3)
62+
]
63+
)
64+
}
6265
)
6366
model = torch.nn.ModuleDict(
6467
{
@@ -89,9 +92,12 @@ def test_set_resolved_mappings():
8992
if "self_attn.v_proj" in mapping.smooth_name:
9093
assert set(mapping.balance_names) == {"decoder.self_attn.o_proj"}
9194
assert mapping.parent_name == "decoder.self_attn.o_proj"
92-
if "mlp.up_proj" in mapping.smooth_name:
93-
assert set(mapping.balance_names) == {"decoder.mlp.0.down_proj", "decoder.mlp.0.down_proj", "decoder.mlp.0.down_proj"}
94-
assert mapping.parent_name == "decoder.mlp.down_proj" # TODODODO
95+
if "mlp.experts" in mapping.smooth_name and "up_proj" in mapping.smooth_name:
96+
expert_idx = mapping.smooth_name.split(".")[-2]
97+
expected_down_proj = f"decoder.mlp.experts.{expert_idx}.down_proj"
98+
assert set(mapping.balance_names) == {expected_down_proj}
99+
assert mapping.parent_name == expected_down_proj
100+
assert mapping.parent == mlp["experts"][int(expert_idx)]["down_proj"]
95101

96102
awq = AWQModifier(
97103
mappings=[
@@ -223,7 +229,7 @@ def test_get_lowest_non_module_list_ancestor():
223229
assert ancestor_name == "" and ancestor == model
224230

225231
ancestor_name, ancestor = get_lowest_non_module_list_ancestor(
226-
["experts"], model
232+
"experts", model
227233
)
228234
assert ancestor_name == "" and ancestor == model
229235

@@ -232,3 +238,64 @@ def test_get_lowest_non_module_list_ancestor():
232238
)
233239
assert ancestor_name == "experts.1.gate_proj" and ancestor == model["experts"][1]["gate_proj"]
234240

241+
242+
@pytest.mark.unit
243+
def test_moe_multiple_balance_layers():
244+
"""Test AWQ mapping with multiple balance layers in MoE architecture"""
245+
awq = AWQModifier(
246+
mappings=[
247+
# Map input_layernorm to multiple experts' gate_proj and up_proj
248+
AWQMapping(
249+
"re:.*input_layernorm",
250+
["re:.*gate_proj", "re:.*up_proj"],
251+
),
252+
],
253+
scheme="W4A16_ASYM",
254+
)
255+
256+
# Create a simplified MoE model structure
257+
mlp = torch.nn.ModuleDict(
258+
{
259+
"experts": torch.nn.ModuleList(
260+
[
261+
torch.nn.ModuleDict(
262+
{
263+
"gate_proj": Linear(4, 4),
264+
"up_proj": Linear(4, 4),
265+
"down_proj": Linear(4, 4),
266+
}
267+
)
268+
for _ in range(2)
269+
]
270+
)
271+
}
272+
)
273+
model = torch.nn.ModuleDict(
274+
{
275+
"layer": torch.nn.ModuleDict(
276+
{
277+
"input_layernorm": torch.nn.LayerNorm(4),
278+
"mlp": mlp,
279+
}
280+
)
281+
}
282+
)
283+
284+
awq._set_resolved_mappings(model)
285+
286+
# Should have one mapping for input_layernorm
287+
assert len(awq._resolved_mappings) == 1
288+
mapping = awq._resolved_mappings[0]
289+
290+
# Should map to all gate_proj and up_proj across all experts
291+
expected_balance_names = {
292+
"layer.mlp.experts.0.gate_proj",
293+
"layer.mlp.experts.0.up_proj",
294+
"layer.mlp.experts.1.gate_proj",
295+
"layer.mlp.experts.1.up_proj",
296+
}
297+
assert set(mapping.balance_names) == expected_balance_names
298+
299+
assert mapping.parent_name == "layer.mlp"
300+
assert mapping.parent == mlp
301+

0 commit comments

Comments
 (0)