Skip to content

Commit 4002eaa

Browse files
[mlir][bufferize] Improve analysis of external functions
External functions have no body, so they cannot be analyzed. Assume conservatively that each tensor bbArg may be aliasing with each tensor result. Furthermore, assume that each function arg is read and written-to after bufferization. This default behavior can be controlled with `bufferization.access` (similar to `bufferization.memory_layout`) in test cases. Also fix a bug in the dialect attribute verifier, which did not run for region argument attributes. Differential Revision: https://reviews.llvm.org/D139517
1 parent 66692c8 commit 4002eaa

File tree

7 files changed

+186
-37
lines changed

7 files changed

+186
-37
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,29 @@ def Bufferization_Dialect : Dialect {
3030
];
3131

3232
let extraClassDeclaration = [{
33+
/// Verify an attribute from this dialect on the argument at 'argIndex' for
34+
/// the region at 'regionIndex' on the given operation. Returns failure if
35+
/// the verification failed, success otherwise. This hook may optionally be
36+
/// invoked from any operation containing a region.
37+
LogicalResult verifyRegionArgAttribute(Operation *,
38+
unsigned regionIndex,
39+
unsigned argIndex,
40+
NamedAttribute) override;
41+
3342
/// An attribute that can override writability of buffers of tensor function
3443
/// arguments during One-Shot Module Bufferize.
3544
constexpr const static ::llvm::StringLiteral
3645
kWritableAttrName = "bufferization.writable";
3746

47+
/// An attribute for function arguments that describes how the function
48+
/// accesses the buffer. Can be one "none", "read", "write" or "read-write".
49+
///
50+
/// When no attribute is specified, the analysis tries to infer the access
51+
/// behavior from its body. In case of external functions, for which no
52+
/// function body is available, "read-write" is assumed by default.
53+
constexpr const static ::llvm::StringLiteral
54+
kBufferAccessAttrName = "bufferization.access";
55+
3856
/// Attribute name used to mark the bufferization layout for region
3957
/// arguments during One-Shot Module Bufferize.
4058
constexpr const static ::llvm::StringLiteral

mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,34 @@ void mlir::bufferization::BufferizationDialect::initialize() {
5959
addInterfaces<BufferizationInlinerInterface>();
6060
}
6161

62-
LogicalResult
63-
BufferizationDialect::verifyOperationAttribute(Operation *op,
64-
NamedAttribute attr) {
65-
using bufferization::BufferizableOpInterface;
66-
62+
LogicalResult BufferizationDialect::verifyRegionArgAttribute(
63+
Operation *op, unsigned /*regionIndex*/, unsigned argIndex,
64+
NamedAttribute attr) {
6765
if (attr.getName() == kWritableAttrName) {
6866
if (!attr.getValue().isa<BoolAttr>()) {
6967
return op->emitError() << "'" << kWritableAttrName
7068
<< "' is expected to be a boolean attribute";
7169
}
7270
if (!isa<FunctionOpInterface>(op))
73-
return op->emitError() << "expected " << attr.getName()
74-
<< " to be used on function-like operations";
71+
return op->emitError() << "expected '" << kWritableAttrName
72+
<< "' to be used on function-like operations";
73+
if (cast<FunctionOpInterface>(op).isExternal())
74+
return op->emitError() << "'" << kWritableAttrName
75+
<< "' is invalid on external functions";
76+
return success();
77+
}
78+
if (attr.getName() == kBufferAccessAttrName) {
79+
if (!attr.getValue().isa<StringAttr>()) {
80+
return op->emitError() << "'" << kBufferAccessAttrName
81+
<< "' is expected to be a string attribute";
82+
}
83+
StringRef str = attr.getValue().cast<StringAttr>().getValue();
84+
if (str != "none" && str != "read" && str != "write" && str != "read-write")
85+
return op->emitError()
86+
<< "invalid value for '" << kBufferAccessAttrName << "'";
87+
if (!isa<FunctionOpInterface>(op))
88+
return op->emitError() << "expected '" << kBufferAccessAttrName
89+
<< "' to be used on function-like operations";
7590
return success();
7691
}
7792
if (attr.getName() == kBufferLayoutAttrName) {
@@ -80,10 +95,20 @@ BufferizationDialect::verifyOperationAttribute(Operation *op,
8095
<< "' is expected to be a affine map attribute";
8196
}
8297
if (!isa<FunctionOpInterface>(op))
83-
return op->emitError() << "expected " << attr.getName()
84-
<< " to be used on function-like operations";
98+
return op->emitError() << "expected '" << kBufferLayoutAttrName
99+
<< "' to be used on function-like operations";
85100
return success();
86101
}
102+
return op->emitError() << "attribute '" << kBufferLayoutAttrName
103+
<< "' not supported as a region arg attribute by the "
104+
"bufferization dialect";
105+
}
106+
107+
LogicalResult
108+
BufferizationDialect::verifyOperationAttribute(Operation *op,
109+
NamedAttribute attr) {
110+
using bufferization::BufferizableOpInterface;
111+
87112
if (attr.getName() == kEscapeAttrName) {
88113
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
89114
if (!arrayAttr)
@@ -116,6 +141,7 @@ BufferizationDialect::verifyOperationAttribute(Operation *op,
116141
return success();
117142
}
118143

119-
return op->emitError() << "attribute '" << attr.getName()
120-
<< "' not supported by the bufferization dialect";
144+
return op->emitError()
145+
<< "attribute '" << attr.getName()
146+
<< "' not supported as an op attribute by the bufferization dialect";
121147
}

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,25 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
127127
static LogicalResult
128128
aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
129129
FuncAnalysisState &funcState) {
130+
if (funcOp.getBody().empty()) {
131+
// No function body available. Conservatively assume that every tensor
132+
// return value may alias with any tensor bbArg.
133+
FunctionType type = funcOp.getFunctionType();
134+
for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
135+
if (!inputIt.value().isa<TensorType>())
136+
continue;
137+
for (const auto &resultIt : llvm::enumerate(type.getResults())) {
138+
if (!resultIt.value().isa<TensorType>())
139+
continue;
140+
int64_t returnIdx = resultIt.index();
141+
int64_t bbArgIdx = inputIt.index();
142+
funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx);
143+
funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
144+
}
145+
}
146+
return success();
147+
}
148+
130149
// Support only single return-terminated block in the function.
131150
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
132151
assert(returnOp && "expected func with single return op");
@@ -151,8 +170,8 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
151170
return success();
152171
}
153172

154-
static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg,
155-
bool isRead, bool isWritten) {
173+
static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
174+
bool isWritten) {
156175
OpBuilder b(funcOp.getContext());
157176
Attribute accessType;
158177
if (isRead && isWritten) {
@@ -164,7 +183,8 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg,
164183
} else {
165184
accessType = b.getStringAttr("none");
166185
}
167-
funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType);
186+
funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
187+
accessType);
168188
}
169189

170190
/// Determine which FuncOp bbArgs are read and which are written. When run on a
@@ -173,28 +193,37 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg,
173193
static LogicalResult
174194
funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
175195
FuncAnalysisState &funcState) {
176-
// If the function has no body, conservatively assume that all args are
177-
// read + written.
178-
if (funcOp.getBody().empty()) {
179-
for (BlockArgument bbArg : funcOp.getArguments()) {
180-
funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
181-
funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
196+
for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
197+
++idx) {
198+
// Skip non-tensor arguments.
199+
if (!funcOp.getFunctionType().getInput(idx).isa<TensorType>())
200+
continue;
201+
bool isRead;
202+
bool isWritten;
203+
if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
204+
idx, BufferizationDialect::kBufferAccessAttrName)) {
205+
// Buffer access behavior is specified on the function. Skip the analysis.
206+
StringRef str = accessAttr.getValue();
207+
isRead = str == "read" || str == "read-write";
208+
isWritten = str == "write" || str == "read-write";
209+
} else if (funcOp.getBody().empty()) {
210+
// If the function has no body, conservatively assume that all args are
211+
// read + written.
212+
isRead = true;
213+
isWritten = true;
214+
} else {
215+
// Analyze the body of the function.
216+
BlockArgument bbArg = funcOp.getArgument(idx);
217+
isRead = state.isValueRead(bbArg);
218+
isWritten = state.isValueWritten(bbArg);
182219
}
183220

184-
return success();
185-
}
186-
187-
for (BlockArgument bbArg : funcOp.getArguments()) {
188-
if (!bbArg.getType().isa<TensorType>())
189-
continue;
190-
bool isRead = state.isValueRead(bbArg);
191-
bool isWritten = state.isValueWritten(bbArg);
192221
if (state.getOptions().testAnalysisOnly)
193-
annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten);
222+
annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
194223
if (isRead)
195-
funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
224+
funcState.readBbArgs[funcOp].insert(idx);
196225
if (isWritten)
197-
funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
226+
funcState.writtenBbArgs[funcOp].insert(idx);
198227
}
199228

200229
return success();
@@ -351,10 +380,6 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
351380

352381
// Analyze ops.
353382
for (func::FuncOp funcOp : orderedFuncOps) {
354-
// No body => no analysis.
355-
if (funcOp.getBody().empty())
356-
continue;
357-
358383
// Now analyzing function.
359384
funcState.startFunctionAnalysis(funcOp);
360385

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,3 +1280,66 @@ func.func @write_to_same_alloc_tensor_out_of_place(
12801280

12811281
return %r0 : tensor<?xf32>
12821282
}
1283+
1284+
// -----
1285+
1286+
// CHECK-LABEL: func.func private @ext_func(tensor<*xf32> {bufferization.access = "read-write"})
1287+
func.func private @ext_func(%t: tensor<*xf32>)
1288+
1289+
// CHECK: func.func @private_func_read_write(%{{.*}}: tensor<5xf32> {bufferization.access = "read"})
1290+
func.func @private_func_read_write(%t: tensor<5xf32>) -> f32 {
1291+
%c0 = arith.constant 0 : index
1292+
// Bufferizes out-of-place because `ext_func` may modify the buffer.
1293+
// CHECK: tensor.cast {{.*}} {__inplace_operands_attr__ = ["false"]}
1294+
%0 = tensor.cast %t : tensor<5xf32> to tensor<*xf32>
1295+
func.call @ext_func(%0) : (tensor<*xf32>) -> ()
1296+
%1 = tensor.extract %t[%c0] : tensor<5xf32>
1297+
return %1 : f32
1298+
}
1299+
1300+
// -----
1301+
1302+
// CHECK-LABEL: func.func private @print_buffer(tensor<*xf32> {bufferization.access = "read"})
1303+
func.func private @print_buffer(%t: tensor<*xf32> {bufferization.access = "read"})
1304+
1305+
// CHECK: func.func @private_func_read(%{{.*}}: tensor<5xf32> {bufferization.access = "read"})
1306+
func.func @private_func_read(%t: tensor<5xf32>) -> f32 {
1307+
%c0 = arith.constant 0 : index
1308+
// Bufferizes in-place because `print_buffer` is read-only.
1309+
// CHECK: tensor.cast {{.*}} {__inplace_operands_attr__ = ["true"]}
1310+
%0 = tensor.cast %t : tensor<5xf32> to tensor<*xf32>
1311+
// CHECK: call @print_buffer(%cast) {__inplace_operands_attr__ = ["true"]}
1312+
func.call @print_buffer(%0) : (tensor<*xf32>) -> ()
1313+
%1 = tensor.extract %t[%c0] : tensor<5xf32>
1314+
return %1 : f32
1315+
}
1316+
1317+
// -----
1318+
1319+
// CHECK-LABEL: func.func private @ext_func(tensor<?xf32> {bufferization.access = "read-write"}, tensor<?xf32> {bufferization.access = "read-write"})
1320+
func.func private @ext_func(%t1: tensor<?xf32>, %t2: tensor<?xf32>)
1321+
1322+
// CHECK: func.func @private_func_two_params_writing(%{{.*}}: tensor<?xf32> {bufferization.access = "read"})
1323+
func.func @private_func_two_params_writing(%t: tensor<?xf32>) {
1324+
// Both operands bufferize out-of-place because both bufferize to a memory
1325+
// write.
1326+
// CHECK: call @ext_func(%{{.*}}, %{{.*}}) {__inplace_operands_attr__ = ["false", "false"]}
1327+
func.call @ext_func(%t, %t) : (tensor<?xf32>, tensor<?xf32>) -> ()
1328+
return
1329+
}
1330+
1331+
// -----
1332+
1333+
// CHECK-LABEL: func.func private @ext_func(tensor<?xf32> {bufferization.access = "read-write"}) -> (tensor<5xf32>, tensor<6xf32>)
1334+
func.func private @ext_func(%t: tensor<?xf32>) -> (tensor<5xf32>, tensor<6xf32>)
1335+
1336+
// CHECK: func.func @private_func_aliasing(%{{.*}}: tensor<?xf32> {bufferization.access = "read"})
1337+
func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
1338+
%c0 = arith.constant 0 : index
1339+
// Bufferizes out-of-place because either one of the two reuslts may alias
1340+
// with the argument and one of the results is read afterwards.
1341+
// CHECK: call @ext_func(%{{.*}}) {__inplace_operands_attr__ = ["false"]} : (tensor<?xf32>) -> (tensor<5xf32>, tensor<6xf32>)
1342+
%0, %1 = func.call @ext_func(%t) : (tensor<?xf32>) -> (tensor<5xf32>, tensor<6xf32>)
1343+
%2 = tensor.extract %1[%c0] : tensor<6xf32>
1344+
return %2 : f32
1345+
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ func.func @scf_while_non_equiv_yield(%arg0: tensor<5xi1>,
158158

159159
// -----
160160

161-
func.func private @fun_with_side_effects(%A: tensor<?xf32> {bufferization.writable = true})
161+
func.func private @fun_with_side_effects(%A: tensor<?xf32>)
162162

163163
func.func @foo(%A: tensor<?xf32> {bufferization.writable = true}) -> (tensor<?xf32>) {
164164
call @fun_with_side_effects(%A) : (tensor<?xf32>) -> ()

mlir/test/Dialect/Bufferization/invalid.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,20 @@ func.func @sparse_alloc_call() {
7878
call @foo(%0) : (tensor<20x40xf32, #DCSR>) -> ()
7979
return
8080
}
81+
82+
// -----
83+
84+
// expected-error @+1{{invalid value for 'bufferization.access'}}
85+
func.func private @invalid_buffer_access_type(tensor<*xf32> {bufferization.access = "foo"})
86+
87+
// -----
88+
89+
// expected-error @+1{{'bufferization.writable' is invalid on external functions}}
90+
func.func private @invalid_writable_attribute(tensor<*xf32> {bufferization.writable = false})
91+
92+
// -----
93+
94+
func.func @invalid_writable_on_op() {
95+
// expected-error @+1{{attribute '"bufferization.writable"' not supported as an op attribute by the bufferization dialect}}
96+
arith.constant {bufferization.writable = true} 0 : index
97+
}

mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ func.func @scf_for_with_tensor.insert_slice(
129129
// CHECK-LABEL: func @execute_region_with_conflict(
130130
// CHECK-SAME: %[[m1:.*]]: memref<?xf32
131131
func.func @execute_region_with_conflict(
132-
%t1 : tensor<?xf32> {bufferization.writable = "true"})
132+
%t1 : tensor<?xf32> {bufferization.writable = true})
133133
-> (f32, tensor<?xf32>, f32)
134134
{
135135
%f1 = arith.constant 0.0 : f32

0 commit comments

Comments
 (0)