@@ -327,7 +327,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
327
327
layer .w13_weight .data = shuffled_w13
328
328
layer .w2_weight .data = shuffled_w2
329
329
330
- if current_platform .is_cpu ():
330
+ if current_platform .is_xpu ():
331
+ import intel_extension_for_pytorch as ipex
332
+ layer .ipex_fusion = ipex .llm .modules .GatedMLPMOE (
333
+ layer .w13_weight ,
334
+ layer .w2_weight ,
335
+ use_prepack = True ,
336
+ )
337
+ elif current_platform .is_cpu ():
331
338
if current_platform .get_cpu_architecture () == CpuArchEnum .X86 :
332
339
from vllm .model_executor .layers .fused_moe import cpu_fused_moe
333
340
dtype = layer .w13_weight .dtype
@@ -509,6 +516,44 @@ def forward_cpu(
509
516
activation ,
510
517
)
511
518
519
+ def forward_xpu (
520
+ self ,
521
+ layer : torch .nn .Module ,
522
+ x : torch .Tensor ,
523
+ use_grouped_topk : bool ,
524
+ top_k : int ,
525
+ router_logits : torch .Tensor ,
526
+ renormalize : bool ,
527
+ topk_group : Optional [int ] = None ,
528
+ num_expert_group : Optional [int ] = None ,
529
+ global_num_experts : int = - 1 ,
530
+ expert_map : Optional [torch .Tensor ] = None ,
531
+ custom_routing_function : Optional [Callable ] = None ,
532
+ scoring_func : str = "softmax" ,
533
+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
534
+ apply_router_weight_on_input : bool = False ,
535
+ activation : str = "silu" ,
536
+ enable_eplb : bool = False ,
537
+ expert_load_view : Optional [torch .Tensor ] = None ,
538
+ logical_to_physical_map : Optional [torch .Tensor ] = None ,
539
+ logical_replica_count : Optional [torch .Tensor ] = None ,
540
+ ):
541
+ if enable_eplb is not False or expert_load_view is not None or \
542
+ logical_to_physical_map is not None or \
543
+ logical_replica_count is not None :
544
+ raise NotImplementedError ("Expert load balancing is not supported "
545
+ "for XPU." )
546
+ assert custom_routing_function is None
547
+ return layer .ipex_fusion (
548
+ x ,
549
+ use_grouped_topk ,
550
+ top_k ,
551
+ router_logits ,
552
+ renormalize ,
553
+ topk_group ,
554
+ num_expert_group ,
555
+ )
556
+
512
557
def forward_tpu (
513
558
self ,
514
559
layer : torch .nn .Module ,
0 commit comments