Skip to content

Commit c02058c

Browse files
cfRodIsotr0pyisharif168
authored
Add bias handling to CPUFusedMOE kernel (#26289)
Signed-off-by: Crefeda Rodrigues <[email protected]> Signed-off-by: Isotr0py <[email protected]> Signed-off-by: Crefeda Rodrigues <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: Sharif Inamdar <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent b2ea5ba commit c02058c

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

tests/kernels/moe/test_moe.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,3 +909,72 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
909909
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
910910

911911
opcheck(torch.ops._moe_C.moe_sum, (input, actual))
912+
913+
914+
@pytest.mark.parametrize("m", [1, 33])
915+
@pytest.mark.parametrize("n,k", [(128, 128)])
916+
@pytest.mark.parametrize("e", [8])
917+
@pytest.mark.parametrize("topk", [2])
918+
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
919+
@pytest.mark.parametrize("with_bias", [False, True])
920+
@pytest.mark.parametrize("activation", ["silu"])
921+
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test")
922+
def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
923+
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE
924+
925+
device = "cpu"
926+
torch.manual_seed(7)
927+
928+
a = torch.randn((m, k), device=device, dtype=dtype) / 10
929+
w13 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
930+
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
931+
router_logits = torch.randn((m, e), device=device, dtype=dtype)
932+
933+
b1 = b2 = None
934+
if with_bias:
935+
b1 = torch.randn((e, 2 * n), device=device, dtype=dtype) / 10
936+
b2 = torch.randn((e, k), device=device, dtype=dtype) / 10
937+
938+
ref = (
939+
torch_moe(a, w13, w2, router_logits, topk, b1, b2)
940+
if with_bias
941+
else torch_moe(a, w13, w2, router_logits, topk)
942+
)
943+
944+
class _Dummy(torch.nn.Module):
945+
def __init__(self, w13, w2, b1=None, b2=None):
946+
super().__init__()
947+
self.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
948+
self.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
949+
if b1 is not None:
950+
self.w13_bias = torch.nn.Parameter(b1, requires_grad=False)
951+
if b2 is not None:
952+
self.w2_bias = torch.nn.Parameter(b2, requires_grad=False)
953+
954+
layer = _Dummy(w13, w2, b1, b2).to(dtype)
955+
fused = CPUFusedMOE(layer)
956+
out = fused(
957+
layer=layer,
958+
x=a,
959+
use_grouped_topk=False,
960+
top_k=topk,
961+
router_logits=router_logits,
962+
renormalize=False,
963+
global_num_experts=e,
964+
expert_map=None,
965+
custom_routing_function=None,
966+
scoring_func="softmax",
967+
routed_scaling_factor=1.0,
968+
e_score_correction_bias=None,
969+
apply_router_weight_on_input=False,
970+
activation=activation,
971+
)
972+
973+
# Tolerances: fp32 tight; bf16 looser (esp. with bias)
974+
if dtype == torch.float32:
975+
atol = 1e-3
976+
elif with_bias:
977+
atol = 8e-2
978+
else:
979+
atol = 5e-2
980+
torch.testing.assert_close(out, ref, atol=atol, rtol=0)

vllm/model_executor/layers/fused_moe/cpu_fused_moe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,18 +276,25 @@ def __call__(
276276

277277
outputs = []
278278
start_idx = 0
279+
has_w13_bias = hasattr(layer, "w13_bias")
280+
has_w2_bias = hasattr(layer, "w2_bias")
281+
279282
for i, num_tokens in enumerate(tokens_per_expert):
280283
end_idx = start_idx + num_tokens
281284
if num_tokens == 0:
282285
continue
283286
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
284287

285288
layer_w13_weight = layer.w13_weight[i]
289+
layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None
286290
layer_w2_weight = layer.w2_weight[i]
291+
layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None
287292

288-
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
293+
gate_up = F.linear(
294+
tokens_for_this_expert, layer_w13_weight, bias=layer_w13_bias
295+
)
289296
gate_up = silu_and_mul(gate_up)
290-
expert_out = F.linear(gate_up, layer_w2_weight)
297+
expert_out = F.linear(gate_up, layer_w2_weight, bias=layer_w2_bias)
291298
outputs.append(expert_out)
292299
start_idx = end_idx
293300

0 commit comments

Comments
 (0)