Skip to content

Commit 734d9f2

Browse files
authored
[PIPELINER] Implement a loop nest fusion pass (#5550)
This PR adds a `FusedNestedLoopsPass` that analyzes loop nests in the problem and attempts to fuse them into a single loop. This pass is meant to work together with the pipeliner to enable pipelining loop nests without manual fusion on the user's part. Eventually, the logic in this pass will get spliced into the pipeliner, which will allow fusion and pipelining of data-dependent inner loop bounds. The pass is currently written to generate IR amenable to the pipeliner, but is not currently turned on. It will be placed right before the loop scheduling pass.
1 parent d8ae341 commit 734d9f2

File tree

5 files changed

+1266
-14
lines changed

5 files changed

+1266
-14
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,22 @@ def TritonGPUTestPipelineScheduleLoop : Pass<"tritongpu-test-pipeline-schedule-l
5555
"mlir::arith::ArithDialect"];
5656
}
5757

58+
def TritonGPUFuseNestedLoops : Pass<"tritongpu-fuse-nested-loops", "mlir::ModuleOp"> {
59+
let summary = "fuse nested loops for pipelining";
60+
61+
let description = [{
62+
The `tritongpu-fuse-nested-loops` pass will analyze loop nests in the module
63+
that need to be pipelined and fuse them into a single loop. This composes
64+
with the pipeliner to pipeline loop nests.
65+
}];
66+
67+
let dependentDialects = [
68+
"mlir::triton::gpu::TritonGPUDialect",
69+
"mlir::arith::ArithDialect",
70+
"mlir::ub::UBDialect",
71+
];
72+
}
73+
5874
def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
5975
let summary = "3xTF32 trick";
6076

lib/Conversion/TritonGPUToLLVM/CMakeLists.txt

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
add_triton_library(TritonGPUToLLVM
22
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
33
DotOpToLLVM/FMA.cpp
4-
GlobalScratchMemoryAllocation.cpp
5-
TypeConverter.cpp
6-
Utility.cpp
7-
ElementwiseOpToLLVM.cpp
8-
MemoryOpToLLVM.cpp
4+
AllocateSharedMemory.cpp
95
AssertOpToLLVM.cpp
10-
ViewOpToLLVM.cpp
11-
MakeRangeOpToLLVM.cpp
6+
ControlFlowOpToLLVM.cpp
7+
ConvertLayoutOpToLLVM.cpp
8+
DecomposeUnsupportedConversions.cpp
9+
ElementwiseOpToLLVM.cpp
10+
FuncOpToLLVM.cpp
11+
GatherOpToLLVM.cpp
12+
GlobalScratchMemoryAllocation.cpp
1213
HistogramOpToLLVM.cpp
13-
AllocateSharedMemory.cpp
14+
MakeRangeOpToLLVM.cpp
15+
MemoryOpToLLVM.cpp
16+
PrintOpToLLVM.cpp
1417
ReduceOpToLLVM.cpp
1518
ScanOpToLLVM.cpp
16-
GatherOpToLLVM.cpp
17-
ConvertLayoutOpToLLVM.cpp
18-
ControlFlowOpToLLVM.cpp
19-
FuncOpToLLVM.cpp
2019
SPMDOpToLLVM.cpp
21-
DecomposeUnsupportedConversions.cpp
22-
PrintOpToLLVM.cpp
20+
TypeConverter.cpp
21+
Utility.cpp
22+
ViewOpToLLVM.cpp
2323

2424
DEPENDS
2525
TritonGPUConversionPassIncGen

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_triton_library(TritonGPUTransforms
22
AccelerateMatmul.cpp
33
Coalesce.cpp
44
F32DotTC.cpp
5+
FuseNestedLoops.cpp
56
CombineTensorSelectAndIf.cpp
67
LoopScheduling.cpp
78
ReduceDataDuplication.cpp

0 commit comments

Comments
 (0)