Skip to content

Commit 8a8f0a0

Browse files
[mlir][Linalg] Relax PadTensor tiling constraints and expose it to strategies.
Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D117334
1 parent d96a504 commit 8a8f0a0

File tree

4 files changed

+16
-1
lines changed

4 files changed

+16
-1
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
4646
//===----------------------------------------------------------------------===//
4747
using LinalgLoops = SmallVector<Operation *, 4>;
4848

49+
void populatePadTensorTilingPatterns(RewritePatternSet &patterns,
50+
const LinalgTilingOptions &options);
51+
4952
/// Populate patterns for vectorizing low-D convolution ops. This is a step in
5053
/// progressive lowering for convolution ops, it assume high-D convolution ops
5154
/// were decomposed previously.

mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ struct LinalgStrategyTilePass
100100
filter);
101101
else
102102
tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
103+
if (anchorOpName == linalg::PadTensorOp::getOperationName())
104+
populatePadTensorTilingPatterns(tilingPattern, options);
103105
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
104106
}
105107

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,9 @@ static LogicalResult tilePadTensorOp(RewriterBase &builder, PadTensorOp op,
354354
int64_t rank = op.getResultType().getRank();
355355
SmallVector<Value> tileSizes =
356356
options.tileSizeComputationFunction(builder, op);
357-
assert(static_cast<int64_t>(tileSizes.size()) == rank);
357+
// Normalize untiled padding dimensions to 0.
358+
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
359+
tileSizes.append(rank - tileSizes.size(), zero);
358360
// Compute lower and upper bounds of the loop nest.
359361
SmallVector<Range> ranges = op.getIterationDomain(builder);
360362
SmallVector<Value> lbs, dims, allDims, steps;
@@ -490,6 +492,12 @@ static void insertTilingPatterns(RewritePatternSet &patterns,
490492
patterns.add<PadTensorOpTilingPattern>(ctx, options);
491493
}
492494

495+
void mlir::linalg::populatePadTensorTilingPatterns(
496+
RewritePatternSet &patterns, const LinalgTilingOptions &options) {
497+
auto *ctx = patterns.getContext();
498+
patterns.add<PadTensorOpTilingPattern>(ctx, options);
499+
}
500+
493501
static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
494502
MLIRContext *ctx = funcOp.getContext();
495503
RewritePatternSet patterns(ctx);

mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// RUN: FileCheck %s -check-prefix=TILE2
33
// RUN: mlir-opt %s -linalg-tile="tile-sizes=0,3" -resolve-shaped-type-result-dims -cse -split-input-file | \
44
// RUN: FileCheck %s -check-prefix=TILE1
5+
// This test only checks that tiling does not crash.
6+
// RUN: mlir-opt %s -linalg-tile="tile-sizes=2" -resolve-shaped-type-result-dims -cse -split-input-file
57

68
// TILE2-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)>
79
// TILE2-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 7)>

0 commit comments

Comments
 (0)