@@ -43,8 +43,9 @@ using namespace mlir::linalg;
43
43
#define DBGS () (llvm::dbgs() << ' [' << DEBUG_TYPE << " ] " )
44
44
#define LDBG (X ) LLVM_DEBUG(DBGS() << X)
45
45
46
- static FailureOr<Operation *>
47
- vectorizeConvolution (OpBuilder &b, ConvolutionOpInterface convOp);
46
+ // / Try to vectorize `convOp` as a convolution.
47
+ static FailureOr<Operation *> vectorizeConvolution (OpBuilder &b,
48
+ LinalgOp convOp);
48
49
49
50
// / Return the unique instance of OpType in `block` if it is indeed unique.
50
51
// / Return null if none or more than 1 instances exist.
@@ -636,13 +637,12 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
636
637
SmallVector<Value> results;
637
638
// TODO: isaConvolutionOpInterface that can also infer from generic
638
639
// features. Will require stride/dilation attributes inference.
639
- if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation ())) {
640
- LDBG (" Vectorize as a conv: " << linalgOp);
641
- FailureOr<Operation *> convOr = vectorizeConvolution (rewriter, convOp);
642
- if (failed (convOr))
643
- return failure ();
640
+ FailureOr<Operation *> convOr = vectorizeConvolution (rewriter, linalgOp);
641
+ if (succeeded (convOr)) {
644
642
llvm::append_range (results, (*convOr)->getResults ());
645
643
} else {
644
+ if (failed (vectorizeLinalgOpPrecondition (linalgOp)))
645
+ return failure ();
646
646
LDBG (" Vectorize generic by broadcasting to a common shape: " << linalgOp);
647
647
if (failed (vectorizeAsLinalgGeneric (rewriter, linalgOp, results)))
648
648
return failure ();
@@ -1640,40 +1640,39 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
1640
1640
};
1641
1641
} // namespace
1642
1642
1643
- // / Helper function to vectorize a `linalgOp` with convolution semantics.
1643
+ // / Helper function to vectorize a LinalgOp with convolution semantics.
1644
1644
// TODO: extend the generic vectorization to support windows and drop this.
1645
- static FailureOr<Operation *>
1646
- vectorizeConvolution (OpBuilder &b, ConvolutionOpInterface convOp) {
1647
- // TODO: these are legitimately part of ConvolutionOpInterface.
1648
- auto strides = convOp->getAttrOfType <DenseIntElementsAttr>(" strides" );
1649
- auto dilations = convOp->getAttrOfType <DenseIntElementsAttr>(" dilations" );
1645
+ static FailureOr<Operation *> vectorizeConvolution (OpBuilder &b, LinalgOp op) {
1646
+ // The ConvolutionOpInterface gives us guarantees of existence for
1647
+ // strides/dilations. However, we do not need to rely on those, we can simply
1648
+ // use them if present, otherwise use the default and let the generic conv.
1649
+ // matcher in the ConvGenerator succeed or fail.
1650
+ auto strides = op->getAttrOfType <DenseIntElementsAttr>(" strides" );
1651
+ auto dilations = op->getAttrOfType <DenseIntElementsAttr>(" dilations" );
1650
1652
auto stride = strides ? *strides.getValues <uint64_t >().begin () : 1 ;
1651
1653
auto dilation = dilations ? *dilations.getValues <uint64_t >().begin () : 1 ;
1652
- LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation ());
1653
- Conv1DNwcGenerator e (b, linalgOp, stride, dilation);
1654
+ Conv1DNwcGenerator e (b, op, stride, dilation);
1654
1655
auto res = e.generateConv ();
1655
1656
if (succeeded (res))
1656
1657
return res;
1657
1658
return e.generateDilatedConv ();
1658
1659
}
1659
1660
1660
- struct VectorizeConvolution
1661
- : public OpInterfaceRewritePattern<ConvolutionOpInterface> {
1661
+ struct VectorizeConvolution : public OpInterfaceRewritePattern <LinalgOp> {
1662
1662
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
1663
1663
1664
- LogicalResult matchAndRewrite (ConvolutionOpInterface convOp ,
1664
+ LogicalResult matchAndRewrite (LinalgOp op ,
1665
1665
PatternRewriter &rewriter) const override {
1666
- FailureOr<Operation *> resultOrFail =
1667
- vectorizeConvolution (rewriter, convOp);
1666
+ FailureOr<Operation *> resultOrFail = vectorizeConvolution (rewriter, op);
1668
1667
if (failed (resultOrFail))
1669
1668
return failure ();
1670
1669
Operation *newOp = *resultOrFail;
1671
1670
if (newOp->getNumResults () == 0 ) {
1672
- rewriter.eraseOp (convOp .getOperation ());
1671
+ rewriter.eraseOp (op .getOperation ());
1673
1672
return success ();
1674
1673
}
1675
1674
assert (newOp->getNumResults () == 1 && " expected single result" );
1676
- rewriter.replaceOp (convOp .getOperation (), newOp->getResult (0 ));
1675
+ rewriter.replaceOp (op .getOperation (), newOp->getResult (0 ));
1677
1676
return success ();
1678
1677
}
1679
1678
};
0 commit comments