Skip to content

Commit efdd4c1

Browse files
[mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface
Differential revision: https://reviews.llvm.org/D117323
1 parent fd1dce3 commit efdd4c1

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ using namespace mlir::linalg;
4343
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
4444
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
4545

46-
static FailureOr<Operation *>
47-
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
46+
/// Try to vectorize `convOp` as a convolution.
47+
static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b,
48+
LinalgOp convOp);
4849

4950
/// Return the unique instance of OpType in `block` if it is indeed unique.
5051
/// Return null if none or more than 1 instances exist.
@@ -636,13 +637,12 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
636637
SmallVector<Value> results;
637638
// TODO: isaConvolutionOpInterface that can also infer from generic
638639
// features. Will require stride/dilation attributes inference.
639-
if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
640-
LDBG("Vectorize as a conv: " << linalgOp);
641-
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
642-
if (failed(convOr))
643-
return failure();
640+
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
641+
if (succeeded(convOr)) {
644642
llvm::append_range(results, (*convOr)->getResults());
645643
} else {
644+
if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
645+
return failure();
646646
LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
647647
if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
648648
return failure();
@@ -1640,40 +1640,39 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
16401640
};
16411641
} // namespace
16421642

1643-
/// Helper function to vectorize a `linalgOp` with convolution semantics.
1643+
/// Helper function to vectorize a LinalgOp with convolution semantics.
16441644
// TODO: extend the generic vectorization to support windows and drop this.
1645-
static FailureOr<Operation *>
1646-
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
1647-
// TODO: these are legitimately part of ConvolutionOpInterface.
1648-
auto strides = convOp->getAttrOfType<DenseIntElementsAttr>("strides");
1649-
auto dilations = convOp->getAttrOfType<DenseIntElementsAttr>("dilations");
1645+
static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
1646+
// The ConvolutionOpInterface gives us guarantees of existence for
1647+
// strides/dilations. However, we do not need to rely on those, we can simply
1648+
// use them if present, otherwise use the default and let the generic conv.
1649+
// matcher in the ConvGenerator succeed or fail.
1650+
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
1651+
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
16501652
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
16511653
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
1652-
LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
1653-
Conv1DNwcGenerator e(b, linalgOp, stride, dilation);
1654+
Conv1DNwcGenerator e(b, op, stride, dilation);
16541655
auto res = e.generateConv();
16551656
if (succeeded(res))
16561657
return res;
16571658
return e.generateDilatedConv();
16581659
}
16591660

1660-
struct VectorizeConvolution
1661-
: public OpInterfaceRewritePattern<ConvolutionOpInterface> {
1661+
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
16621662
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
16631663

1664-
LogicalResult matchAndRewrite(ConvolutionOpInterface convOp,
1664+
LogicalResult matchAndRewrite(LinalgOp op,
16651665
PatternRewriter &rewriter) const override {
1666-
FailureOr<Operation *> resultOrFail =
1667-
vectorizeConvolution(rewriter, convOp);
1666+
FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
16681667
if (failed(resultOrFail))
16691668
return failure();
16701669
Operation *newOp = *resultOrFail;
16711670
if (newOp->getNumResults() == 0) {
1672-
rewriter.eraseOp(convOp.getOperation());
1671+
rewriter.eraseOp(op.getOperation());
16731672
return success();
16741673
}
16751674
assert(newOp->getNumResults() == 1 && "expected single result");
1676-
rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0));
1675+
rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
16771676
return success();
16781677
}
16791678
};

0 commit comments

Comments
 (0)