Skip to content

Commit 93bbcff

Browse files
[mlir][Transform] Make FuseIntoContainingOp support rank-reducing extract slices
This fixes an issue where rank-reducing + fusion would not interop properly. Differential Revision: https://reviews.llvm.org/D139844
1 parent cde2cc9 commit 93bbcff

File tree

4 files changed

+93
-7
lines changed

4 files changed

+93
-7
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,15 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
435435
/// Return the dimensions of the source that are dropped in the
436436
/// result when the result is rank-reduced.
437437
llvm::SmallBitVector getDroppedDims();
438+
439+
/// Given a `value`, asserted to be of RankedTensorType, build an
440+
/// ExtractSliceOp that results in a rank-reducing extract to the desired
441+
/// tensor shape and return the new value created.
442+
/// If the shape of `value` is already the `desiredShape`, just return
443+
/// `value`.
444+
/// If the shape of `value` cannot be rank-reduced to `desiredShape`, fail.
445+
static FailureOr<Value> rankReduceIfNeeded(
446+
OpBuilder &b, Location loc, Value value, ArrayRef<int64_t> desiredShape);
438447
}];
439448

440449
let hasCanonicalizer = 1;

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
2121
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
2222
#include "mlir/Dialect/Transform/IR/TransformUtils.h"
23+
#include "mlir/IR/BuiltinTypes.h"
2324
#include "mlir/IR/OpDefinition.h"
2425
#include "mlir/Interfaces/TilingInterface.h"
2526
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -299,7 +300,14 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
299300

300301
// Replace the extract op.
301302
Operation *fusedOp = tiledProducer->getDefiningOp();
302-
rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
303+
auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
304+
rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
305+
sliceOpToTile->getResult(0)
306+
.getType()
307+
.cast<RankedTensorType>()
308+
.getShape());
309+
assert(succeeded(maybeRankReduced) && "unexpected shape");
310+
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
303311
return fusedOp;
304312
}
305313

@@ -399,7 +407,14 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
399407

400408
// Replace the extract op.
401409
Operation *fusedOp = tiledProducer->getDefiningOp();
402-
rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
410+
auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
411+
rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
412+
sliceOpToTile->getResult(0)
413+
.getType()
414+
.cast<RankedTensorType>()
415+
.getShape());
416+
assert(succeeded(maybeRankReduced) && "unexpected shape");
417+
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
403418

404419
// Replace the use in containingOp.
405420
rewriter.updateRootInPlace(containingOp, [&]() {

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
#include "mlir/IR/BlockAndValueMapping.h"
1818
#include "mlir/IR/Builders.h"
1919
#include "mlir/IR/BuiltinAttributeInterfaces.h"
20+
#include "mlir/IR/BuiltinTypes.h"
2021
#include "mlir/IR/Matchers.h"
22+
#include "mlir/IR/OpDefinition.h"
2123
#include "mlir/IR/TypeUtilities.h"
2224
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2325
#include "mlir/Support/MathExtras.h"
@@ -1754,6 +1756,23 @@ llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
17541756
return droppedDims;
17551757
}
17561758

1759+
FailureOr<Value>
1760+
ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
1761+
ArrayRef<int64_t> desiredShape) {
1762+
auto sourceTensorType = value.getType().dyn_cast<RankedTensorType>();
1763+
assert(sourceTensorType && "not a ranked tensor type");
1764+
auto sourceShape = sourceTensorType.getShape();
1765+
if (sourceShape.equals(desiredShape))
1766+
return value;
1767+
auto maybeRankReductionMask =
1768+
mlir::computeRankReductionMask(sourceShape, desiredShape);
1769+
if (!maybeRankReductionMask)
1770+
return failure();
1771+
return createCanonicalRankReducingExtractSliceOp(
1772+
b, loc, value,
1773+
RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
1774+
}
1775+
17571776
LogicalResult ExtractSliceOp::reifyResultShapes(
17581777
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
17591778
reifiedReturnShapes.resize(1);
@@ -2375,7 +2394,6 @@ struct InsertSliceOpSourceCastInserter final
23752394
insertSliceOp, cast, insertSliceOp.getDest(),
23762395
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
23772396
insertSliceOp.getMixedStrides());
2378-
cast.getDefiningOp()->getParentOfType<ModuleOp>().dump();
23792397
return success();
23802398
}
23812399
};
@@ -2475,8 +2493,7 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
24752493

24762494
SmallVector<int64_t, 4> inferredShape;
24772495
for (auto i : llvm::seq<unsigned>(0, rank)) {
2478-
if (sourceType.isDynamicDim(i) ||
2479-
staticLow[i] == ShapedType::kDynamic ||
2496+
if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
24802497
staticHigh[i] == ShapedType::kDynamic) {
24812498
inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
24822499
: resultShape[i]);
@@ -2525,8 +2542,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
25252542
// This will grow staticLow and staticHigh with 1 value. If the config is
25262543
// dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
25272544
// value as well.
2528-
dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
2529-
ShapedType::kDynamic);
2545+
dispatchIndexOpFoldResults(low, dynamicLow, staticLow, ShapedType::kDynamic);
25302546
dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
25312547
ShapedType::kDynamic);
25322548
if (!resultType) {

mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,52 @@ module {
9696

9797
// -----
9898

99+
module {
100+
func.func @foo(%0: tensor<f32>) -> tensor<f32> {
101+
return %0: tensor<f32>
102+
}
103+
104+
// CHECK-LABEL: func.func @fuse_tileable_op_rank_reducing
105+
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
106+
// CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
107+
// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<?xf32>
108+
func.func @fuse_tileable_op_rank_reducing(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
109+
%cst = arith.constant 4.200000e+01 : f32
110+
%c0 = arith.constant 0 : index
111+
%0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
112+
%d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
113+
114+
// CHECK: scf.foreach_thread {{.*}} -> (tensor<?xf32>) {
115+
%2 = scf.foreach_thread (%arg3) in (%d0) shared_outs(%o = %0) -> (tensor<?xf32>) {
116+
%5 = tensor.extract_slice %o[%arg3] [1] [1] : tensor<?xf32> to tensor<f32>
117+
118+
// CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [1] [1] : tensor<?xf32> to tensor<1xf32>
119+
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<1xf32>) -> tensor<1xf32>
120+
// CHECK: tensor.extract_slice %{{.*}}[0] [1] [1] : tensor<1xf32> to tensor<f32>
121+
// CHECK: func.call @foo(%{{.*}}) : (tensor<f32>) -> tensor<f32>
122+
%7 = func.call @foo(%5) : (tensor<f32>) -> tensor<f32>
123+
124+
scf.foreach_thread.perform_concurrently {
125+
// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [1] [1] : tensor<f32> into tensor<?xf32>
126+
tensor.parallel_insert_slice %7 into %o[%arg3] [1] [1] : tensor<f32> into tensor<?xf32>
127+
}
128+
}
129+
// CHECK: }
130+
func.return %2 : tensor<?xf32>
131+
}
132+
133+
transform.sequence failures(propagate) {
134+
^bb1(%arg1: !pdl.operation):
135+
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1
136+
%1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
137+
138+
// linalg.fill is tileable. The op is tiled and fused.
139+
transform.structured.fuse_into_containing_op %0 into %1
140+
}
141+
}
142+
143+
// -----
144+
99145
#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
100146
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
101147
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>

0 commit comments

Comments
 (0)