Skip to content

Commit 9912b8c

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Build] Add OpenAI triton_kernels (#28788)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent 49ef847 commit 9912b8c

File tree

6 files changed

+119
-1
lines changed

6 files changed

+119
-1
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
# vllm-flash-attn built from source
55
vllm/vllm_flash_attn/*
66

7+
# OpenAI triton kernels copied from source
8+
vllm/third_party/triton_kernels/*
9+
710
# triton jit
811
.triton
912

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,11 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
10301030
WITH_SOABI)
10311031
endif()
10321032

1033+
# For CUDA and HIP builds also build the triton_kernels external package.
1034+
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
1035+
include(cmake/external_projects/triton_kernels.cmake)
1036+
endif()
1037+
10331038
# For CUDA we also build and ship some external projects.
10341039
if (VLLM_GPU_LANG STREQUAL "CUDA")
10351040
include(cmake/external_projects/flashmla.cmake)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Install OpenAI triton_kernels from https://github.com/triton-lang/triton/tree/main/python/triton_kernels
2+
3+
set(DEFAULT_TRITON_KERNELS_TAG "v3.5.0")
4+
5+
# Set TRITON_KERNELS_SRC_DIR for use with local development with vLLM. We expect TRITON_KERNELS_SRC_DIR to
6+
# be directly set to the triton_kernels python directory.
7+
if (DEFINED ENV{TRITON_KERNELS_SRC_DIR})
8+
message(STATUS "[triton_kernels] Fetch from $ENV{TRITON_KERNELS_SRC_DIR}")
9+
FetchContent_Declare(
10+
triton_kernels
11+
SOURCE_DIR $ENV{TRITON_KERNELS_SRC_DIR}
12+
)
13+
14+
else()
15+
set(TRITON_GIT "https://github.com/triton-lang/triton.git")
16+
message (STATUS "[triton_kernels] Fetch from ${TRITON_GIT}:${DEFAULT_TRITON_KERNELS_TAG}")
17+
FetchContent_Declare(
18+
triton_kernels
19+
# TODO (varun) : Fetch just the triton_kernels directory from Triton
20+
GIT_REPOSITORY https://github.com/triton-lang/triton.git
21+
GIT_TAG ${DEFAULT_TRITON_KERNELS_TAG}
22+
GIT_PROGRESS TRUE
23+
SOURCE_SUBDIR python/triton_kernels/triton_kernels
24+
)
25+
endif()
26+
27+
# Fetch content
28+
FetchContent_MakeAvailable(triton_kernels)
29+
30+
if (NOT triton_kernels_SOURCE_DIR)
31+
message (FATAL_ERROR "[triton_kernels] Cannot resolve triton_kernels_SOURCE_DIR")
32+
endif()
33+
34+
if (DEFINED ENV{TRITON_KERNELS_SRC_DIR})
35+
set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/")
36+
else()
37+
set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/python/triton_kernels/triton_kernels/")
38+
endif()
39+
40+
message (STATUS "[triton_kernels] triton_kernels is available at ${TRITON_KERNELS_PYTHON_DIR}")
41+
42+
add_custom_target(triton_kernels)
43+
44+
# Ensure the vllm/third_party directory exists before installation
45+
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/third_party/triton_kernels\")")
46+
47+
## Copy .py files to install directory.
48+
install(DIRECTORY
49+
${TRITON_KERNELS_PYTHON_DIR}
50+
DESTINATION
51+
vllm/third_party/triton_kernels/
52+
COMPONENT triton_kernels
53+
FILES_MATCHING PATTERN "*.py")

setup.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,20 @@ def run(self):
299299
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
300300
self.copy_file(file, dst_file)
301301

302+
if _is_cuda() or _is_hip():
303+
# copy vllm/third_party/triton_kernels/**/*.py from self.build_lib
304+
# to current directory so that they can be included in the editable
305+
# build
306+
print(
307+
f"Copying {self.build_lib}/vllm/third_party/triton_kernels "
308+
"to vllm/third_party/triton_kernels"
309+
)
310+
shutil.copytree(
311+
f"{self.build_lib}/vllm/third_party/triton_kernels",
312+
"vllm/third_party/triton_kernels",
313+
dirs_exist_ok=True,
314+
)
315+
302316

303317
class precompiled_build_ext(build_ext):
304318
"""Disables extension building when using precompiled binaries."""
@@ -633,6 +647,9 @@ def _read_requirements(filename: str) -> list[str]:
633647
if _is_cuda() or _is_hip():
634648
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
635649
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
650+
# Optional since this doesn't get built (produce an .so file). This is just
651+
# copying the relevant .py files from the source repository.
652+
ext_modules.append(CMakeExtension(name="vllm.triton_kernels", optional=True))
636653

637654
if _is_hip():
638655
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))

vllm/model_executor/layers/quantization/utils/mxfp4_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from vllm.logger import init_logger
99
from vllm.platforms import current_platform
1010
from vllm.triton_utils import triton
11+
from vllm.utils.import_utils import has_triton_kernels
1112
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
1213

1314
logger = init_logger(__name__)
1415

1516

1617
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
1718
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
19+
assert has_triton_kernels()
1820
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
1921
from triton_kernels.numerics import InFlexData
2022
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor

vllm/utils/import_utils.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
import regex as re
1919
from typing_extensions import Never
2020

21+
from vllm.logger import init_logger
22+
23+
logger = init_logger(__name__)
24+
2125

2226
# TODO: This function can be removed if transformer_modules classes are
2327
# serialized by value when communicating between processes
@@ -62,6 +66,35 @@ def import_pynvml():
6266
return pynvml
6367

6468

69+
@cache
70+
def import_triton_kernels():
71+
"""
72+
For convenience, prioritize triton_kernels that is available in
73+
`site-packages`. Use `vllm.third_party.triton_kernels` as a fall-back.
74+
"""
75+
if _has_module("triton_kernels"):
76+
import triton_kernels
77+
78+
logger.debug_once(
79+
f"Loading module triton_kernels from {triton_kernels.__file__}.",
80+
scope="local",
81+
)
82+
elif _has_module("vllm.third_party.triton_kernels"):
83+
import vllm.third_party.triton_kernels as triton_kernels
84+
85+
logger.debug_once(
86+
f"Loading module triton_kernels from {triton_kernels.__file__}.",
87+
scope="local",
88+
)
89+
sys.modules["triton_kernels"] = triton_kernels
90+
else:
91+
logger.info_once(
92+
"triton_kernels unavailable in this build. "
93+
"Please consider installing triton_kernels from "
94+
"https://github.com/triton-lang/triton/tree/main/python/triton_kernels"
95+
)
96+
97+
6598
def import_from_path(module_name: str, file_path: str | os.PathLike):
6699
"""
67100
Import a Python file according to its file path.
@@ -397,7 +430,12 @@ def has_deep_gemm() -> bool:
397430

398431
def has_triton_kernels() -> bool:
399432
"""Whether the optional `triton_kernels` package is available."""
400-
return _has_module("triton_kernels")
433+
is_available = _has_module("triton_kernels") or _has_module(
434+
"vllm.third_party.triton_kernels"
435+
)
436+
if is_available:
437+
import_triton_kernels()
438+
return is_available
401439

402440

403441
def has_tilelang() -> bool:

0 commit comments

Comments
 (0)