@@ -909,3 +909,72 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
909
909
torch .testing .assert_close (actual , expected , atol = 2e-2 , rtol = 0 )
910
910
911
911
opcheck (torch .ops ._moe_C .moe_sum , (input , actual ))
912
+
913
+
914
+ @pytest .mark .parametrize ("m" , [1 , 33 ])
915
+ @pytest .mark .parametrize ("n,k" , [(128 , 128 )])
916
+ @pytest .mark .parametrize ("e" , [8 ])
917
+ @pytest .mark .parametrize ("topk" , [2 ])
918
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .bfloat16 ])
919
+ @pytest .mark .parametrize ("with_bias" , [False , True ])
920
+ @pytest .mark .parametrize ("activation" , ["silu" ])
921
+ @pytest .mark .skipif (not current_platform .is_cpu (), reason = "CPU only test" )
922
+ def test_cpu_fused_moe_basic (m , n , k , e , topk , dtype , with_bias , activation ):
923
+ from vllm .model_executor .layers .fused_moe .cpu_fused_moe import CPUFusedMOE
924
+
925
+ device = "cpu"
926
+ torch .manual_seed (7 )
927
+
928
+ a = torch .randn ((m , k ), device = device , dtype = dtype ) / 10
929
+ w13 = torch .randn ((e , 2 * n , k ), device = device , dtype = dtype ) / 10
930
+ w2 = torch .randn ((e , k , n ), device = device , dtype = dtype ) / 10
931
+ router_logits = torch .randn ((m , e ), device = device , dtype = dtype )
932
+
933
+ b1 = b2 = None
934
+ if with_bias :
935
+ b1 = torch .randn ((e , 2 * n ), device = device , dtype = dtype ) / 10
936
+ b2 = torch .randn ((e , k ), device = device , dtype = dtype ) / 10
937
+
938
+ ref = (
939
+ torch_moe (a , w13 , w2 , router_logits , topk , b1 , b2 )
940
+ if with_bias
941
+ else torch_moe (a , w13 , w2 , router_logits , topk )
942
+ )
943
+
944
+ class _Dummy (torch .nn .Module ):
945
+ def __init__ (self , w13 , w2 , b1 = None , b2 = None ):
946
+ super ().__init__ ()
947
+ self .w13_weight = torch .nn .Parameter (w13 , requires_grad = False )
948
+ self .w2_weight = torch .nn .Parameter (w2 , requires_grad = False )
949
+ if b1 is not None :
950
+ self .w13_bias = torch .nn .Parameter (b1 , requires_grad = False )
951
+ if b2 is not None :
952
+ self .w2_bias = torch .nn .Parameter (b2 , requires_grad = False )
953
+
954
+ layer = _Dummy (w13 , w2 , b1 , b2 ).to (dtype )
955
+ fused = CPUFusedMOE (layer )
956
+ out = fused (
957
+ layer = layer ,
958
+ x = a ,
959
+ use_grouped_topk = False ,
960
+ top_k = topk ,
961
+ router_logits = router_logits ,
962
+ renormalize = False ,
963
+ global_num_experts = e ,
964
+ expert_map = None ,
965
+ custom_routing_function = None ,
966
+ scoring_func = "softmax" ,
967
+ routed_scaling_factor = 1.0 ,
968
+ e_score_correction_bias = None ,
969
+ apply_router_weight_on_input = False ,
970
+ activation = activation ,
971
+ )
972
+
973
+ # Tolerances: fp32 tight; bf16 looser (esp. with bias)
974
+ if dtype == torch .float32 :
975
+ atol = 1e-3
976
+ elif with_bias :
977
+ atol = 8e-2
978
+ else :
979
+ atol = 5e-2
980
+ torch .testing .assert_close (out , ref , atol = atol , rtol = 0 )
0 commit comments