Skip to content

Commit e7328a9

Browse files
[mlir][linalg] Fold duplicate and unused inputs in linalg.generic
If an input bbArg is not used, its corresponding input operand is removed. If there are duplicate input operands or input operands that are also used as output operands, the duplicate input operands are removed. Output operands are never modified. Differential Revision: https://reviews.llvm.org/D139709
1 parent e96ddad commit e7328a9

File tree

4 files changed

+113
-6
lines changed

4 files changed

+113
-6
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns);
8888
/// This is effectively DCE for a linalg op.
8989
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
9090

91+
/// Patterns to promote inputs to outputs and remove unused inputs of
92+
/// `linalg.generic` ops.
93+
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
94+
9195
/// Function type to control generic op dimension collapsing. It is expected
9296
/// to return an array of `ReassociationIndices` representing dimensions that
9397
/// should be merged.

mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ namespace {
5656

5757
struct DeduplicateAndRemoveDeadOperandsAndResults
5858
: public OpRewritePattern<GenericOp> {
59-
using OpRewritePattern<GenericOp>::OpRewritePattern;
59+
DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
60+
bool removeOutputs)
61+
: OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
6062

6163
LogicalResult matchAndRewrite(GenericOp genericOp,
6264
PatternRewriter &rewriter) const override {
@@ -120,6 +122,9 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
120122
}
121123

122124
private:
125+
/// If unset, outputs are not modified by this pattern.
126+
bool removeOutputs;
127+
123128
// Deduplicate input operands, and return the
124129
// - Mapping from operand position in the original op, to operand position in
125130
// the canonicalized op.
@@ -176,9 +181,9 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
176181
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
177182
llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
178183
dedupedOutpts;
179-
// If the op doesnt have tensor semantics, keep all the outputs as
180-
// preserved.
181-
if (!genericOp.hasTensorSemantics()) {
184+
// If the op doesn't have tensor semantics or outputs should not be removed,
185+
// keep all the outputs as preserved.
186+
if (!genericOp.hasTensorSemantics() || !removeOutputs) {
182187
for (const auto &en : llvm::enumerate(genericOp.getDpsInitOperands())) {
183188
origToNewPos[en.index()] = newOutputOperands.size();
184189
newOutputOperands.push_back(en.value()->get());
@@ -353,10 +358,69 @@ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
353358
return failure();
354359
}
355360
};
361+
362+
/// Fold uses of duplicate inputs in the body of a linalg.generic. E.g.:
363+
/// ```
364+
/// linalg.generic ins(%a, %b, %a, %b) outs(%a)
365+
/// ^bb0(%in0, %in1, %in2, %in3, %out1)
366+
/// ```
367+
/// Assuming that all %a and %b have the same index map:
368+
/// * All uses of %in0 and %in2 are replaced with %out1
369+
/// * All uses of %in1 are replaced with %in3
370+
/// This pattern can enable additional canonicalizations: In the above example,
371+
/// %in0, %in1 and %in3 have no uses anymore and their corresponding operands
372+
/// can be folded away. This pattern does not modify uses of output block args.
373+
struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
374+
using OpRewritePattern<GenericOp>::OpRewritePattern;
375+
376+
LogicalResult matchAndRewrite(GenericOp genericOp,
377+
PatternRewriter &rewriter) const override {
378+
// Find replacement bbArgs for all input bbArg.
379+
DenseMap<int, int> replacements;
380+
for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
381+
// Skip bbArgs that have no uses.
382+
if (genericOp.getBody()->getArgument(i).getUses().empty())
383+
continue;
384+
// Find replacement bbArg. This can be an input or an output bbArg.
385+
for (int j = genericOp->getNumOperands() - 1; j > i; --j) {
386+
if (genericOp->getOperand(i) == genericOp->getOperand(j) &&
387+
genericOp.getIndexingMapsArray()[i] ==
388+
genericOp.getIndexingMapsArray()[j]) {
389+
replacements[i] = j;
390+
break;
391+
}
392+
}
393+
}
394+
395+
// Stop here if no replacements were found.
396+
if (replacements.empty())
397+
return failure();
398+
399+
// Rewrite the op.
400+
rewriter.updateRootInPlace(genericOp, [&]() {
401+
for (auto [before, after] : replacements) {
402+
BlockArgument bbArg = genericOp.getBody()->getArgument(before);
403+
BlockArgument replacement = genericOp.getBody()->getArgument(after);
404+
rewriter.replaceAllUsesWith(bbArg, replacement);
405+
}
406+
});
407+
408+
return success();
409+
}
410+
};
411+
356412
} // namespace
357413

358414
void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns(
359415
RewritePatternSet &patterns) {
360-
patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults,
361-
RemoveUnusedCycleInGenericOp>(patterns.getContext());
416+
patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
417+
patterns.getContext(), /*removeOutputs=*/true);
418+
patterns.insert<RemoveUnusedCycleInGenericOp>(patterns.getContext());
419+
}
420+
421+
void mlir::linalg::populateEraseUnnecessaryInputsPatterns(
422+
RewritePatternSet &patterns) {
423+
patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
424+
patterns.getContext(), /*removeOutputs=*/false);
425+
patterns.insert<FoldDuplicateInputBbArgs>(patterns.getContext());
362426
}

mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-erase-unused-operands-and-results | FileCheck %s
2+
// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-erase-unnecessary-inputs | FileCheck %s --check-prefix=CHECK-INPUT
23

34
// CHECK-LABEL: func @remove_deadargs_generic_basic
45
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
@@ -493,3 +494,29 @@ func.func @drop_only_the_cycles_not_used_by_others(%arg0 : tensor<?x?x?xf32>) ->
493494
// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]]
494495
// CHECK-SAME: outs(%[[ARG0]], %[[INIT]] :
495496
// CHECK: return %[[GENERIC]]#0
497+
498+
499+
// -----
500+
501+
// CHECK-INPUT-LABEL: func @remove_unnecessary_input(
502+
// CHECK-INPUT-SAME: %[[a:.*]]: tensor<?xf32>, %[[b:.*]]: tensor<?xf32>
503+
#map = affine_map<(d0) -> (d0)>
504+
func.func @remove_unnecessary_input(%a: tensor<?xf32>, %b: tensor<?xf32>)
505+
-> tensor<?xf32>
506+
{
507+
// CHECK-INPUT: %[[result:.*]] = linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel"]}
508+
// CHECK-INPUT-SAME: ins(%[[a]] : tensor<?xf32>) outs(%[[b]] : tensor<?xf32>) {
509+
// CHECK-INPUT: ^bb0(%[[in:.*]]: f32, %[[out:.*]]: f32):
510+
// CHECK-INPUT: %[[add:.*]] = arith.addf %[[in]], %[[out]]
511+
// CHECK-INPUT: linalg.yield %[[add]]
512+
// CHECK-INPUT: } -> tensor<?xf32>
513+
// CHECK-INPUT: return %[[result]]
514+
%0 = linalg.generic
515+
{indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]}
516+
ins(%a, %b : tensor<?xf32>, tensor<?xf32>) outs(%b : tensor<?xf32>) {
517+
^bb0(%in: f32, %in_2: f32, %out: f32):
518+
%16 = arith.addf %in, %in_2 : f32
519+
linalg.yield %16 : f32
520+
} -> tensor<?xf32>
521+
return %0 : tensor<?xf32>
522+
}

mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ struct TestLinalgTransforms
113113
*this, "test-erase-unused-operands-and-results",
114114
llvm::cl::desc("Test patterns to erase unused operands and results"),
115115
llvm::cl::init(false)};
116+
Option<bool> testEraseUnnecessaryInputs{
117+
*this, "test-erase-unnecessary-inputs",
118+
llvm::cl::desc("Test patterns to erase unnecessary inputs"),
119+
llvm::cl::init(false)};
116120
};
117121
} // namespace
118122

@@ -185,6 +189,12 @@ static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
185189
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
186190
}
187191

192+
static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
193+
RewritePatternSet patterns(funcOp.getContext());
194+
populateEraseUnnecessaryInputsPatterns(patterns);
195+
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
196+
}
197+
188198
/// Apply transformations specified as patterns.
189199
void TestLinalgTransforms::runOnOperation() {
190200
if (testPatterns)
@@ -205,6 +215,8 @@ void TestLinalgTransforms::runOnOperation() {
205215
return applySwapExtractSliceWithFillPattern(getOperation());
206216
if (testEraseUnusedOperandsAndResults)
207217
return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
218+
if (testEraseUnnecessaryInputs)
219+
return applyEraseUnnecessaryInputs(getOperation());
208220
}
209221

210222
namespace mlir {

0 commit comments

Comments
 (0)