Skip to content

Commit d144cf5

Browse files
committed
[MLIR][MemRef] Nested allocation scope inlining
If a stack allocation is within a nested allocation scope don't count that as an allocation of the outer allocation scope that would prevent inlining. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D121981
1 parent ec10ac7 commit d144cf5

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -311,25 +311,28 @@ struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
311311

312312
LogicalResult matchAndRewrite(AllocaScopeOp op,
313313
PatternRewriter &rewriter) const override {
314-
if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>()) {
315-
bool hasPotentialAlloca =
316-
op->walk([&](Operation *alloc) {
317-
if (alloc == op)
318-
return WalkResult::advance();
319-
if (isOpItselfPotentialAutomaticAllocation(alloc))
320-
return WalkResult::interrupt();
314+
bool hasPotentialAlloca =
315+
op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
316+
if (alloc == op)
321317
return WalkResult::advance();
322-
}).wasInterrupted();
323-
if (hasPotentialAlloca)
318+
if (isOpItselfPotentialAutomaticAllocation(alloc))
319+
return WalkResult::interrupt();
320+
if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
321+
return WalkResult::skip();
322+
return WalkResult::advance();
323+
}).wasInterrupted();
324+
325+
// If this contains no potential allocation, it is always legal to
326+
// inline. Otherwise, consider two conditions:
327+
if (hasPotentialAlloca) {
328+
// If the parent isn't an allocation scope, or we are not the last
329+
// non-terminator op in the parent, we will extend the lifetime.
330+
if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
331+
return failure();
332+
if (!lastNonTerminatorInRegion(op))
324333
return failure();
325334
}
326335

327-
// Only apply to if this is this last non-terminator
328-
// op in the block (lest lifetime be extended) of a one
329-
// block region
330-
if (!lastNonTerminatorInRegion(op))
331-
return failure();
332-
333336
Block *block = &op.getRegion().front();
334337
Operation *terminator = block->getTerminator();
335338
ValueRange results = terminator->getOperands();

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,32 @@ func @scopeMerge4() {
644644
// CHECK: return
645645
// CHECK: }
646646

647+
func @scopeMerge5() {
648+
"test.region"() ({
649+
memref.alloca_scope {
650+
affine.parallel (%arg) = (0) to (64) {
651+
%a = memref.alloca(%arg) : memref<?xi64>
652+
"test.use"(%a) : (memref<?xi64>) -> ()
653+
}
654+
}
655+
"test.op"() : () -> ()
656+
"test.terminator"() : () -> ()
657+
}) : () -> ()
658+
return
659+
}
660+
661+
// CHECK: func @scopeMerge5() {
662+
// CHECK: "test.region"() ({
663+
// CHECK: affine.parallel (%[[cnt:.+]]) = (0) to (64) {
664+
// CHECK: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref<?xi64>
665+
// CHECK: "test.use"(%[[alloc]]) : (memref<?xi64>) -> ()
666+
// CHECK: }
667+
// CHECK: "test.op"() : () -> ()
668+
// CHECK: "test.terminator"() : () -> ()
669+
// CHECK: }) : () -> ()
670+
// CHECK: return
671+
// CHECK: }
672+
647673
func @scopeInline(%arg : memref<index>) {
648674
%cnt = "test.count"() : () -> index
649675
"test.region"() ({

0 commit comments

Comments
 (0)