Skip to content

Commit df7a403

Browse files
authored
[PROTON-DEV] Add Proton Lowering Pass (#5847)
1 parent 58225df commit df7a403

File tree

9 files changed

+108
-2
lines changed

9 files changed

+108
-2
lines changed

bin/RegisterTritonDialects.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "amd/include/TritonAMDGPUTransforms/Passes.h"
44
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
55
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
6+
#include "third_party/proton/dialect/include/TritonProtonToLLVM/Passes.h"
67
#include "triton/Dialect/Triton/IR/Dialect.h"
78
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
89
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
@@ -68,6 +69,9 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6869
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
6970
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
7071

72+
// Proton passes
73+
mlir::triton::registerProtonLoweringPass();
74+
7175
// TODO: register Triton & TritonGPU passes
7276
registry
7377
.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,

test/Proton/ops.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
// RUN: triton-opt --split-input-file %s -cse -canonicalize | FileCheck %s
1+
// RUN: triton-opt --split-input-file %s -cse -canonicalize --proton-lowering-pass | FileCheck %s
22

33
module {
4-
// CHECK-LABEL: proton_record
54
tt.func @proton_record() {
65
// CHECK: proton.record() {isStart = true, regionId = 1 : i32}
76
// CHECK-NEXT: proton.record() {isStart = false, regionId = 1 : i32}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(Dialect)
2+
add_subdirectory(TritonProtonToLLVM)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonProtonToLLVM)
3+
add_public_tablegen_target(TritonProtonConversionPassIncGen)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef TRITON_THIRD_PARTY_PROTON_INCLUDE_TRITONPROTONGPUTOLLVM_PASSES_H_
2+
#define TRITON_THIRD_PARTY_PROTON_INCLUDE_TRITONPROTONGPUTOLLVM_PASSES_H_
3+
4+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
5+
#include "mlir/Pass/Pass.h"
6+
#include "mlir/Transforms/DialectConversion.h"
7+
8+
namespace mlir {
9+
10+
class ModuleOp;
11+
template <typename T> class OperationPass;
12+
13+
} // namespace mlir
14+
15+
namespace mlir::triton {
16+
std::unique_ptr<OperationPass<ModuleOp>> createProtonLoweringPass();
17+
18+
#define GEN_PASS_REGISTRATION
19+
#include "../third_party/proton/dialect/include/TritonProtonToLLVM/Passes.h.inc"
20+
21+
} // namespace mlir::triton
22+
23+
#endif
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#ifndef TRITONPROTON_CONVERSION_PASSES
2+
#define TRITONPROTON_CONVERSION_PASSES
3+
4+
include "mlir/Pass/PassBase.td"
5+
6+
def ProtonLoweringPass : Pass<"proton-lowering-pass", "mlir::ModuleOp"> {
7+
let constructor = "mlir::triton::createProtonLoweringPass()";
8+
9+
}
10+
#endif

third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
add_triton_library(TritonProtonToLLVM
22
RecordOpToLLVM.cpp
3+
ProtonLoweringPass.cpp
4+
5+
DEPENDS
6+
TritonProtonConversionPassIncGen
37

48
LINK_LIBS PUBLIC
59
ProtonIR
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2+
#include "mlir/Pass/Pass.h"
3+
#include "triton/Analysis/Allocation.h"
4+
#include "triton/Analysis/Utility.h"
5+
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
6+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
7+
#include "triton/Dialect/Triton/IR/Dialect.h"
8+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
9+
10+
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
11+
12+
using namespace mlir;
13+
using namespace mlir::triton;
14+
15+
namespace mlir {
16+
namespace triton {
17+
#define GEN_PASS_DEF_PROTONLOWERINGPASS
18+
#include "../third_party/proton/dialect/include/TritonProtonToLLVM/Passes.h.inc"
19+
} // namespace triton
20+
} // namespace mlir
21+
22+
namespace {
23+
struct ProtonLoweringPass
24+
: public mlir::triton::impl::ProtonLoweringPassBase<ProtonLoweringPass> {
25+
void runOnOperation() override {
26+
ModuleOp mod = getOperation();
27+
ModuleAllocation allocation(mod);
28+
29+
OpBuilder b(mod.getBodyRegion());
30+
MLIRContext *context = &getContext();
31+
auto loc = mod.getLoc();
32+
33+
/*Add Proton Op Lowerings Here*/
34+
}
35+
};
36+
37+
} // namespace
38+
39+
namespace mlir {
40+
41+
namespace triton {
42+
43+
std::unique_ptr<OperationPass<ModuleOp>> createProtonLoweringPass() {
44+
return std::make_unique<ProtonLoweringPass>();
45+
}
46+
47+
} // namespace triton
48+
49+
} // namespace mlir

third_party/proton/dialect/triton_proton.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,23 @@
55
#include <pybind11/stl.h>
66
#include <pybind11/stl_bind.h>
77

8+
#include "../third_party/proton/dialect/include/TritonProtonToLLVM/Passes.h"
9+
810
namespace py = pybind11;
911

12+
namespace {
13+
14+
void init_triton_proton_passes_ttgpuir(py::module &&m) {
15+
using namespace mlir::triton;
16+
m.def("add_proton_lowering_pass",
17+
[](mlir::PassManager &pm) { pm.addPass(createProtonLoweringPass()); });
18+
}
19+
} // namespace
20+
1021
void init_triton_proton(py::module &&m) {
22+
m.doc() = "Python bindings to the Proton backend";
1123
auto passes = m.def_submodule("passes");
24+
init_triton_proton_passes_ttgpuir(passes.def_submodule("ttgpuir"));
1225

1326
// load dialects
1427
m.def("load_dialects", [](mlir::MLIRContext &context) {

0 commit comments

Comments
 (0)