File tree Expand file tree Collapse file tree 9 files changed +108
-2
lines changed
third_party/proton/dialect Expand file tree Collapse file tree 9 files changed +108
-2
lines changed Original file line number Diff line number Diff line change 33#include " amd/include/TritonAMDGPUTransforms/Passes.h"
44#include " third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
55#include " third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
6+ #include " third_party/proton/dialect/include/TritonProtonToLLVM/Passes.h"
67#include " triton/Dialect/Triton/IR/Dialect.h"
78#include " triton/Dialect/TritonGPU/IR/Dialect.h"
89#include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
@@ -68,6 +69,9 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) {
6869 mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints ();
6970 mlir::triton::registerTritonAMDGPULowerInstructionSchedHints ();
7071
72+ // Proton passes
73+ mlir::triton::registerProtonLoweringPass ();
74+
7175 // TODO: register Triton & TritonGPU passes
7276 registry
7377 .insert <mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Original file line number Diff line number Diff line change 1- // RUN: triton-opt --split-input-file %s -cse -canonicalize | FileCheck %s
1+ // RUN: triton-opt --split-input-file %s -cse -canonicalize --proton-lowering-pass | FileCheck %s
22
33module {
4- // CHECK-LABEL: proton_record
54 tt.func @proton_record () {
65 // CHECK: proton.record() {isStart = true, regionId = 1 : i32}
76 // CHECK-NEXT: proton.record() {isStart = false, regionId = 1 : i32}
Original file line number Diff line number Diff line change 11add_subdirectory (Dialect)
2+ add_subdirectory (TritonProtonToLLVM)
Original file line number Diff line number Diff line change 1+ set (LLVM_TARGET_DEFINITIONS Passes.td)
2+ mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonProtonToLLVM)
3+ add_public_tablegen_target(TritonProtonConversionPassIncGen)
Original file line number Diff line number Diff line change 1+ #ifndef TRITON_THIRD_PARTY_PROTON_INCLUDE_TRITONPROTONGPUTOLLVM_PASSES_H_
2+ #define TRITON_THIRD_PARTY_PROTON_INCLUDE_TRITONPROTONGPUTOLLVM_PASSES_H_
3+
4+ #include " mlir/Conversion/LLVMCommon/TypeConverter.h"
5+ #include " mlir/Pass/Pass.h"
6+ #include " mlir/Transforms/DialectConversion.h"
7+
8+ namespace mlir {
9+
10+ class ModuleOp ;
11+ template <typename T> class OperationPass ;
12+
13+ } // namespace mlir
14+
15+ namespace mlir ::triton {
16+ std::unique_ptr<OperationPass<ModuleOp>> createProtonLoweringPass ();
17+
18+ #define GEN_PASS_REGISTRATION
19+ #include " ../third_party/proton/dialect/include/TritonProtonToLLVM/Passes.h.inc"
20+
21+ } // namespace mlir::triton
22+
23+ #endif
Original file line number Diff line number Diff line change 1+ #ifndef TRITONPROTON_CONVERSION_PASSES
2+ #define TRITONPROTON_CONVERSION_PASSES
3+
4+ include "mlir/Pass/PassBase.td"
5+
6+ def ProtonLoweringPass : Pass<"proton-lowering-pass", "mlir::ModuleOp"> {
7+ let constructor = "mlir::triton::createProtonLoweringPass()";
8+
9+ }
10+ #endif
Original file line number Diff line number Diff line change 11add_triton_library(TritonProtonToLLVM
22 RecordOpToLLVM.cpp
3+ ProtonLoweringPass.cpp
4+
5+ DEPENDS
6+ TritonProtonConversionPassIncGen
37
48 LINK_LIBS PUBLIC
59 ProtonIR
Original file line number Diff line number Diff line change 1+ #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
2+ #include " mlir/Pass/Pass.h"
3+ #include " triton/Analysis/Allocation.h"
4+ #include " triton/Analysis/Utility.h"
5+ #include " triton/Conversion/TritonGPUToLLVM/Passes.h"
6+ #include " triton/Conversion/TritonGPUToLLVM/Utility.h"
7+ #include " triton/Dialect/Triton/IR/Dialect.h"
8+ #include " triton/Dialect/TritonGPU/IR/Dialect.h"
9+
10+ #include " third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
11+
12+ using namespace mlir ;
13+ using namespace mlir ::triton;
14+
15+ namespace mlir {
16+ namespace triton {
17+ #define GEN_PASS_DEF_PROTONLOWERINGPASS
18+ #include " ../third_party/proton/dialect/include/TritonProtonToLLVM/Passes.h.inc"
19+ } // namespace triton
20+ } // namespace mlir
21+
22+ namespace {
23+ struct ProtonLoweringPass
24+ : public mlir::triton::impl::ProtonLoweringPassBase<ProtonLoweringPass> {
25+ void runOnOperation () override {
26+ ModuleOp mod = getOperation ();
27+ ModuleAllocation allocation (mod);
28+
29+ OpBuilder b (mod.getBodyRegion ());
30+ MLIRContext *context = &getContext ();
31+ auto loc = mod.getLoc ();
32+
33+ /* Add Proton Op Lowerings Here*/
34+ }
35+ };
36+
37+ } // namespace
38+
39+ namespace mlir {
40+
41+ namespace triton {
42+
43+ std::unique_ptr<OperationPass<ModuleOp>> createProtonLoweringPass () {
44+ return std::make_unique<ProtonLoweringPass>();
45+ }
46+
47+ } // namespace triton
48+
49+ } // namespace mlir
Original file line number Diff line number Diff line change 55#include < pybind11/stl.h>
66#include < pybind11/stl_bind.h>
77
8+ #include " ../third_party/proton/dialect/include/TritonProtonToLLVM/Passes.h"
9+
810namespace py = pybind11;
911
12+ namespace {
13+
14+ void init_triton_proton_passes_ttgpuir (py::module &&m) {
15+ using namespace mlir ::triton;
16+ m.def (" add_proton_lowering_pass" ,
17+ [](mlir::PassManager &pm) { pm.addPass (createProtonLoweringPass ()); });
18+ }
19+ } // namespace
20+
1021void init_triton_proton (py::module &&m) {
22+ m.doc () = " Python bindings to the Proton backend" ;
1123 auto passes = m.def_submodule (" passes" );
24+ init_triton_proton_passes_ttgpuir (passes.def_submodule (" ttgpuir" ));
1225
1326 // load dialects
1427 m.def (" load_dialects" , [](mlir::MLIRContext &context) {
You can’t perform that action at this time.
0 commit comments