@@ -42,8 +42,12 @@ constexpr const ::llvm::StringLiteral
42
42
constexpr const ::llvm::StringLiteral
43
43
bufferization::BufferizableOpInterface::kInplaceableAttrName ;
44
44
45
+ // / Attribute name used to mark allocs that are created by the bufferization.
45
46
static const char *kBufferAllocationAttr = " bufferization.allocation" ;
46
47
48
+ // / Attribute name used to mark allocs that should not be deallocated.
49
+ static const char *kSkipDeallocAttr = " bufferization.skip_dealloc" ;
50
+
47
51
// ===----------------------------------------------------------------------===//
48
52
// BufferizationOptions
49
53
// ===----------------------------------------------------------------------===//
@@ -253,6 +257,8 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
253
257
OpBuilder::InsertionGuard guard (rewriter);
254
258
Operation *op = opOperand.getOwner ();
255
259
Location loc = op->getLoc ();
260
+ SmallVector<OpResult> aliasingOpResults =
261
+ analysisState.getAliasingOpResult (opOperand);
256
262
Value operand = opOperand.get ();
257
263
Value operandBuffer = lookupBuffer (rewriter, operand, options);
258
264
@@ -263,8 +269,13 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
263
269
// Move insertion point right after `operandBuffer`. That is where the
264
270
// allocation should be inserted (in the absence of allocation hoisting).
265
271
setInsertionPointAfter (rewriter, operandBuffer);
266
- // Allocate the result buffer.
267
- FailureOr<Value> resultBuffer = createAlloc (rewriter, loc, operandBuffer);
272
+ // Allocate the result buffer. The buffer should be deallocated if the tensor
273
+ // is not yielded and deallocs are enabled in general.
274
+ bool dealloc = llvm::none_of (aliasingOpResults, [&](Value v) {
275
+ return getAnalysisState ().isTensorYielded (v);
276
+ });
277
+ FailureOr<Value> resultBuffer = createAlloc (
278
+ rewriter, loc, operandBuffer, dealloc && getOptions ().createDeallocs );
268
279
if (failed (resultBuffer))
269
280
return failure ();
270
281
// Do not copy if the last preceding writes of `operand` are ops that do
@@ -281,8 +292,6 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
281
292
}))
282
293
return resultBuffer;
283
294
// Do not copy if the copied data is never read.
284
- SmallVector<OpResult> aliasingOpResults =
285
- analysisState.getAliasingOpResult (opOperand);
286
295
if (!aliasingOpResults.empty () &&
287
296
!analysisState.bufferizesToMemoryRead (opOperand) &&
288
297
llvm::none_of (aliasingOpResults, [&](OpResult opResult) {
@@ -339,7 +348,12 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
339
348
340
349
AlwaysCopyAnalysisState::AlwaysCopyAnalysisState (
341
350
const BufferizationOptions &options)
342
- : AnalysisState(options) {}
351
+ : AnalysisState(options) {
352
+ // Note: Allocations must be deallocated with a subsequent run of the buffer
353
+ // deallocation pass.
354
+ assert (!options.createDeallocs &&
355
+ " cannot create deallocs with AlwaysCopyBufferizationState" );
356
+ }
343
357
344
358
// / Return `true` if the given OpResult has been decided to bufferize inplace.
345
359
bool AlwaysCopyAnalysisState::isInPlace (OpOperand &opOperand) const {
@@ -356,6 +370,13 @@ bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
356
370
return false ;
357
371
}
358
372
373
+ // / Return true if the given tensor (or an aliasing tensor) is yielded from
374
+ // / the containing block. Also include all aliasing tensors in the same block.
375
+ bool AlwaysCopyAnalysisState::isTensorYielded (Value tensor) const {
376
+ // There is no analysis, so conservatively answer "true".
377
+ return true ;
378
+ }
379
+
359
380
// ===----------------------------------------------------------------------===//
360
381
// Bufferization-specific scoped alloc/dealloc insertion support.
361
382
// ===----------------------------------------------------------------------===//
@@ -426,37 +447,54 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
426
447
}
427
448
428
449
static Value createBufferAllocation (OpBuilder &b, Location loc, MemRefType type,
429
- ValueRange dynShape) {
450
+ ValueRange dynShape, bool skipDealloc ) {
430
451
auto allocaOp = b.create <memref::AllocaOp>(loc, type, dynShape);
431
452
allocaOp->setAttr (kBufferAllocationAttr , b.getUnitAttr ());
453
+ if (skipDealloc)
454
+ allocaOp->setAttr (kSkipDeallocAttr , b.getUnitAttr ());
432
455
return allocaOp.getResult ();
433
456
}
434
457
435
458
// / Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
436
459
// / block in case of a bbArg).
437
460
FailureOr<Value> BufferizationState::createAlloc (OpBuilder &b, Location loc,
438
- Value shapedValue) {
461
+ Value shapedValue,
462
+ Optional<bool > dealloc) {
439
463
// Take a guard before anything else.
440
464
OpBuilder::InsertionGuard g (b);
465
+
466
+ // Compute allocation memref type.
441
467
assert (shapedValue.getType ().isa <ShapedType>());
442
468
MemRefType memRefType = shapedValue.getType ().dyn_cast <MemRefType>();
443
469
SmallVector<Value> dynShape;
444
470
MemRefType allocMemRefType =
445
471
getAllocationTypeAndShape (b, loc, shapedValue, dynShape);
446
- Value alloc = createBufferAllocation (b, loc, allocMemRefType, dynShape);
472
+
473
+ // Should be the buffer be deallocated again or should we let it leak?
474
+ bool skipDealloc;
475
+ if (dealloc) {
476
+ skipDealloc = !dealloc.getValue ();
477
+ } else {
478
+ assert (shapedValue.getType ().isa <TensorType>() &&
479
+ " must specify `dealloc` if non-tensor value is passed" );
480
+ // Buffer should be not be deallocated if deallocs are generally deactivated
481
+ // or if the tensor is yielded from a block.
482
+ skipDealloc = !getOptions ().createDeallocs ||
483
+ getAnalysisState ().isTensorYielded (shapedValue);
484
+ }
485
+
486
+ // Create the buffer allocation.
487
+ Value alloc =
488
+ createBufferAllocation (b, loc, allocMemRefType, dynShape, skipDealloc);
489
+
490
+ // Insert a cast if a different type was requested.
447
491
if (memRefType && memRefType != allocMemRefType) {
448
- assert (memref::CastOp::areCastCompatible (alloc. getType () , memRefType) &&
492
+ assert (memref::CastOp::areCastCompatible (allocMemRefType , memRefType) &&
449
493
" createAlloc: cast incompatible" );
450
494
alloc = b.create <memref::CastOp>(loc, memRefType, alloc);
451
495
}
452
- return alloc;
453
- }
454
496
455
- // / Create a memref allocation with the given type and dynamic extents.
456
- FailureOr<Value> BufferizationState::createAlloc (OpBuilder &b, Location loc,
457
- MemRefType type,
458
- ValueRange dynShape) {
459
- return createBufferAllocation (b, loc, type, dynShape);
497
+ return alloc;
460
498
}
461
499
462
500
// / Create a memory copy between two memref buffers.
@@ -480,7 +518,9 @@ createAllocDeallocOps(Operation *op, const BufferizationOptions &options) {
480
518
// Ignore memref.alloca ops that were not created by the bufferization.
481
519
if (!allocaOp->hasAttr (kBufferAllocationAttr ))
482
520
return WalkResult::skip ();
521
+ bool skipDealloc = allocaOp->hasAttr (kSkipDeallocAttr );
483
522
523
+ // Create alloc.
484
524
Block *block = allocaOp->getBlock ();
485
525
rewriter.setInsertionPoint (allocaOp);
486
526
FailureOr<Value> alloc =
@@ -490,10 +530,11 @@ createAllocDeallocOps(Operation *op, const BufferizationOptions &options) {
490
530
return WalkResult::interrupt ();
491
531
rewriter.replaceOp (allocaOp, *alloc);
492
532
493
- // Stop here if deallocations are deactivated .
494
- if (!options. createDeallocs )
533
+ // Stop here if the buffer should not be deallocated .
534
+ if (skipDealloc )
495
535
return WalkResult::advance ();
496
536
537
+ // Create dealloc.
497
538
rewriter.setInsertionPoint (block->getTerminator ());
498
539
if (failed (createDealloc (rewriter, alloc->getLoc (), *alloc, options)))
499
540
return WalkResult::interrupt ();
0 commit comments