Skip to content

Commit 0788dba

Browse files
author
git apple-llvm automerger
committed
Merge commit 'f53b6249c240' from llvm.org/main into next
2 parents eb0d6c3 + f53b624 commit 0788dba

File tree

3 files changed

+55
-36
lines changed

3 files changed

+55
-36
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
5757
/// tile factors.
5858
DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();
5959

60-
/// Return the tile sizes as OpFoldResult.
60+
// TODO: Return the folded result.
61+
/// Return the tile sizes as OpFoldResult. Will return the Value
62+
/// of the constant Op, not the constant Attribute.
63+
/// E.g., for %size = arith.constant 1 : i32 will return %size,
64+
/// not 1.
6165
SmallVector<OpFoldResult> getMixedTiles();
6266

6367
/// Return the tile sizes as `int64_t`. If a tile size is dynamic

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,37 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11461146
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
11471147
Location loc = packOp.getLoc();
11481148

1149-
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1150-
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1151-
packOp.getDimAndTileMapping();
11521149
int64_t srcRank = packOp.getSourceRank();
11531150
int64_t destRank = packOp.getDestRank();
1154-
int64_t numTiles = destRank - srcRank;
1151+
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
1152+
int64_t numberOfTiles = innerDimsPos.size();
11551153

1156-
// 1. Extract the inner tile sizes.
1157-
// Where possible, values are replaced with constant attributes (to match the
1158-
// behaviour of `getPackOpSourceOrPaddedSource`).
1159-
SmallVector<OpFoldResult> tileSizes;
1160-
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
1161-
if (dimAndTileMapping.count(i)) {
1162-
// Rather than taking the tile size as is, extact the actual constant
1163-
// value Attribute where possible, e.g.:
1164-
// [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1165-
auto [_, tileSize] =
1166-
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
1167-
tileSizes.push_back(tileSize);
1168-
}
1169-
}
1154+
// 1. Get the input that is going to be packed. If the input requires padding,
1155+
// add a padding operation and return that as the input.
1156+
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
11701157

11711158
// 2. Transpose the input to match the inner tile order:
11721159
// %init = tensor.empty()
11731160
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
11741161
// outs(%init)
11751162
// Assumptions made:
1176-
// 1. All outer dims are 1 - the corresponding transposition order doesn't
1163+
// - All outer dims are 1 - the corresponding transposition order doesn't
11771164
// matter, but requires all dim indices to be present.
1165+
1166+
// 2.1 Get the permutation for linalg.transpose
11781167
SmallVector<int64_t> srcPermForTranspose;
1179-
ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
11801168
for (int64_t i = 0; i < srcRank; i++) {
11811169
// We assume the `k` dimensions of the inner dim position, where `k` is the
11821170
// rank of the inner tiling, correspond to the last `k` indices of the
@@ -1185,27 +1173,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11851173
// rank of the source tensor. For example if we have a source tensor with
11861174
// indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
11871175
// indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1188-
if (llvm::is_contained(innerDimPos, i))
1176+
if (llvm::is_contained(innerDimsPos, i))
11891177
continue;
11901178
srcPermForTranspose.push_back(i);
11911179
}
1192-
srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
1180+
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1181+
1182+
// 2.2 Create the init tensor for linalg.transpose with the correct shape
1183+
SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
1184+
oneIdxAttr);
1185+
shapeForEmptyOp.append(packOp.getMixedTiles());
1186+
1187+
// getMixedTiles() may contain Values pointing to constant ops, not the
1188+
// constant attributes. Replace them with a true OpFoldResult.
1189+
llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1190+
[&](OpFoldResult ofr) {
1191+
if (auto val = llvm::dyn_cast<Value>(ofr))
1192+
return getAsOpFoldResult(val);
1193+
return ofr;
1194+
});
11931195

11941196
LDBG() << "Pack permutation: " << packOp;
11951197
LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
1198+
LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
11961199

1197-
// 2.1 Create tensor.empty (init value for TransposeOp)
1198-
SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
1199-
oneIdxAttr);
1200-
transShapeForEmptyOp.append(tileSizes);
1201-
1202-
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1203-
srcPermForTranspose);
1204-
Value empty =
1205-
tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
1206-
packOp.getSourceType().getElementType());
1200+
Value empty = tensor::EmptyOp::create(
1201+
rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
12071202

1208-
// 2.2 Create linalg.transpose
1203+
// 2.3 Create linalg.transpose
12091204
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
12101205
srcPermForTranspose);
12111206

@@ -1214,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12141209
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
12151210
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
12161211
// Outer dims are all 1s!
1217-
SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1218-
oneIdxAttr);
1212+
SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
12191213
SmallVector<int64_t> writeShape;
12201214

12211215
for (auto tileSize : packOp.getMixedTiles()) {

mlir/test/Dialect/Linalg/decompose-pack.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,24 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(
274274
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
275275
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
276276
// CHECK: return %[[INSERT]]
277+
278+
// -----
279+
280+
// The following example shows a pack operation where the inner dims
281+
// positions are non-adjacent and non-permuted.
282+
func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> {
283+
%pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
284+
return %pack : tensor<1x1x1x1x8x1xf32>
285+
}
286+
287+
// CHECK-LABEL: func.func @pack_with_non_adjacent_and_non_permuted_inner_dims
288+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
289+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
290+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32>
291+
// CHECK: %[[TRANSP:.+]] = linalg.transpose
292+
// CHECK-SAME: ins(%[[SRC]] : tensor<8x1x1x1xf32>)
293+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x1xf32>)
294+
// CHECK-SAME: permutation = [1, 2, 0, 3]
295+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
296+
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
297+
// CHECK: return %[[INSERT]]

0 commit comments

Comments
 (0)