diff --git a/lib/Dialect/Triton/Transforms/LoopUnroll.cpp b/lib/Dialect/Triton/Transforms/LoopUnroll.cpp index 257e734b7f88..cb25d41a2548 100644 --- a/lib/Dialect/Triton/Transforms/LoopUnroll.cpp +++ b/lib/Dialect/Triton/Transforms/LoopUnroll.cpp @@ -22,8 +22,6 @@ namespace mlir::triton { -static const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor"; - namespace { class LoopUnrollPass : public TritonLoopUnrollBase { @@ -31,12 +29,15 @@ class LoopUnrollPass : public TritonLoopUnrollBase { int getUnrollFactorOrDefault(scf::ForOp forOp) { // Use the attribute attached to the loop if it exists otherwise set the // factor to 1 to suppress the unrolling. - if (auto factor = forOp->getAttrOfType( - mlir::triton::loopUnrollFactorAttrName)) + if (auto factor = + forOp->getAttrOfType(loopUnrollFactorAttrName)) return factor.getInt(); return 1; } + const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor"; + const char *pipelineStagesAttrName = "tt.num_stages"; + public: LoopUnrollPass() = default; LoopUnrollPass(const LoopUnrollPass &) {} @@ -49,11 +50,18 @@ class LoopUnrollPass : public TritonLoopUnrollBase { loops.push_back(forOp); }); + auto ctx = getOperation()->getContext(); for (auto loop : loops) { auto unrollFactor = getUnrollFactorOrDefault(loop); - loop->removeAttr(mlir::triton::loopUnrollFactorAttrName); + loop->removeAttr(loopUnrollFactorAttrName); LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop); - (void)loopUnrollByFactor(loop, unrollFactor); + auto resultLoops = loopUnrollByFactor(loop, unrollFactor); + // Do not pipeline the epilog loop. + if (succeeded(resultLoops) && resultLoops->epilogueLoopOp) { + (*resultLoops->epilogueLoopOp) + ->setAttr(pipelineStagesAttrName, + mlir::IntegerAttr::get(IntegerType::get(ctx, 32), 1)); + } } } }; diff --git a/test/Triton/loop-unroll.mlir b/test/Triton/loop-unroll.mlir index 9166630281e6..531a14fffad3 100644 --- a/test/Triton/loop-unroll.mlir +++ b/test/Triton/loop-unroll.mlir @@ -13,6 +13,7 @@ tt.func @add_kernel_unroll(%arg0: tensor<256x!tt.ptr>, %arg1: i32) { // CHECK: scf.for // CHECK: tt.load // CHECK-NOT: tt.load + // CHECK: tt.num_stages = 1 : i32 %2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr>) : i32 { %3 = tt.load %arg5 : tensor<256x!tt.ptr> %4 = arith.addf %arg4, %3 : tensor<256xf32>