Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit da4098b

Browse files
River707tensorflower-gardener
authored andcommitted
Change the notifyRootUpdated API to be transaction based.
This means that in-place, or root, updates need to use explicit calls to `startRootUpdate`, `finalizeRootUpdate`, and `cancelRootUpdate`. The major benefit of this change is that it enables in-place updates in DialectConversion, which simplifies the FuncOp pattern for example. The major downside to this is that the cases that *may* modify an operation in-place will need an explicit cancel on the failure branches(assuming that they started an update before attempting the transformation). PiperOrigin-RevId: 286933674
1 parent 16b138a commit da4098b

File tree

10 files changed

+199
-87
lines changed

10 files changed

+199
-87
lines changed

include/mlir/IR/BlockSupport.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class SuccessorRange final
6161
public:
6262
using RangeBaseT::RangeBaseT;
6363
SuccessorRange(Block *block);
64+
SuccessorRange(Operation *term);
6465

6566
private:
6667
/// See `detail::indexed_accessor_range_base` for details.

include/mlir/IR/Operation.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,12 @@ class Operation final
385385
return {getTrailingObjects<BlockOperand>(), numSuccs};
386386
}
387387

388+
// Successor iteration.
389+
using succ_iterator = SuccessorRange::iterator;
390+
succ_iterator successor_begin() { return getSuccessors().begin(); }
391+
succ_iterator successor_end() { return getSuccessors().end(); }
392+
SuccessorRange getSuccessors() { return SuccessorRange(this); }
393+
388394
/// Return the operands of this operation that are *not* successor arguments.
389395
operand_range getNonSuccessorOperands();
390396

include/mlir/IR/PatternMatch.h

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -361,15 +361,31 @@ class PatternRewriter : public OpBuilder {
361361
/// block into a new block, and return it.
362362
virtual Block *splitBlock(Block *block, Block::iterator before);
363363

364-
/// This method is used as the final notification hook for patterns that end
365-
/// up modifying the pattern root in place, by changing its operands. This is
366-
/// a minor efficiency win (it avoids creating a new operation and removing
367-
/// the old one) but also often allows simpler code in the client.
368-
///
369-
/// The valuesToRemoveIfDead list is an optional list of values that the
370-
/// rewriter should remove if they are dead at this point.
371-
///
372-
void updatedRootInPlace(Operation *op, ValueRange valuesToRemoveIfDead = {});
364+
/// This method is used to notify the rewriter that an in-place operation
365+
/// modification is about to happen. A call to this function *must* be
366+
/// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
367+
/// This is a minor efficiency win (it avoids creating a new operation and
368+
/// removing the old one) but also often allows simpler code in the client.
369+
virtual void startRootUpdate(Operation *op) {}
370+
371+
/// This method is used to signal the end of a root update on the given
372+
/// operation. This can only be called on operations that were provided to a
373+
/// call to `startRootUpdate`.
374+
virtual void finalizeRootUpdate(Operation *op) {}
375+
376+
/// This method cancels a pending root update. This can only be called on
377+
/// operations that were provided to a call to `startRootUpdate`.
378+
virtual void cancelRootUpdate(Operation *op) {}
379+
380+
/// This method is a utility wrapper around a root update of an operation. It
381+
/// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
382+
/// callable.
383+
template <typename CallableT>
384+
void updateRootInPlace(Operation *root, CallableT &&callable) {
385+
startRootUpdate(root);
386+
callable();
387+
finalizeRootUpdate(root);
388+
}
373389

374390
protected:
375391
explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
@@ -378,10 +394,6 @@ class PatternRewriter : public OpBuilder {
378394
// These are the callback methods that subclasses can choose to implement if
379395
// they would like to be notified about certain types of mutations.
380396

381-
/// Notify the pattern rewriter that the specified operation has been mutated
382-
/// in place. This is called after the mutation is done.
383-
virtual void notifyRootUpdated(Operation *op) {}
384-
385397
/// Notify the pattern rewriter that the specified operation is about to be
386398
/// replaced with another set of operations. This is called before the uses
387399
/// of the operation have been changed.

include/mlir/Transforms/DialectConversion.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,16 @@ class ConversionPatternRewriter final : public PatternRewriter {
365365
Operation *insert(Operation *op) override;
366366

367367
/// PatternRewriter hook for updating the root operation in-place.
368-
void notifyRootUpdated(Operation *op) override;
368+
/// Note: These methods only track updates to the top-level operation itself,
369+
/// and not nested regions. Updates to regions will still require notification
370+
/// through other more specific hooks above.
371+
void startRootUpdate(Operation *op) override;
372+
373+
/// PatternRewriter hook for updating the root operation in-place.
374+
void finalizeRootUpdate(Operation *op) override;
375+
376+
/// PatternRewriter hook for updating the root operation in-place.
377+
void cancelRootUpdate(Operation *op) override;
369378

370379
/// Return a reference to the internal implementation.
371380
detail::ConversionPatternRewriterImpl &getImpl();

lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,12 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<ValuePtr> operands,
5454
signatureConverter.addInputs(argType.index(), convertedType);
5555
}
5656
}
57-
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
58-
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
59-
newFuncOp.end());
60-
newFuncOp.setType(rewriter.getFunctionType(
61-
signatureConverter.getConvertedTypes(), llvm::None));
62-
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
63-
rewriter.replaceOp(funcOp.getOperation(), llvm::None);
57+
58+
rewriter.updateRootInPlace(funcOp, [&] {
59+
funcOp.setType(rewriter.getFunctionType(
60+
signatureConverter.getConvertedTypes(), llvm::None));
61+
rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
62+
});
6463
return matchSuccess();
6564
}
6665

lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -472,39 +472,38 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
472472

473473
PatternMatchResult matchAndRewrite(LaunchOp launchOp,
474474
PatternRewriter &rewriter) const override {
475-
auto origInsertionPoint = rewriter.saveInsertionPoint();
475+
rewriter.startRootUpdate(launchOp);
476+
PatternRewriter::InsertionGuard guard(rewriter);
476477
rewriter.setInsertionPointToStart(&launchOp.body().front());
477478

478479
// Traverse operands passed to kernel and check if some of them are known
479480
// constants. If so, clone the constant operation inside the kernel region
480481
// and use it instead of passing the value from the parent region. Perform
481482
// the traversal in the inverse order to simplify index arithmetics when
482483
// dropping arguments.
483-
SmallVector<ValuePtr, 8> operands(launchOp.getKernelOperandValues().begin(),
484-
launchOp.getKernelOperandValues().end());
485-
SmallVector<ValuePtr, 8> kernelArgs(launchOp.getKernelArguments().begin(),
486-
launchOp.getKernelArguments().end());
484+
auto operands = launchOp.getKernelOperandValues();
485+
auto kernelArgs = launchOp.getKernelArguments();
487486
bool found = false;
488487
for (unsigned i = operands.size(); i > 0; --i) {
489488
unsigned index = i - 1;
490-
ValuePtr operand = operands[index];
491-
if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) {
489+
Value operand = operands[index];
490+
if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp()))
492491
continue;
493-
}
494492

495493
found = true;
496-
ValuePtr internalConstant =
494+
Value internalConstant =
497495
rewriter.clone(*operand->getDefiningOp())->getResult(0);
498-
ValuePtr kernelArg = kernelArgs[index];
496+
Value kernelArg = *std::next(kernelArgs.begin(), index);
499497
kernelArg->replaceAllUsesWith(internalConstant);
500498
launchOp.eraseKernelArgument(index);
501499
}
502-
rewriter.restoreInsertionPoint(origInsertionPoint);
503500

504-
if (!found)
501+
if (!found) {
502+
rewriter.cancelRootUpdate(launchOp);
505503
return matchFailure();
504+
}
506505

507-
rewriter.updatedRootInPlace(launchOp);
506+
rewriter.finalizeRootUpdate(launchOp);
508507
return matchSuccess();
509508
}
510509
};

lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,11 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<ValuePtr> operands,
197197
}
198198

199199
// Creates a new function with the update signature.
200-
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
201-
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
202-
newFuncOp.end());
203-
newFuncOp.setType(rewriter.getFunctionType(
204-
signatureConverter.getConvertedTypes(), llvm::None));
205-
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
206-
rewriter.eraseOp(funcOp.getOperation());
200+
rewriter.updateRootInPlace(funcOp, [&] {
201+
funcOp.setType(rewriter.getFunctionType(
202+
signatureConverter.getConvertedTypes(), llvm::None));
203+
rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
204+
});
207205
return matchSuccess();
208206
}
209207

lib/IR/Block.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,8 @@ SuccessorRange::SuccessorRange(Block *block) : SuccessorRange(nullptr, 0) {
267267
if ((count = term->getNumSuccessors()))
268268
base = term->getBlockOperands().data();
269269
}
270+
271+
SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange(nullptr, 0) {
272+
if ((count = term->getNumSuccessors()))
273+
base = term->getBlockOperands().data();
274+
}

lib/IR/PatternMatch.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -170,23 +170,6 @@ void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
170170
cloneRegionBefore(region, *before->getParent(), before->getIterator());
171171
}
172172

173-
/// This method is used as the final notification hook for patterns that end
174-
/// up modifying the pattern root in place, by changing its operands. This is
175-
/// a minor efficiency win (it avoids creating a new operation and removing
176-
/// the old one) but also often allows simpler code in the client.
177-
///
178-
/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
179-
/// should remove if they are dead at this point.
180-
///
181-
void PatternRewriter::updatedRootInPlace(Operation *op,
182-
ValueRange valuesToRemoveIfDead) {
183-
// Notify the rewriter subclass that we're about to replace this root.
184-
notifyRootUpdated(op);
185-
186-
// TODO: Process the valuesToRemoveIfDead list, removing things and calling
187-
// the notifyOperationRemoved hook in the process.
188-
}
189-
190173
//===----------------------------------------------------------------------===//
191174
// PatternMatcher implementation
192175
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)