@@ -56,7 +56,9 @@ namespace {
5656
5757struct DeduplicateAndRemoveDeadOperandsAndResults
5858 : public OpRewritePattern<GenericOp> {
59- using OpRewritePattern<GenericOp>::OpRewritePattern;
59+ DeduplicateAndRemoveDeadOperandsAndResults (MLIRContext *ctx,
60+ bool removeOutputs)
61+ : OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
6062
6163 LogicalResult matchAndRewrite (GenericOp genericOp,
6264 PatternRewriter &rewriter) const override {
@@ -120,6 +122,9 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
120122 }
121123
122124private:
125+ // / If unset, outputs are not modified by this pattern.
126+ bool removeOutputs;
127+
123128 // Deduplicate input operands, and return the
124129 // - Mapping from operand position in the original op, to operand position in
125130 // the canonicalized op.
@@ -176,9 +181,9 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
176181 llvm::SmallDenseMap<unsigned , unsigned > origToNewPos;
177182 llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned >
178183 dedupedOutpts;
179- // If the op doesnt have tensor semantics, keep all the outputs as
180- // preserved.
181- if (!genericOp.hasTensorSemantics ()) {
184+ // If the op doesn't have tensor semantics or outputs should not be removed,
185+ // keep all the outputs as preserved.
186+ if (!genericOp.hasTensorSemantics () || !removeOutputs ) {
182187 for (const auto &en : llvm::enumerate (genericOp.getDpsInitOperands ())) {
183188 origToNewPos[en.index ()] = newOutputOperands.size ();
184189 newOutputOperands.push_back (en.value ()->get ());
@@ -353,10 +358,69 @@ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
353358 return failure ();
354359 }
355360};
361+
362+ // / Fold uses of duplicate inputs in the body of a linalg.generic. E.g.:
363+ // / ```
364+ // / linalg.generic ins(%a, %b, %a, %b) outs(%a)
365+ // / ^bb0(%in0, %in1, %in2, %in3, %out1)
366+ // / ```
367+ // / Assuming that all %a and %b have the same index map:
368+ // / * All uses of %in0 and %in2 are replaced with %out1
369+ // / * All uses of %in1 are replaced with %in3
370+ // / This pattern can enable additional canonicalizations: In the above example,
371+ // / %in0, %in1 and %in3 have no uses anymore and their corresponding operands
372+ // / can be folded away. This pattern does not modify uses of output block args.
373+ struct FoldDuplicateInputBbArgs : public OpRewritePattern <GenericOp> {
374+ using OpRewritePattern<GenericOp>::OpRewritePattern;
375+
376+ LogicalResult matchAndRewrite (GenericOp genericOp,
377+ PatternRewriter &rewriter) const override {
378+ // Find replacement bbArgs for all input bbArg.
379+ DenseMap<int , int > replacements;
380+ for (int i = 0 ; i < genericOp.getNumDpsInputs (); ++i) {
381+ // Skip bbArgs that have no uses.
382+ if (genericOp.getBody ()->getArgument (i).getUses ().empty ())
383+ continue ;
384+ // Find replacement bbArg. This can be an input or an output bbArg.
385+ for (int j = genericOp->getNumOperands () - 1 ; j > i; --j) {
386+ if (genericOp->getOperand (i) == genericOp->getOperand (j) &&
387+ genericOp.getIndexingMapsArray ()[i] ==
388+ genericOp.getIndexingMapsArray ()[j]) {
389+ replacements[i] = j;
390+ break ;
391+ }
392+ }
393+ }
394+
395+ // Stop here if no replacements were found.
396+ if (replacements.empty ())
397+ return failure ();
398+
399+ // Rewrite the op.
400+ rewriter.updateRootInPlace (genericOp, [&]() {
401+ for (auto [before, after] : replacements) {
402+ BlockArgument bbArg = genericOp.getBody ()->getArgument (before);
403+ BlockArgument replacement = genericOp.getBody ()->getArgument (after);
404+ rewriter.replaceAllUsesWith (bbArg, replacement);
405+ }
406+ });
407+
408+ return success ();
409+ }
410+ };
411+
356412} // namespace
357413
358414void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns (
359415 RewritePatternSet &patterns) {
360- patterns.insert <DeduplicateAndRemoveDeadOperandsAndResults,
361- RemoveUnusedCycleInGenericOp>(patterns.getContext ());
416+ patterns.insert <DeduplicateAndRemoveDeadOperandsAndResults>(
417+ patterns.getContext (), /* removeOutputs=*/ true );
418+ patterns.insert <RemoveUnusedCycleInGenericOp>(patterns.getContext ());
419+ }
420+
421+ void mlir::linalg::populateEraseUnnecessaryInputsPatterns (
422+ RewritePatternSet &patterns) {
423+ patterns.insert <DeduplicateAndRemoveDeadOperandsAndResults>(
424+ patterns.getContext (), /* removeOutputs=*/ false );
425+ patterns.insert <FoldDuplicateInputBbArgs>(patterns.getContext ());
362426}
0 commit comments