Skip to content

Commit 7d97678

Browse files
committed
[mlir][linalg] Break up linalg vectorization pre-condition
Break up the vectorization pre-condition into the part checking for static shape and the rest checking if the linalg op is supported by vectorization. This allows checking if an op could be vectorized if it had static shapes. Differential Revision: https://reviews.llvm.org/D115754
1 parent 9c7fbc3 commit 7d97678

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,15 @@ LogicalResult generalizeNamedOpPrecondition(Operation *op);
401401
LogicalResult promoteSubviewsPrecondition(Operation *op,
402402
LinalgPromotionOptions options);
403403

404-
/// Rewrite a linalg.generic into a suitable vector.contraction op.
404+
/// Return success if the operation can be vectorized.
405405
LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
406406

407+
/// Return success if `op` can be vectorized assuming it is static. This allows
408+
/// checking if an op will be vectorizable once all the dimensions are folded to
409+
/// static values.
410+
/// It is the same as `vectorizeLinalgOpPrecondition` for static shapes.
411+
LogicalResult vectorizeStaticLinalgOpPrecondition(LinalgOp op);
412+
407413
//===----------------------------------------------------------------------===//
408414
// Transformations exposed as rewrite patterns.
409415
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -599,34 +599,39 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
599599
return success();
600600
}
601601

602-
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
603-
auto linalgOp = cast<linalg::LinalgOp>(op);
604-
// All types must be static shape to go to vector.
605-
if (linalgOp.hasDynamicShape()) {
606-
LDBG("precondition failed: dynamic shape");
607-
return failure();
608-
}
602+
LogicalResult
603+
mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
609604
if (isElementwise(op))
610605
return success();
611606
// TODO: isaConvolutionOpInterface that can also infer from generic features.
612607
// But we will still need stride/dilation attributes that will be annoying to
613608
// reverse-engineer...
614-
if (isa<ConvolutionOpInterface>(op))
609+
if (isa<ConvolutionOpInterface>(op.getOperation()))
615610
return success();
616611
// TODO: the common vector shape is equal to the static loop sizes only when
617612
// all indexing maps are projected permutations. For convs and stencils the
618613
// logic will need to evolve.
619-
if (!allIndexingsAreProjectedPermutation(linalgOp)) {
614+
if (!allIndexingsAreProjectedPermutation(op)) {
620615
LDBG("precondition failed: not projected permutations");
621616
return failure();
622617
}
623-
if (failed(reductionPreconditions(linalgOp))) {
618+
if (failed(reductionPreconditions(op))) {
624619
LDBG("precondition failed: reduction preconditions");
625620
return failure();
626621
}
627622
return success();
628623
}
629624

625+
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
626+
auto linalgOp = cast<linalg::LinalgOp>(op);
627+
// All types must be static shape to go to vector.
628+
if (linalgOp.hasDynamicShape()) {
629+
LDBG("precondition failed: dynamic shape");
630+
return failure();
631+
}
632+
return vectorizeStaticLinalgOpPrecondition(linalgOp);
633+
}
634+
630635
LogicalResult
631636
mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
632637
SmallVectorImpl<Value> &newResults) {

0 commit comments

Comments
 (0)