@@ -334,6 +334,29 @@ def fused_experts_with_mc2(
334
334
return hidden_states , shared_output
335
335
336
336
337
+ def init_routing_quant (hidden_states , top_k , topk_ids , global_num_experts ):
338
+ num_tokens , _ = hidden_states .shape
339
+ row_idx_len = num_tokens * top_k
340
+ row_idx = (torch .arange (0 ,
341
+ row_idx_len ,
342
+ dtype = torch .int32 ,
343
+ device = hidden_states .device ).view (
344
+ top_k , - 1 ).permute (1 , 0 ).contiguous ())
345
+ hidden_states , expanded_row_idx , expanded_expert_idx = torch_npu .npu_moe_init_routing (
346
+ hidden_states ,
347
+ row_idx = row_idx ,
348
+ expert_idx = topk_ids ,
349
+ active_num = num_tokens )
350
+
351
+ expanded_row_idx = (expanded_row_idx .view (top_k , - 1 ).permute (
352
+ 1 , 0 ).contiguous ().view (- 1 ))
353
+ global_expert_tokens = torch .bincount (expanded_expert_idx ,
354
+ minlength = global_num_experts )
355
+ global_expert_tokens = global_expert_tokens .to (torch .int32 )
356
+ quantized_tokens , token_scales = torch_npu .npu_dynamic_quant (hidden_states )
357
+ return quantized_tokens , expanded_row_idx , global_expert_tokens , token_scales
358
+
359
+
337
360
# currently expert parallelism implemented with all2all
338
361
# is under-optimized.
339
362
def fused_experts_with_all2all (
@@ -358,50 +381,54 @@ def fused_experts_with_all2all(
358
381
359
382
num_tokens , _ = hidden_states .shape
360
383
num_experts = w1 .shape [0 ]
361
- device = hidden_states .device
362
384
363
385
if expert_map is not None :
364
386
global_num_experts = len (expert_map ) + global_redundant_expert_num
365
- local_num_experts = global_num_experts // ep_group .world_size
366
- row_idx_len = num_tokens * top_k
367
- row_idx = (torch .arange (0 ,
368
- row_idx_len ,
369
- dtype = torch .int32 ,
370
- device = device ).view (top_k , - 1 ).permute (
371
- 1 , 0 ).contiguous ())
372
- hidden_states , expanded_row_idx , expanded_expert_idx = torch_npu .npu_moe_init_routing (
373
- hidden_states ,
374
- row_idx = row_idx ,
375
- expert_idx = topk_ids ,
376
- active_num = num_tokens )
377
-
378
- global_expert_tokens = torch .bincount (expanded_expert_idx ,
379
- minlength = global_num_experts )
380
- scatter_sizes = global_expert_tokens .view (ep_group .world_size ,
381
- - 1 ).sum (- 1 )
382
-
383
- gather_sizes = torch .empty_like (scatter_sizes )
384
- dist .all_to_all_single (gather_sizes ,
385
- scatter_sizes ,
386
- group = ep_group .device_group )
387
- scatter_size_list = scatter_sizes .cpu ().tolist ()
388
- gather_size_list = gather_sizes .cpu ().tolist ()
389
-
390
- expanded_expert_idx = expanded_expert_idx % local_num_experts
391
- hidden_states = ep_group .all_to_all (hidden_states , 0 , 0 ,
392
- scatter_size_list ,
393
- gather_size_list )
394
- local_expert_idx = ep_group .all_to_all (expanded_expert_idx , 0 , 0 ,
395
- scatter_size_list ,
396
- gather_size_list )
397
-
398
- sorted_local_expert_idx , sorted_idx = torch .sort (local_expert_idx )
399
-
400
- expert_tokens = torch_npu .npu_moe_compute_expert_tokens (
401
- sorted_local_expert_idx , local_num_experts ).to (torch .int64 )
402
-
403
- hidden_states = hidden_states [sorted_idx ]
404
- group_list_type = 0
387
+ if hasattr (torch_npu , "npu_moe_init_routing_quant" ):
388
+ quantized_tokens , expanded_row_idx , global_expert_tokens , _ , token_scales = torch_npu .npu_moe_init_routing_quant (
389
+ hidden_states ,
390
+ expert_idx = topk_ids .to (torch .int32 ),
391
+ active_num = 0 ,
392
+ expert_capacity = 0 ,
393
+ expert_num = global_num_experts ,
394
+ drop_pad_mode = 0 ,
395
+ expert_tokens_num_mode = 2 ,
396
+ expert_tokens_before_capacity_flag = False ,
397
+ quant_mode = 1 ,
398
+ )
399
+ else :
400
+ quantized_tokens , expanded_row_idx , global_expert_tokens , token_scales = init_routing_quant (
401
+ hidden_states , top_k , topk_ids , global_num_experts )
402
+
403
+ gather_sizes = global_expert_tokens .new_empty (
404
+ global_expert_tokens .shape [0 ])
405
+ dist .all_to_all_single (gather_sizes , global_expert_tokens )
406
+
407
+ token_counts_combined = torch .stack (
408
+ [gather_sizes , global_expert_tokens ], dim = 0 )
409
+ token_counts_combined = token_counts_combined .view (
410
+ 2 , ep_group .world_size , - 1 ).sum (dim = 2 )
411
+ token_counts_combined_cpu = token_counts_combined .to (
412
+ torch .device ("cpu" ), non_blocking = True ).numpy ()
413
+ all_tokens = gather_sizes .sum ()
414
+
415
+ gathered_tokens = quantized_tokens .new_empty (all_tokens .item (),
416
+ quantized_tokens .shape [1 ])
417
+ dynamic_scale = token_scales .new_empty (gathered_tokens .shape [0 ])
418
+ gather_size_list = token_counts_combined_cpu [1 ]
419
+ scatter_size_list = token_counts_combined_cpu [0 ]
420
+
421
+ dist .all_to_all_single (gathered_tokens , quantized_tokens ,
422
+ scatter_size_list , gather_size_list )
423
+ dist .all_to_all_single (dynamic_scale , token_scales , scatter_size_list ,
424
+ gather_size_list )
425
+
426
+ hidden_states , dynamic_scale , inverse_indices , expert_tokens = torch_npu .npu_moe_re_routing (
427
+ gathered_tokens ,
428
+ gather_sizes .view (ep_group .world_size , - 1 ),
429
+ per_token_scales = dynamic_scale )
430
+ expert_tokens = expert_tokens .to (torch .int64 )
431
+ group_list_type = 1
405
432
else :
406
433
row_idx_len = num_tokens * top_k
407
434
row_idx = torch .arange (0 ,
@@ -419,6 +446,7 @@ def fused_experts_with_all2all(
419
446
expanded_expert_idx , num_experts )
420
447
expert_tokens = expert_tokens .to (torch .int64 )
421
448
group_list_type = 0
449
+ dynamic_scale = None
422
450
423
451
# `hidden_states` will be disposed in the `apply_mlp` function
424
452
hidden_states = apply_mlp (
@@ -428,14 +456,19 @@ def fused_experts_with_all2all(
428
456
w2 ,
429
457
w2_scale ,
430
458
expert_tokens , #16
459
+ dynamic_scale = dynamic_scale ,
431
460
group_list_type = group_list_type )
432
461
433
462
if expert_map is not None :
434
- resorted_idx = torch .argsort (sorted_idx )
435
- hidden_states = hidden_states [resorted_idx ]
436
- hidden_states = ep_group .all_to_all (hidden_states , 0 , 0 ,
437
- gather_size_list ,
438
- scatter_size_list )
463
+ reordered_outputs = torch .index_select (
464
+ hidden_states ,
465
+ dim = 0 ,
466
+ # Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU
467
+ index = inverse_indices .to (torch .float32 ).argsort ().to (torch .int32 ))
468
+
469
+ hidden_states = reordered_outputs .new_empty (* quantized_tokens .shape )
470
+ dist .all_to_all_single (hidden_states , reordered_outputs ,
471
+ gather_size_list , scatter_size_list )
439
472
440
473
final_hidden_states = torch_npu .npu_moe_finalize_routing (
441
474
hidden_states ,
@@ -444,8 +477,8 @@ def fused_experts_with_all2all(
444
477
bias = None ,
445
478
scales = topk_weights ,
446
479
expanded_src_to_dst_row = expanded_row_idx ,
447
- export_for_source_row = topk_ids ,
448
- )
480
+ export_for_source_row = None ,
481
+ drop_pad_mode = 2 )
449
482
else :
450
483
# TODO: Reorder device memory 2 times here, replace the current
451
484
# implementation here when suitable operators become available.
0 commit comments