Skip to content

Commit f756a68

Browse files
zyongyeWoosukKwon
andauthored
[gpt-oss] guard import when triton kernel is not installed (#22529)
Signed-off-by: Yongye Zhu <[email protected]> Signed-off-by: Woosuk Kwon <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]>
1 parent f0964e2 commit f756a68

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import Any, Optional
3+
from typing import TYPE_CHECKING, Any, Optional
44

55
import torch
66

77
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
88
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
99
TopKWeightAndReduceDelegate)
1010
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
11+
from vllm.utils import has_triton_kernels
1112

12-
if True:
13+
if has_triton_kernels():
1314
import triton_kernels.swiglu
14-
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
15-
PrecisionConfig, matmul_ogs)
15+
from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
1616
from triton_kernels.routing import routing
1717

18+
if TYPE_CHECKING:
19+
from triton_kernels.matmul_ogs import PrecisionConfig
20+
1821

1922
def triton_kernel_moe_forward(
2023
hidden_states: torch.Tensor,
@@ -33,8 +36,8 @@ def triton_kernel_moe_forward(
3336
w2_scale: Optional[torch.Tensor] = None,
3437
w1_bias: Optional[torch.Tensor] = None,
3538
w2_bias: Optional[torch.Tensor] = None,
36-
w1_precision=None, # PrecisionConfig or None
37-
w2_precision=None, # PrecisionConfig or None
39+
w1_precision: Optional["PrecisionConfig"] = None,
40+
w2_precision: Optional["PrecisionConfig"] = None,
3841
a1_scale: Optional[torch.Tensor] = None,
3942
a2_scale: Optional[torch.Tensor] = None,
4043
block_shape: Optional[list[int]] = None,
@@ -90,8 +93,8 @@ def triton_kernel_fused_experts(
9093
w2_scale: Optional[torch.Tensor] = None,
9194
w1_bias: Optional[torch.Tensor] = None,
9295
w2_bias: Optional[torch.Tensor] = None,
93-
w1_precision=None, # PrecisionConfig or None
94-
w2_precision=None, # PrecisionConfig or None
96+
w1_precision: Optional["PrecisionConfig"] = None,
97+
w2_precision: Optional["PrecisionConfig"] = None,
9598
a1_scale: Optional[torch.Tensor] = None,
9699
a2_scale: Optional[torch.Tensor] = None,
97100
block_shape: Optional[list[int]] = None,
@@ -141,8 +144,14 @@ def triton_kernel_fused_experts(
141144

142145
class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
143146

144-
def __init__(self, quant_config, max_num_tokens: int, num_dispatchers: int,
145-
w1_precision: PrecisionConfig, w2_precision: PrecisionConfig):
147+
def __init__(
148+
self,
149+
quant_config,
150+
max_num_tokens: int,
151+
num_dispatchers: int,
152+
w1_precision: "PrecisionConfig",
153+
w2_precision: "PrecisionConfig",
154+
):
146155
super().__init__(quant_config)
147156
self.max_num_tokens = max_num_tokens
148157
self.num_dispatchers = num_dispatchers

0 commit comments

Comments
 (0)