1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
- from typing import Any , Optional
3
+ from typing import TYPE_CHECKING , Any , Optional
4
4
5
5
import torch
6
6
7
7
import vllm .model_executor .layers .fused_moe .modular_kernel as mk
8
8
from vllm .model_executor .layers .fused_moe .topk_weight_and_reduce import (
9
9
TopKWeightAndReduceDelegate )
10
10
from vllm .model_executor .layers .fused_moe .utils import extract_required_args
11
+ from vllm .utils import has_triton_kernels
11
12
12
- if True :
13
+ if has_triton_kernels () :
13
14
import triton_kernels .swiglu
14
- from triton_kernels .matmul_ogs import (FnSpecs , FusedActivation ,
15
- PrecisionConfig , matmul_ogs )
15
+ from triton_kernels .matmul_ogs import FnSpecs , FusedActivation , matmul_ogs
16
16
from triton_kernels .routing import routing
17
17
18
+ if TYPE_CHECKING :
19
+ from triton_kernels .matmul_ogs import PrecisionConfig
20
+
18
21
19
22
def triton_kernel_moe_forward (
20
23
hidden_states : torch .Tensor ,
@@ -33,8 +36,8 @@ def triton_kernel_moe_forward(
33
36
w2_scale : Optional [torch .Tensor ] = None ,
34
37
w1_bias : Optional [torch .Tensor ] = None ,
35
38
w2_bias : Optional [torch .Tensor ] = None ,
36
- w1_precision = None , # PrecisionConfig or None
37
- w2_precision = None , # PrecisionConfig or None
39
+ w1_precision : Optional [ " PrecisionConfig" ] = None ,
40
+ w2_precision : Optional [ " PrecisionConfig" ] = None ,
38
41
a1_scale : Optional [torch .Tensor ] = None ,
39
42
a2_scale : Optional [torch .Tensor ] = None ,
40
43
block_shape : Optional [list [int ]] = None ,
@@ -90,8 +93,8 @@ def triton_kernel_fused_experts(
90
93
w2_scale : Optional [torch .Tensor ] = None ,
91
94
w1_bias : Optional [torch .Tensor ] = None ,
92
95
w2_bias : Optional [torch .Tensor ] = None ,
93
- w1_precision = None , # PrecisionConfig or None
94
- w2_precision = None , # PrecisionConfig or None
96
+ w1_precision : Optional [ " PrecisionConfig" ] = None ,
97
+ w2_precision : Optional [ " PrecisionConfig" ] = None ,
95
98
a1_scale : Optional [torch .Tensor ] = None ,
96
99
a2_scale : Optional [torch .Tensor ] = None ,
97
100
block_shape : Optional [list [int ]] = None ,
@@ -141,8 +144,14 @@ def triton_kernel_fused_experts(
141
144
142
145
class BatchedOAITritonExperts (mk .FusedMoEPermuteExpertsUnpermute ):
143
146
144
- def __init__ (self , quant_config , max_num_tokens : int , num_dispatchers : int ,
145
- w1_precision : PrecisionConfig , w2_precision : PrecisionConfig ):
147
+ def __init__ (
148
+ self ,
149
+ quant_config ,
150
+ max_num_tokens : int ,
151
+ num_dispatchers : int ,
152
+ w1_precision : "PrecisionConfig" ,
153
+ w2_precision : "PrecisionConfig" ,
154
+ ):
146
155
super ().__init__ (quant_config )
147
156
self .max_num_tokens = max_num_tokens
148
157
self .num_dispatchers = num_dispatchers
0 commit comments