Skip to content

Commit f30ec8f

Browse files
[mlir][linalg][bufferize][NFC] Allow passing custom BufferizationOptions to pass
Differential Revision: https://reviews.llvm.org/D118891
1 parent ef736a1 commit f30ec8f

File tree

7 files changed

+83
-79
lines changed

7 files changed

+83
-79
lines changed

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ namespace comprehensive_bufferize {
2727
/// Run Module Bufferization on the given module. Performs a simple function
2828
/// call analysis to determine which function arguments are inplaceable. Then
2929
/// analyzes and bufferizes FuncOps one-by-one with One-Shot Bufferize.
30-
LogicalResult runComprehensiveBufferize(
31-
ModuleOp moduleOp,
32-
std::unique_ptr<bufferization::AnalysisBufferizationOptions> options);
30+
LogicalResult
31+
runModuleBufferize(ModuleOp moduleOp,
32+
bufferization::AnalysisBufferizationOptions options);
3333

3434
namespace std_ext {
3535

mlir/include/mlir/Dialect/Linalg/Passes.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
#include "mlir/Pass/Pass.h"
1919

2020
namespace mlir {
21+
namespace bufferization {
22+
struct AnalysisBufferizationOptions;
23+
} // namespace bufferization
2124

2225
std::unique_ptr<Pass> createConvertElementwiseToLinalgPass();
2326

@@ -64,8 +67,8 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
6467
/// on SSA use-def chains starting from function operands that are annotated
6568
/// with the 'inplaceable' attribute.
6669
std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass();
67-
std::unique_ptr<Pass>
68-
createLinalgComprehensiveModuleBufferizePass(bool useLinalgCopy);
70+
std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass(
71+
const bufferization::AnalysisBufferizationOptions &options);
6972

7073
/// Create a pass to convert Linalg operations which work on tensors to use
7174
/// buffers instead.

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ def LinalgComprehensiveModuleBufferize :
5252
Option<"useAlloca", "use-alloca", "bool",
5353
/*default=*/"false",
5454
"Use stack allocations for memrefs (for testing purposes only)">,
55-
Option<"useLinalgCopy", "use-memref.copy", "bool",
56-
/*default=*/"false",
57-
"Use a copy operation implemented as a Linalg op.">,
5855
Option<"fullyDynamicLayoutMaps", "fully-dynamic-layout-maps", "bool",
5956
/*default=*/"true",
6057
"Generate MemRef types with dynamic offset+strides by default.">,

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
135135
Value outputTensor,
136136
ArrayRef<int64_t> transposeVector);
137137

138+
/// Returns GenericOp that copies an n-D memref. Unlike the current
139+
/// implementation of memref::CopyOp, this op can further tile, lower to loops
140+
/// or vectorize.
141+
GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
142+
138143
//===----------------------------------------------------------------------===//
139144
// Fusion / Tiling utilities
140145
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
// bufferizes function boundaries. It provides `BufferizableOpInterface`
1111
// implementations for FuncOp, CallOp and ReturnOp.
1212
//
13-
// Module Bufferization is run via `runComprehensiveBufferize(ModuleOp, ...)`.
14-
// This function analyzed the given module and determines the order of
15-
// analysis and bufferization: Functions that are called are processed before
16-
// their respective callers.
13+
// Module Bufferization is run via `runModuleBufferize(ModuleOp, ...)`. This
14+
// function analyzes the given module and determines the order of analysis and
15+
// bufferization: Functions that are called are processed before their
16+
// respective callers.
1717
//
1818
// After analyzing a FuncOp, additional information about its bbArgs is
1919
// gathered through PostAnalysisStepFns and stored in
@@ -971,10 +971,10 @@ annotateOpsWithBufferizationMarkers(FuncOp funcOp,
971971
setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state));
972972
}
973973

974-
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
975-
ModuleOp moduleOp, std::unique_ptr<AnalysisBufferizationOptions> options) {
974+
LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
975+
ModuleOp moduleOp, AnalysisBufferizationOptions options) {
976976
IRRewriter rewriter(moduleOp.getContext());
977-
AnalysisBufferizationState state(moduleOp, *options);
977+
AnalysisBufferizationState state(moduleOp, options);
978978
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
979979
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
980980

@@ -983,8 +983,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
983983
return failure();
984984

985985
// Collect bbArg/return value information after the analysis.
986-
options->postAnalysisSteps.push_back(equivalentFuncOpBBArgsAnalysis);
987-
options->postAnalysisSteps.push_back(funcOpBbArgReadWriteAnalysis);
986+
options.addPostAnalysisStep(equivalentFuncOpBBArgsAnalysis);
987+
options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis);
988988

989989
// Analyze ops.
990990
for (FuncOp funcOp : moduleState.orderedFuncOps) {
@@ -1007,11 +1007,11 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
10071007
moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
10081008

10091009
// Add annotations to function arguments.
1010-
if (options->testAnalysisOnly)
1010+
if (options.testAnalysisOnly)
10111011
annotateOpsWithBufferizationMarkers(funcOp, state);
10121012
}
10131013

1014-
if (options->testAnalysisOnly)
1014+
if (options.testAnalysisOnly)
10151015
return success();
10161016

10171017
// Bufferize function bodies.
@@ -1031,7 +1031,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
10311031
if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, state)))
10321032
return failure();
10331033

1034-
if (!options->allowReturnMemref &&
1034+
if (!options.allowReturnMemref &&
10351035
llvm::any_of(funcOp.getType().getResults(), [](Type t) {
10361036
return t.isa<MemRefType, UnrankedMemRefType>();
10371037
})) {

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Lines changed: 35 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ struct LinalgComprehensiveModuleBufferize
3838
LinalgComprehensiveModuleBufferize(
3939
const LinalgComprehensiveModuleBufferize &p) = default;
4040

41-
LinalgComprehensiveModuleBufferize(bool linalgCopy) {
42-
this->useLinalgCopy = linalgCopy;
43-
}
41+
explicit LinalgComprehensiveModuleBufferize(
42+
AnalysisBufferizationOptions options)
43+
: options(options) {}
4444

4545
void runOnOperation() override;
4646

@@ -58,6 +58,9 @@ struct LinalgComprehensiveModuleBufferize
5858
tensor::registerBufferizableOpInterfaceExternalModels(registry);
5959
vector::registerBufferizableOpInterfaceExternalModels(registry);
6060
}
61+
62+
private:
63+
llvm::Optional<AnalysisBufferizationOptions> options;
6164
};
6265
} // namespace
6366

@@ -76,71 +79,44 @@ static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
7679
return allocated;
7780
}
7881

79-
/// Create a linalg::GenericOp version of an n-D copy that can further tile,
80-
/// lower to loops or vectorize, unlike the current implementation of
81-
/// memref::CopyOp.
82-
/// Do not depend on memref::CopyOp that is getting deprecated.
83-
static LogicalResult createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
84-
Value to) {
85-
auto memrefTypeFrom = from.getType().cast<MemRefType>();
86-
auto memrefTypeTo = to.getType().cast<MemRefType>();
87-
if (!memrefTypeFrom || !memrefTypeTo ||
88-
memrefTypeFrom.getRank() != memrefTypeTo.getRank())
89-
return failure();
90-
AffineMap id =
91-
AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
92-
SmallVector<StringRef> iteratorTypes(memrefTypeTo.getRank(),
93-
getParallelIteratorTypeName());
94-
b.create<linalg::GenericOp>(loc,
95-
/*inputs=*/from,
96-
/*outputs=*/to,
97-
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
98-
/*iteratorTypes=*/iteratorTypes,
99-
[](OpBuilder &b, Location loc, ValueRange args) {
100-
b.create<linalg::YieldOp>(loc, args.front());
101-
});
102-
return success();
103-
}
104-
10582
void LinalgComprehensiveModuleBufferize::runOnOperation() {
106-
auto options = std::make_unique<AnalysisBufferizationOptions>();
107-
if (useAlloca) {
108-
options->allocationFn = allocationFnUsingAlloca;
109-
options->deallocationFn = [](OpBuilder &b, Location loc, Value v) {
110-
return success();
111-
};
112-
}
113-
// TODO: atm memref::CopyOp can be 200x slower than linalg::GenericOp.
114-
// Once this perf bug is fixed more systematically, we can revisit.
115-
if (useLinalgCopy)
116-
options->memCpyFn = createLinalgCopyOp;
117-
118-
options->allowReturnMemref = allowReturnMemref;
119-
options->allowUnknownOps = allowUnknownOps;
120-
options->analysisFuzzerSeed = analysisFuzzerSeed;
121-
options->createDeallocs = createDeallocs;
122-
options->fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
123-
options->printConflicts = printConflicts;
124-
options->testAnalysisOnly = testAnalysisOnly;
125-
126-
// Enable InitTensorOp elimination.
127-
if (initTensorElimination) {
128-
options->addPostAnalysisStep(
129-
linalg_ext::insertSliceAnchoredInitTensorEliminationStep);
83+
AnalysisBufferizationOptions opt;
84+
if (!options) {
85+
// Make new bufferization options if none were provided when creating the
86+
// pass.
87+
if (useAlloca) {
88+
opt.allocationFn = allocationFnUsingAlloca;
89+
opt.deallocationFn = [](OpBuilder &b, Location loc, Value v) {
90+
return success();
91+
};
92+
}
93+
opt.allowReturnMemref = allowReturnMemref;
94+
opt.allowUnknownOps = allowUnknownOps;
95+
opt.analysisFuzzerSeed = analysisFuzzerSeed;
96+
opt.createDeallocs = createDeallocs;
97+
opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
98+
opt.printConflicts = printConflicts;
99+
opt.testAnalysisOnly = testAnalysisOnly;
100+
if (initTensorElimination) {
101+
opt.addPostAnalysisStep(
102+
linalg_ext::insertSliceAnchoredInitTensorEliminationStep);
103+
}
104+
} else {
105+
opt = *options;
130106
}
131107

132108
// Only certain scf.for ops are supported by the analysis.
133-
options->addPostAnalysisStep(scf::assertScfForAliasingProperties);
109+
opt.addPostAnalysisStep(scf::assertScfForAliasingProperties);
134110

135111
ModuleOp moduleOp = getOperation();
136112
applyEnablingTransformations(moduleOp);
137113

138-
if (failed(runComprehensiveBufferize(moduleOp, std::move(options)))) {
114+
if (failed(runModuleBufferize(moduleOp, opt))) {
139115
signalPassFailure();
140116
return;
141117
}
142118

143-
if (testAnalysisOnly)
119+
if (opt.testAnalysisOnly)
144120
return;
145121

146122
OpPassManager cleanupPipeline("builtin.module");
@@ -154,7 +130,7 @@ std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {
154130
return std::make_unique<LinalgComprehensiveModuleBufferize>();
155131
}
156132

157-
std::unique_ptr<Pass>
158-
mlir::createLinalgComprehensiveModuleBufferizePass(bool useLinalgCopy) {
159-
return std::make_unique<LinalgComprehensiveModuleBufferize>(useLinalgCopy);
133+
std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass(
134+
const AnalysisBufferizationOptions &options) {
135+
return std::make_unique<LinalgComprehensiveModuleBufferize>(options);
160136
}

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,29 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
423423
return transposeOp;
424424
}
425425

426+
GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
427+
auto memrefTypeTo = to.getType().cast<MemRefType>();
428+
#ifndef NDEBUG
429+
auto memrefTypeFrom = from.getType().cast<MemRefType>();
430+
assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
431+
"`from` and `to` memref must have the same rank");
432+
#endif // NDEBUG
433+
434+
AffineMap id =
435+
AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
436+
SmallVector<StringRef> iteratorTypes(memrefTypeTo.getRank(),
437+
getParallelIteratorTypeName());
438+
return b.create<linalg::GenericOp>(
439+
loc,
440+
/*inputs=*/from,
441+
/*outputs=*/to,
442+
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
443+
/*iteratorTypes=*/iteratorTypes,
444+
[](OpBuilder &b, Location loc, ValueRange args) {
445+
b.create<linalg::YieldOp>(loc, args.front());
446+
});
447+
}
448+
426449
/// Specialization to build an scf "for" nest.
427450
template <>
428451
void GenerateLoopNest<scf::ForOp>::doit(

0 commit comments

Comments
 (0)