Skip to content

Commit d3a031f

Browse files
banach-space丹治秀樹
authored andcommitted
[mlir][tensor] Add new builders for insert_slice/extract_slice Ops (nfc) (llvm#169533)
Adds new builders for `tensor.insert_slice` and `tensor.extract_slice` Ops for which the _offsets_ and the _strides_ are all 0s and 1s, respecitvely. This allows us to write: ```cpp // No offsets and no strides - implicitly set to 0s and 1s, // respectively. tensor::InsertSliceOp::create(rewriter, loc, src, dest, writeSizes); ``` instead of: ```cpp // Strides are initialised explicitly to 1s Attribute oneIdxAttr = rewriter.getIndexAttr(1); SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); // Offsets are initialised explicitly to 0s Attribute zeroIdxAttr = rewriter.getIndexAttr(0); SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); tensor::InsertSliceOp::create(rewriter, loc, src, dest, writeOffsets, writeSizes, writeStrides); ```
1 parent 9f7dd4a commit d3a031f

File tree

3 files changed

+41
-27
lines changed

3 files changed

+41
-27
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
471471
// a Range vector.
472472
OpBuilder<(ins "Value":$source, "ArrayRef<Range>":$ranges,
473473
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
474+
// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
475+
// result type, offsets set to 0 and strides set to 1.
476+
OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
477+
"ArrayRef<OpFoldResult>":$sizes,
478+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
474479
];
475480

476481
let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -930,7 +935,12 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
930935
// a Range vector and inferred result type.
931936
OpBuilder<(ins "Value":$source, "Value":$dest,
932937
"ArrayRef<Range>":$ranges,
933-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
938+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
939+
// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
940+
// to 0, strides set to 1 and inferred result type.
941+
OpBuilder<(ins "Value":$source, "Value":$dest,
942+
"ArrayRef<OpFoldResult>":$sizes,
943+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
934944
];
935945

936946
let extraClassDeclaration = extraBaseClassDeclaration # [{

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,12 +1167,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11671167
"this is not supported ATM!");
11681168
}
11691169

1170-
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1171-
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
11721170
Location loc = packOp.getLoc();
11731171

11741172
int64_t srcRank = packOp.getSourceRank();
1175-
int64_t destRank = packOp.getDestRank();
11761173

11771174
// 1. Get the input that is going to be packed. If the input requires padding,
11781175
// add a padding operation and return that as the input.
@@ -1262,14 +1259,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12621259
writeSizes.push_back(tileSizeOfr);
12631260
}
12641261

1265-
// TODO: Add a constructor for tensor.insert_slice that doesn't require
1266-
// strides nor offsets.
1267-
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1268-
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1269-
12701262
auto insert = tensor::InsertSliceOp::create(
1271-
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
1272-
writeOffsets, writeSizes, writeStrides);
1263+
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
12731264

12741265
// 4. Replace tensor.packOp with tensor.insert_slice created above
12751266
rewriter.replaceOp(packOp, insert.getResult());
@@ -1279,7 +1270,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12791270

12801271
LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12811272
linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1282-
int64_t srcRank = unpackOp.getSourceRank();
12831273
int64_t destRank = unpackOp.getDestRank();
12841274
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
12851275
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
@@ -1296,7 +1286,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12961286
Value source = unpackOp.getSource();
12971287
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
12981288
unpackOp.getDimAndTileMapping();
1299-
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
13001289
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
13011290

13021291
// The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
@@ -1307,9 +1296,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
13071296
// outer-tiled-dims being all 1), this will be
13081297
// [ outer-untiled-dims, tile-sizes ]
13091298
SmallVector<OpFoldResult> extractSliceSizes;
1310-
// The offset and strides attributes for ExtractSliceOp.
1311-
SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
1312-
SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
13131299

13141300
// Shape for EmptyOp that's used as the init value for TransposeOp below.
13151301
// This should be:
@@ -1364,8 +1350,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
13641350
Type elemType = unpackOp.getSourceType().getElementType();
13651351
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
13661352
Value innerTile = tensor::ExtractSliceOp::create(
1367-
rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets,
1368-
extractSliceSizes, extractSliceStrides);
1353+
rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
13691354

13701355
// 2. Transpose the tile to match the outer corresponding tile order.
13711356
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
@@ -1381,9 +1366,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
13811366

13821367
// 3. Handle in-complete tiles if needed. It truncates trailing data from the
13831368
// transposed tile.
1384-
int numLoops = shapeForEmptyOp.size();
1385-
SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
1386-
SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
13871369
SmallVector<OpFoldResult> tileSizes;
13881370
ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
13891371
for (auto i : llvm::seq<unsigned>(0, destRank)) {
@@ -1393,22 +1375,19 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
13931375
}
13941376

13951377
auto partialTile =
1396-
tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0],
1397-
tileOffsets, tileSizes, tileStrides);
1378+
tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
1379+
transposedOp.getResult()[0], tileSizes);
13981380

13991381
// 4. Insert the result to the destination tensor.
14001382
SmallVector<OpFoldResult> writeSizes;
1401-
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1402-
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
14031383
for (int i = 0, idx = 0; i < destRank; ++i) {
14041384
if (dimAndTileMapping.count(i) || destShape[i] != 1)
14051385
writeSizes.push_back(tileSizes[idx++]);
14061386
else
14071387
writeSizes.push_back(oneIdxAttr);
14081388
}
14091389
auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1410-
unpackOp.getDest(), writeOffsets,
1411-
writeSizes, writeStrides);
1390+
unpackOp.getDest(), writeSizes);
14121391
rewriter.replaceOp(unpackOp, insert.getResult());
14131392

14141393
return success();

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,6 +2445,19 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
24452445
}
24462446
}
24472447

2448+
/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
2449+
/// result type, offsets set to 0 and strides set to 1.
2450+
void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2451+
RankedTensorType resultType, Value source,
2452+
ArrayRef<OpFoldResult> sizes,
2453+
ArrayRef<NamedAttribute> attrs) {
2454+
Attribute zeroIdxAttr = b.getIndexAttr(0);
2455+
Attribute oneIdxAttr = b.getIndexAttr(1);
2456+
SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2457+
SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2458+
build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
2459+
}
2460+
24482461
/// Verifier for ExtractSliceOp.
24492462
LogicalResult ExtractSliceOp::verify() {
24502463
RankedTensorType sourceType = getSourceType();
@@ -3889,6 +3902,18 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
38893902
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
38903903
}
38913904

3905+
// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
3906+
// to 0, strides set to 1 and inferred result type.
3907+
void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
3908+
Value dest, ArrayRef<OpFoldResult> sizes,
3909+
ArrayRef<NamedAttribute> attrs) {
3910+
Attribute zeroIdxAttr = b.getIndexAttr(0);
3911+
Attribute oneIdxAttr = b.getIndexAttr(1);
3912+
SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3913+
SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3914+
build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3915+
}
3916+
38923917
LogicalResult ParallelInsertSliceOp::verify() {
38933918
if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
38943919
return this->emitError("expected InParallelOpInterface parent, got:")

0 commit comments

Comments
 (0)