3
3
import functools
4
4
import json
5
5
import os
6
- from typing import Any , Callable , Dict , List , Optional , Tuple
6
+ from typing import Any , Callable , Optional
7
7
8
8
import torch
9
9
@@ -472,14 +472,14 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
472
472
num_tokens_post_padded : torch .Tensor ,
473
473
mul_routed_weight : bool ,
474
474
top_k : int ,
475
- config : Dict [str , Any ],
475
+ config : dict [str , Any ],
476
476
compute_type : tl .dtype ,
477
477
use_fp8_w8a8 : bool ,
478
478
use_int8_w8a8 : bool ,
479
479
use_int8_w8a16 : bool ,
480
480
use_int4_w4a16 : bool ,
481
481
per_channel_quant : bool ,
482
- block_shape : Optional [List [int ]] = None ) -> None :
482
+ block_shape : Optional [list [int ]] = None ) -> None :
483
483
assert topk_weights is not None or not mul_routed_weight
484
484
assert topk_weights is None or topk_weights .stride (1 ) == 1
485
485
assert sorted_token_ids .stride (0 ) == 1
@@ -622,7 +622,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
622
622
def get_config_file_name (E : int ,
623
623
N : int ,
624
624
dtype : Optional [str ],
625
- block_shape : Optional [List [int ]] = None ) -> str :
625
+ block_shape : Optional [list [int ]] = None ) -> str :
626
626
device_name = current_platform .get_device_name ().replace (" " , "_" )
627
627
dtype_selector = "" if not dtype else f",dtype={ dtype } "
628
628
block_shape_selector = ("" if not block_shape or not all (block_shape ) else
@@ -638,7 +638,7 @@ def get_moe_configs(
638
638
dtype : Optional [str ],
639
639
block_n : Optional [int ] = None ,
640
640
block_k : Optional [int ] = None ,
641
- ) -> Optional [Dict [int , Any ]]:
641
+ ) -> Optional [dict [int , Any ]]:
642
642
"""
643
643
Return optimized configurations for the fused MoE kernel.
644
644
@@ -670,7 +670,7 @@ def get_moe_configs(
670
670
return None
671
671
672
672
673
- def get_moe_wna16_block_config (config : Dict [str ,
673
+ def get_moe_wna16_block_config (config : dict [str ,
674
674
int ], use_moe_wna16_cuda : bool ,
675
675
num_valid_tokens : int , size_k : int , size_n : int ,
676
676
num_experts : int , group_size : int ,
@@ -742,8 +742,8 @@ def get_default_config(
742
742
topk : int ,
743
743
dtype : Optional [str ],
744
744
is_marlin : bool ,
745
- block_shape : Optional [List [int ]] = None ,
746
- ) -> Dict [str , int ]:
745
+ block_shape : Optional [list [int ]] = None ,
746
+ ) -> dict [str , int ]:
747
747
if dtype == "fp8_w8a8" and block_shape is not None :
748
748
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
749
749
# BLOCK_SIZE_K must be divisible by block_shape[1]
@@ -795,13 +795,13 @@ def get_default_config(
795
795
796
796
797
797
def try_get_optimal_moe_config (
798
- w1_shape : Tuple [int , ...],
799
- w2_shape : Tuple [int , ...],
798
+ w1_shape : tuple [int , ...],
799
+ w2_shape : tuple [int , ...],
800
800
top_k : int ,
801
801
dtype : Optional [str ],
802
802
M : int ,
803
803
is_marlin : bool = False ,
804
- block_shape : Optional [List [int ]] = None ,
804
+ block_shape : Optional [list [int ]] = None ,
805
805
):
806
806
from vllm .model_executor .layers .fused_moe import get_config
807
807
override_config = get_config ()
@@ -855,7 +855,7 @@ def fused_topk(
855
855
gating_output : torch .Tensor ,
856
856
topk : int ,
857
857
renormalize : bool ,
858
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
858
+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
859
859
assert hidden_states .shape [0 ] == gating_output .shape [0 ], (
860
860
"Number of tokens mismatch" )
861
861
@@ -895,7 +895,7 @@ def grouped_topk(
895
895
topk_group : int = 0 ,
896
896
scoring_func : str = "softmax" ,
897
897
e_score_correction_bias : Optional [torch .Tensor ] = None
898
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
898
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
899
899
900
900
assert hidden_states .shape [0 ] == gating_output .shape [0 ], (
901
901
"Number of tokens mismatch" )
@@ -982,7 +982,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
982
982
w2_zp : Optional [torch .Tensor ] = None ,
983
983
a1_scale : Optional [torch .Tensor ] = None ,
984
984
a2_scale : Optional [torch .Tensor ] = None ,
985
- block_shape : Optional [List [int ]] = None ) -> None :
985
+ block_shape : Optional [list [int ]] = None ) -> None :
986
986
fused_experts_impl (hidden_states , w1 , w2 , topk_weights , topk_ids , True ,
987
987
activation , apply_router_weight_on_input , use_fp8_w8a8 ,
988
988
use_int8_w8a8 , use_int8_w8a16 , use_int4_w4a16 ,
@@ -1012,7 +1012,7 @@ def inplace_fused_experts_fake(
1012
1012
w2_zp : Optional [torch .Tensor ] = None ,
1013
1013
a1_scale : Optional [torch .Tensor ] = None ,
1014
1014
a2_scale : Optional [torch .Tensor ] = None ,
1015
- block_shape : Optional [List [int ]] = None ) -> None :
1015
+ block_shape : Optional [list [int ]] = None ) -> None :
1016
1016
pass
1017
1017
1018
1018
@@ -1046,7 +1046,7 @@ def outplace_fused_experts(
1046
1046
w2_zp : Optional [torch .Tensor ] = None ,
1047
1047
a1_scale : Optional [torch .Tensor ] = None ,
1048
1048
a2_scale : Optional [torch .Tensor ] = None ,
1049
- block_shape : Optional [List [int ]] = None ) -> torch .Tensor :
1049
+ block_shape : Optional [list [int ]] = None ) -> torch .Tensor :
1050
1050
return fused_experts_impl (hidden_states , w1 , w2 , topk_weights , topk_ids ,
1051
1051
False , activation , apply_router_weight_on_input ,
1052
1052
use_fp8_w8a8 , use_int8_w8a8 , use_int8_w8a16 ,
@@ -1076,7 +1076,7 @@ def outplace_fused_experts_fake(
1076
1076
w2_zp : Optional [torch .Tensor ] = None ,
1077
1077
a1_scale : Optional [torch .Tensor ] = None ,
1078
1078
a2_scale : Optional [torch .Tensor ] = None ,
1079
- block_shape : Optional [List [int ]] = None ) -> torch .Tensor :
1079
+ block_shape : Optional [list [int ]] = None ) -> torch .Tensor :
1080
1080
return torch .empty_like (hidden_states )
1081
1081
1082
1082
@@ -1129,7 +1129,7 @@ def fused_experts(hidden_states: torch.Tensor,
1129
1129
w2_zp : Optional [torch .Tensor ] = None ,
1130
1130
a1_scale : Optional [torch .Tensor ] = None ,
1131
1131
a2_scale : Optional [torch .Tensor ] = None ,
1132
- block_shape : Optional [List [int ]] = None ,
1132
+ block_shape : Optional [list [int ]] = None ,
1133
1133
allow_deep_gemm : bool = False ) -> torch .Tensor :
1134
1134
if (allow_deep_gemm and use_fp8_w8a8
1135
1135
and _valid_deep_gemm (hidden_states , w1 , w2 , expert_map )):
@@ -1184,8 +1184,8 @@ def moe_kernel_prepare_input(
1184
1184
use_int8_w8a16 : bool ,
1185
1185
use_int4_w4a16 : bool ,
1186
1186
per_channel_quant : bool ,
1187
- block_shape : Optional [List [int ]] = None ,
1188
- ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
1187
+ block_shape : Optional [list [int ]] = None ,
1188
+ ) -> tuple [torch .Tensor , Optional [torch .Tensor ]]:
1189
1189
if use_fp8_w8a8 :
1190
1190
assert B_scale is not None
1191
1191
if block_shape is None :
@@ -1248,7 +1248,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
1248
1248
w2_zp : Optional [torch .Tensor ] = None ,
1249
1249
a1_scale : Optional [torch .Tensor ] = None ,
1250
1250
a2_scale : Optional [torch .Tensor ] = None ,
1251
- block_shape : Optional [List [int ]] = None ):
1251
+ block_shape : Optional [list [int ]] = None ):
1252
1252
# Check constraints.
1253
1253
if use_int4_w4a16 :
1254
1254
assert hidden_states .shape [1 ] // 2 == w1 .shape [
@@ -1452,7 +1452,7 @@ def fused_moe(
1452
1452
w2_zp : Optional [torch .Tensor ] = None ,
1453
1453
a1_scale : Optional [torch .Tensor ] = None ,
1454
1454
a2_scale : Optional [torch .Tensor ] = None ,
1455
- block_shape : Optional [List [int ]] = None ,
1455
+ block_shape : Optional [list [int ]] = None ,
1456
1456
) -> torch .Tensor :
1457
1457
"""
1458
1458
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1497,7 +1497,7 @@ def fused_moe(
1497
1497
a1.
1498
1498
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
1499
1499
a2.
1500
- - block_shape: (Optional[List [int]]): Optional block size for block-wise
1500
+ - block_shape: (Optional[list [int]]): Optional block size for block-wise
1501
1501
quantization.
1502
1502
1503
1503
Returns:
0 commit comments