Skip to content

Commit 9e24f0f

Browse files
[mlir][bufferize] Do not deallocate allocs that are returned from a block
Such IR is rejected by default, but can be allowed with `allow-return-memref`. In preparation of future refactorings, do not deallocate such buffers. One-Shot Analysis now gathers information about yielded tensors, so that we know during the actual bufferization whether a newly allocated buffer should be deallocated again. (Otherwise, it will leak. This will be addressed in a subsequent commit that also makes `allow-return-memref` a non-experimental flag.) As a cleanup, `allow-return-memref` is now part of OneShotBufferizationOptions. (It was previously ignored by AlwaysCopyBufferizationState.) Moreover, AlwaysCopyBufferizationState now asserts that `create-deallocs` is deactivated to prevent surprising behavior. Differential Revision: https://reviews.llvm.org/D121521
1 parent fdb41a2 commit 9e24f0f

File tree

8 files changed

+207
-40
lines changed

8 files changed

+207
-40
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,6 @@ struct BufferizationOptions {
177177
Optional<DeallocationFn> deallocationFn;
178178
Optional<MemCpyFn> memCpyFn;
179179

180-
/// Specifies whether returning newly allocated memrefs should be allowed.
181-
/// Otherwise, a pass failure is triggered.
182-
bool allowReturnMemref = false;
183-
184180
/// Specifies whether not bufferizable ops are allowed in the input. If so,
185181
/// bufferization.to_memref and bufferization.to_tensor ops are inserted at
186182
/// the boundaries.
@@ -356,7 +352,14 @@ class AnalysisState {
356352
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
357353
virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0;
358354

359-
/// Return dialect-specific analysis state.
355+
/// Return true if the given tensor (or an aliasing tensor) is yielded from
356+
/// the containing block. Also include all aliasing tensors in the same block.
357+
///
358+
/// Note: In the absence of an analysis, an implementation may return true for
359+
/// any given tensor.
360+
virtual bool isTensorYielded(Value tensor) const = 0;
361+
362+
/// Return dialect-specific bufferization state.
360363
template <typename StateT>
361364
Optional<const StateT *> getDialectState(StringRef name) const {
362365
auto it = dialectState.find(name);
@@ -415,6 +418,10 @@ class AlwaysCopyAnalysisState : public AnalysisState {
415418

416419
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
417420
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
421+
422+
/// Return true if the given tensor (or an aliasing tensor) is yielded from
423+
/// the containing block. Also include all aliasing tensors in the same block.
424+
bool isTensorYielded(Value tensor) const override;
418425
};
419426

420427
/// BufferizationState provides helper functions for performing bufferization
@@ -423,14 +430,20 @@ struct BufferizationState {
423430
BufferizationState(const AnalysisState &analysisState)
424431
: analysisState(analysisState) {}
425432

426-
/// Creates a memref allocation with the given type and dynamic extents.
427-
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
428-
ValueRange dynShape);
429-
430-
/// Creates a memref allocation for the given shaped value. This function may
431-
/// perform additional optimizations such as buffer allocation hoisting.
432-
// TODO: Allocation hoisting should be a cleanup pass.
433-
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue);
433+
/// Creates a memref allocation for the given shaped value. `dealloc`
434+
/// indicates whether the buffer should be deallocated or not. When `dealloc`
435+
/// is `false`, this would create a memory leak, unless the buffer is
436+
/// deallocated through some other mechanism.
437+
///
438+
/// `dealloc` is optional. By default, this function will figure out by itself
439+
/// if it is safe to deallocate the buffer. In essence, when returning the
440+
/// buffer from a block, it is not safe to deallocate the buffer. This
441+
/// information is queried via `AnalysisState::isTensorYielded`.
442+
///
443+
/// Note: `shapedValue` is typically a tensor value. However, if it is a
444+
/// memref value, `dealloc` is no longer optional and must be specified.
445+
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
446+
Optional<bool> dealloc = None);
434447

435448
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
436449
/// a new buffer and copy over data from the existing buffer if out-of-place

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
4343

4444
/// Registered post analysis steps.
4545
PostAnalysisStepList postAnalysisSteps;
46+
47+
/// Specifies whether returning newly allocated memrefs should be allowed.
48+
/// Otherwise, a pass failure is triggered.
49+
bool allowReturnMemref = false;
4650
};
4751

4852
/// The BufferizationAliasInfo class maintains a list of buffer aliases and
@@ -153,10 +157,22 @@ class OneShotAnalysisState : public AnalysisState {
153157
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
154158
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
155159

160+
/// Return true if the given tensor (or an aliasing tensor) is yielded from
161+
/// the containing block. Also include all aliasing tensors in the same block.
162+
bool isTensorYielded(Value tensor) const override;
163+
164+
/// Find all tensors that are yielded/returned from a block and store them in
165+
/// `yieldedTensors`. Also include all aliasing tensors in the same block.
166+
void gatherYieldedTensors(Operation *op);
167+
156168
private:
157169
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
158170
/// functions and `runOneShotBufferize` may access this object.
159171
BufferizationAliasInfo aliasInfo;
172+
173+
/// A set of all tensors (and maybe aliasing tensors) that yielded from a
174+
/// block.
175+
DenseSet<Value> yieldedTensors;
160176
};
161177

162178
/// Analyze `op` and its nested ops. Bufferization decisions are stored in

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,12 @@ constexpr const ::llvm::StringLiteral
4242
constexpr const ::llvm::StringLiteral
4343
bufferization::BufferizableOpInterface::kInplaceableAttrName;
4444

45+
/// Attribute name used to mark allocs that are created by the bufferization.
4546
static const char *kBufferAllocationAttr = "bufferization.allocation";
4647

48+
/// Attribute name used to mark allocs that should not be deallocated.
49+
static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
50+
4751
//===----------------------------------------------------------------------===//
4852
// BufferizationOptions
4953
//===----------------------------------------------------------------------===//
@@ -253,6 +257,8 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
253257
OpBuilder::InsertionGuard guard(rewriter);
254258
Operation *op = opOperand.getOwner();
255259
Location loc = op->getLoc();
260+
SmallVector<OpResult> aliasingOpResults =
261+
analysisState.getAliasingOpResult(opOperand);
256262
Value operand = opOperand.get();
257263
Value operandBuffer = lookupBuffer(rewriter, operand, options);
258264

@@ -263,8 +269,13 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
263269
// Move insertion point right after `operandBuffer`. That is where the
264270
// allocation should be inserted (in the absence of allocation hoisting).
265271
setInsertionPointAfter(rewriter, operandBuffer);
266-
// Allocate the result buffer.
267-
FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer);
272+
// Allocate the result buffer. The buffer should be deallocated if the tensor
273+
// is not yielded and deallocs are enabled in general.
274+
bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
275+
return getAnalysisState().isTensorYielded(v);
276+
});
277+
FailureOr<Value> resultBuffer = createAlloc(
278+
rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
268279
if (failed(resultBuffer))
269280
return failure();
270281
// Do not copy if the last preceding writes of `operand` are ops that do
@@ -281,8 +292,6 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
281292
}))
282293
return resultBuffer;
283294
// Do not copy if the copied data is never read.
284-
SmallVector<OpResult> aliasingOpResults =
285-
analysisState.getAliasingOpResult(opOperand);
286295
if (!aliasingOpResults.empty() &&
287296
!analysisState.bufferizesToMemoryRead(opOperand) &&
288297
llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
@@ -339,7 +348,12 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
339348

340349
AlwaysCopyAnalysisState::AlwaysCopyAnalysisState(
341350
const BufferizationOptions &options)
342-
: AnalysisState(options) {}
351+
: AnalysisState(options) {
352+
// Note: Allocations must be deallocated with a subsequent run of the buffer
353+
// deallocation pass.
354+
assert(!options.createDeallocs &&
355+
"cannot create deallocs with AlwaysCopyBufferizationState");
356+
}
343357

344358
/// Return `true` if the given OpResult has been decided to bufferize inplace.
345359
bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const {
@@ -356,6 +370,13 @@ bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
356370
return false;
357371
}
358372

373+
/// Return true if the given tensor (or an aliasing tensor) is yielded from
374+
/// the containing block. Also include all aliasing tensors in the same block.
375+
bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const {
376+
// There is no analysis, so conservatively answer "true".
377+
return true;
378+
}
379+
359380
//===----------------------------------------------------------------------===//
360381
// Bufferization-specific scoped alloc/dealloc insertion support.
361382
//===----------------------------------------------------------------------===//
@@ -426,37 +447,54 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
426447
}
427448

428449
static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type,
429-
ValueRange dynShape) {
450+
ValueRange dynShape, bool skipDealloc) {
430451
auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape);
431452
allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr());
453+
if (skipDealloc)
454+
allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr());
432455
return allocaOp.getResult();
433456
}
434457

435458
/// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
436459
/// block in case of a bbArg).
437460
FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
438-
Value shapedValue) {
461+
Value shapedValue,
462+
Optional<bool> dealloc) {
439463
// Take a guard before anything else.
440464
OpBuilder::InsertionGuard g(b);
465+
466+
// Compute allocation memref type.
441467
assert(shapedValue.getType().isa<ShapedType>());
442468
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
443469
SmallVector<Value> dynShape;
444470
MemRefType allocMemRefType =
445471
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
446-
Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape);
472+
473+
// Should be the buffer be deallocated again or should we let it leak?
474+
bool skipDealloc;
475+
if (dealloc) {
476+
skipDealloc = !dealloc.getValue();
477+
} else {
478+
assert(shapedValue.getType().isa<TensorType>() &&
479+
"must specify `dealloc` if non-tensor value is passed");
480+
// Buffer should be not be deallocated if deallocs are generally deactivated
481+
// or if the tensor is yielded from a block.
482+
skipDealloc = !getOptions().createDeallocs ||
483+
getAnalysisState().isTensorYielded(shapedValue);
484+
}
485+
486+
// Create the buffer allocation.
487+
Value alloc =
488+
createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc);
489+
490+
// Insert a cast if a different type was requested.
447491
if (memRefType && memRefType != allocMemRefType) {
448-
assert(memref::CastOp::areCastCompatible(alloc.getType(), memRefType) &&
492+
assert(memref::CastOp::areCastCompatible(allocMemRefType, memRefType) &&
449493
"createAlloc: cast incompatible");
450494
alloc = b.create<memref::CastOp>(loc, memRefType, alloc);
451495
}
452-
return alloc;
453-
}
454496

455-
/// Create a memref allocation with the given type and dynamic extents.
456-
FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
457-
MemRefType type,
458-
ValueRange dynShape) {
459-
return createBufferAllocation(b, loc, type, dynShape);
497+
return alloc;
460498
}
461499

462500
/// Create a memory copy between two memref buffers.
@@ -480,7 +518,9 @@ createAllocDeallocOps(Operation *op, const BufferizationOptions &options) {
480518
// Ignore memref.alloca ops that were not created by the bufferization.
481519
if (!allocaOp->hasAttr(kBufferAllocationAttr))
482520
return WalkResult::skip();
521+
bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr);
483522

523+
// Create alloc.
484524
Block *block = allocaOp->getBlock();
485525
rewriter.setInsertionPoint(allocaOp);
486526
FailureOr<Value> alloc =
@@ -490,10 +530,11 @@ createAllocDeallocOps(Operation *op, const BufferizationOptions &options) {
490530
return WalkResult::interrupt();
491531
rewriter.replaceOp(allocaOp, *alloc);
492532

493-
// Stop here if deallocations are deactivated.
494-
if (!options.createDeallocs)
533+
// Stop here if the buffer should not be deallocated.
534+
if (skipDealloc)
495535
return WalkResult::advance();
496536

537+
// Create dealloc.
497538
rewriter.setInsertionPoint(block->getTerminator());
498539
if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options)))
499540
return WalkResult::interrupt();

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
379379

380380
BufferizationOptions bufferization::getPartialBufferizationOptions() {
381381
BufferizationOptions options;
382-
options.allowReturnMemref = true;
383382
options.allowUnknownOps = true;
384383
options.createDeallocs = false;
385384
options.fullyDynamicLayoutMaps = false;

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,43 @@ bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
215215
return aliasInfo.areEquivalentBufferizedValues(v1, v2);
216216
}
217217

218+
// Gather yielded tensors in `yieldedTensors` by querying all aliases. This is
219+
// to ensure that such information is available during bufferization time.
220+
// Alias information can no longer be queried through BufferizationAliasInfo
221+
// once we have started modifying the IR.
222+
void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
223+
op->walk([&](Operation *returnOp) {
224+
if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
225+
return WalkResult::advance();
226+
227+
for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
228+
Value returnVal = returnValOperand.get();
229+
// Skip non-tensor values.
230+
if (!returnVal.getType().isa<TensorType>())
231+
continue;
232+
233+
// Add all aliases of the returned value. But only the ones that are in
234+
// the same block.
235+
aliasInfo.applyOnAliases(returnVal, [&](Value v) {
236+
if (auto bbArg = v.dyn_cast<BlockArgument>()) {
237+
if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
238+
yieldedTensors.insert(bbArg);
239+
return;
240+
}
241+
Operation *definingOp = v.getDefiningOp();
242+
if (definingOp->getParentOp() == returnOp->getParentOp())
243+
yieldedTensors.insert(v);
244+
});
245+
}
246+
247+
return WalkResult::advance();
248+
});
249+
}
250+
251+
bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
252+
return yieldedTensors.contains(tensor);
253+
}
254+
218255
//===----------------------------------------------------------------------===//
219256
// Bufferization-specific alias analysis.
220257
//===----------------------------------------------------------------------===//
@@ -780,6 +817,9 @@ LogicalResult bufferization::analyzeOp(Operation *op,
780817
failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps));
781818
}
782819

820+
// Gather all yielded tensors.
821+
state.gatherYieldedTensors(op);
822+
783823
// Analysis verification: After setting up alias/equivalence sets, each op
784824
// can check for expected invariants/limitations and fail the analysis if
785825
// necessary.

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,8 @@ struct FromElementsOpInterface
335335
Location loc = op->getLoc();
336336
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
337337
auto shape = tensorType.getShape();
338-
MemRefType resultType = getContiguousMemRefType(tensorType);
339338
FailureOr<Value> maybeBuffer =
340-
state.createAlloc(rewriter, loc, resultType, {});
339+
state.createAlloc(rewriter, loc, fromElementsOp.result());
341340
if (failed(maybeBuffer))
342341
return failure();
343342
Value buffer = *maybeBuffer;
@@ -386,8 +385,8 @@ struct GenerateOpInterface
386385
Location loc = op->getLoc();
387386
MemRefType memrefType =
388387
getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
389-
FailureOr<Value> maybeResult = state.createAlloc(
390-
rewriter, loc, memrefType, generateOp.dynamicExtents());
388+
FailureOr<Value> maybeResult =
389+
state.createAlloc(rewriter, loc, generateOp.result());
391390
if (failed(maybeResult))
392391
return failure();
393392
Value result = *maybeResult;

0 commit comments

Comments
 (0)