Skip to content

Commit 996639d

Browse files
[MLIR][BufferResultsToOutParamsPass] Add Option to Modify Public Function's Signature (llvm#167248)
Since llvm#162441, `buffer-results-to-out-params` transforms `private` functions only. But, as mentioned in llvm#162441 (comment), this is a breaking change for pipelines handling C code. Our pipeline @EfficientComputer is also affected by this breaking change. Therefore, this PR adds an opt-in flag to allow `public` functions to be transformed by `BufferResultsToOutParamsPass`.
1 parent 0d786b9 commit 996639d

File tree

4 files changed

+54
-2
lines changed

4 files changed

+54
-2
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ struct BufferResultsToOutParamsOpts {
171171
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
172172
/// memref is allocated in the current function and has dynamic shape.
173173
bool hoistDynamicAllocs = false;
174+
175+
/// If true, the pass modifies the function signatures of public functions.
176+
bool modifyPublicFunctions = false;
174177
};
175178

176179
/// Replace buffers that are returned from a function with an out parameter.

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,9 @@ def BufferResultsToOutParamsPass
258258
/*default=*/"false", "Hoist static allocations to call sites.">,
259259
Option<"hoistDynamicAllocs", "hoist-dynamic-allocs", "bool",
260260
/*default=*/"false", "Hoist dynamic allocations to call sites.">,
261+
Option<"modifyPublicFunctions", "modify-public-functions", "bool",
262+
/*default=*/"false", "Modify function signatures of public "
263+
"functions.">,
261264
];
262265
let dependentDialects = ["memref::MemRefDialect"];
263266
}

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,9 @@ updateCalls(ModuleOp module, const AllocDynamicSizesMap &map,
217217
}
218218
if (!options.filterFn(&callee))
219219
return;
220-
if (callee.isExternal() || callee.isPublic())
220+
if (callee.isPublic() && !options.modifyPublicFunctions)
221+
return;
222+
if (callee.isExternal())
221223
return;
222224

223225
SmallVector<Value, 6> replaceWithNewCallResults;
@@ -295,7 +297,9 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
295297
// function.
296298
AllocDynamicSizesMap map;
297299
for (auto func : module.getOps<func::FuncOp>()) {
298-
if (func.isExternal() || func.isPublic())
300+
if (func.isPublic() && !options.modifyPublicFunctions)
301+
continue;
302+
if (func.isExternal())
299303
continue;
300304
if (!options.filterFn(&func))
301305
continue;
@@ -326,6 +330,8 @@ struct BufferResultsToOutParamsPass
326330
options.hoistStaticAllocs = true;
327331
if (hoistDynamicAllocs)
328332
options.hoistDynamicAllocs = true;
333+
if (modifyPublicFunctions)
334+
options.modifyPublicFunctions = true;
329335

330336
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
331337
options)))
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{modify-public-functions})' %s | FileCheck %s
2+
3+
// Test if `public` functions' return values are transformed into out parameters
4+
// when `buffer-results-to-out-params` is invoked with `modifyPublicFunctions`.
5+
6+
// CHECK-LABEL: func.func @basic(
7+
// CHECK-SAME: %[[ARG0:.*]]: memref<f32>) {
8+
// CHECK: %[[VAL_0:.*]] = "test.source"() : () -> memref<f32>
9+
// CHECK: memref.copy %[[VAL_0]], %[[ARG0]] : memref<f32> to memref<f32>
10+
// CHECK: return
11+
// CHECK: }
12+
func.func @basic() -> (memref<f32>) {
13+
%0 = "test.source"() : () -> (memref<f32>)
14+
return %0 : memref<f32>
15+
}
16+
17+
// CHECK-LABEL: func.func @presence_of_existing_arguments(
18+
// CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>,
19+
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) {
20+
// CHECK: %[[VAL_0:.*]] = "test.source"() : () -> memref<2xf32>
21+
// CHECK: memref.copy %[[VAL_0]], %[[ARG1]] : memref<2xf32> to memref<2xf32>
22+
// CHECK: return
23+
// CHECK: }
24+
func.func @presence_of_existing_arguments(%arg0: memref<1xf32>) -> (memref<2xf32>) {
25+
%0 = "test.source"() : () -> (memref<2xf32>)
26+
return %0 : memref<2xf32>
27+
}
28+
29+
// CHECK-LABEL: func.func @multiple_results(
30+
// CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>,
31+
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) {
32+
// CHECK: %[[VAL_0:.*]]:2 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
33+
// CHECK: memref.copy %[[VAL_0]]#0, %[[ARG0]] : memref<1xf32> to memref<1xf32>
34+
// CHECK: memref.copy %[[VAL_0]]#1, %[[ARG1]] : memref<2xf32> to memref<2xf32>
35+
// CHECK: return
36+
// CHECK: }
37+
func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
38+
%0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
39+
return %0, %1 : memref<1xf32>, memref<2xf32>
40+
}

0 commit comments

Comments
 (0)