Skip to content

Commit e9db186

Browse files
fywkevinYuanwei FangJokeren
authored
[PROTON] Introduce the Proton dialect as a third-party plugin for intra-kernel perf tooling (#5119)
This PR introduces the `Proton Dialect` to enable intra kernel profiling and tooling for Triton. As a third-party dialect, it serves as the building blocks to create 3rd-party perf tools (e.g., profilers, analysis, modeling) for Triton compiler developers in a compiler-centric way, such as an intra-kernel latency profiler to understand software pipelining, warp specialization, and CTA fine-grained orchestration (e.g., cuda core, tensor core, TMA). Future developments would integrate this dialect with the existing Proton backend profiling infrastructure to make it a powerful and general perf tool utility. As a first step, this PR adds some basic boilerplate code and mechanics, and the `proton.record` op for the `Proton Dialect`. --------- Co-authored-by: Yuanwei Fang <[email protected]> Co-authored-by: Keren Zhou <[email protected]>
1 parent ad28e6c commit e9db186

File tree

19 files changed

+269
-8
lines changed

19 files changed

+269
-8
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ if(TRITON_BUILD_PYTHON_MODULE)
206206
if (TRITON_BUILD_PROTON)
207207
add_subdirectory(third_party/proton)
208208
endif()
209+
# We always build proton dialect
210+
list(APPEND TRITON_PLUGIN_NAMES "proton")
211+
add_subdirectory(third_party/proton/dialect)
209212

210213
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
211214
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
@@ -311,6 +314,7 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
311314
foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
312315
add_subdirectory(third_party/${CODEGEN_BACKEND})
313316
endforeach()
317+
add_subdirectory(third_party/proton/dialect)
314318
endif()
315319

316320
add_subdirectory(third_party/f2reduce)

bin/RegisterTritonDialects.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
33
#include "amd/include/TritonAMDGPUTransforms/Passes.h"
44
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
5+
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
56
#include "triton/Dialect/Triton/IR/Dialect.h"
67
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
78
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
@@ -68,12 +69,13 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6869
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
6970

7071
// TODO: register Triton & TritonGPU passes
71-
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
72-
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
73-
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
74-
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
75-
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect,
76-
mlir::NVVM::NVVMDialect, mlir::triton::nvgpu::NVGPUDialect,
77-
mlir::triton::amdgpu::TritonAMDGPUDialect,
78-
mlir::ROCDL::ROCDLDialect>();
72+
registry
73+
.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
74+
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
75+
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
76+
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
77+
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect,
78+
mlir::NVVM::NVVMDialect, mlir::triton::nvgpu::NVGPUDialect,
79+
mlir::triton::amdgpu::TritonAMDGPUDialect,
80+
mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect>();
7981
}

test/Proton/ops.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: triton-opt --split-input-file %s -cse -canonicalize | FileCheck %s
2+
3+
module {
4+
// CHECK-LABEL: proton_record
5+
tt.func @proton_record() {
6+
// CHECK: proton.record() {isStart = true, regionId = 1 : i32}
7+
// CHECK-NEXT: proton.record() {isStart = false, regionId = 1 : i32}
8+
// CHECK-NEXT: tt.return
9+
proton.record() {isStart = true, regionId = 1 : i32}
10+
proton.record() {isStart = false, regionId = 1 : i32}
11+
tt.return
12+
}
13+
} // end module
14+
15+
// -----
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
2+
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
3+
add_subdirectory(include)
4+
add_subdirectory(lib)
5+
if(TRITON_BUILD_PYTHON_MODULE)
6+
add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc LINK_LIBS ProtonIR)
7+
endif()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(Dialect)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(Proton)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(IR)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})
2+
3+
set(LLVM_TARGET_DEFINITIONS ProtonOps.td)
4+
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton)
5+
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton)
6+
mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions)
7+
mlir_tablegen(Ops.h.inc -gen-op-decls)
8+
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
9+
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
10+
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
11+
add_mlir_doc(ProtonDialect ProtonDialect dialects/ -gen-dialect-doc)
12+
add_mlir_doc(ProtonOps ProtonOps dialects/ -gen-op-doc)
13+
add_public_tablegen_target(ProtonTableGen)
14+
15+
set(LLVM_TARGET_DEFINITIONS ProtonAttrDefs.td)
16+
mlir_tablegen(ProtonAttrDefs.h.inc -gen-attrdef-decls)
17+
mlir_tablegen(ProtonAttrDefs.cpp.inc -gen-attrdef-defs)
18+
add_public_tablegen_target(ProtonAttrDefsIncGen)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef TRITON_DIALECT_PROTON_IR_DIALECT_H_
2+
#define TRITON_DIALECT_PROTON_IR_DIALECT_H_
3+
4+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5+
#include "mlir/IR/BuiltinOps.h"
6+
#include "mlir/IR/Dialect.h"
7+
#include "mlir/IR/PatternMatch.h"
8+
#include "proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc"
9+
#include "proton/dialect/include/Dialect/Proton/IR/OpsEnums.h.inc"
10+
11+
#define GET_ATTRDEF_CLASSES
12+
#include "proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc"
13+
14+
#define GET_OP_CLASSES
15+
#include "proton/dialect/include/Dialect/Proton/IR/Ops.h.inc"
16+
17+
namespace mlir {
18+
namespace triton {
19+
namespace proton {} // namespace proton
20+
} // namespace triton
21+
} // namespace mlir
22+
23+
#endif // TRITON_DIALECT_PROTON_IR_DIALECT_H_
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef PROTON_ATTRDEFS
2+
#define PROTON_ATTRDEFS
3+
4+
include "mlir/IR/AttrTypeBase.td"
5+
include "ProtonDialect.td"
6+
7+
class Proton_Attr<string name, list<Trait> traits = [],
8+
string baseCppClass = "::mlir::Attribute">
9+
: AttrDef<Proton_Dialect, name, traits, baseCppClass> {
10+
}
11+
12+
#endif // PROTON_ATTRDEFS

0 commit comments

Comments
 (0)