Skip to content

Commit c539ec0

Browse files
authored
[mlir][vector] Add support for vector extract/insert_strided_slice in vector distribution. (llvm#145421)
This PR adds initial support for `vector.extract_strided_slice` and `vector.insert_strided_slice` ops in vector distribution.
1 parent 2b2bd51 commit c539ec0

File tree

2 files changed

+295
-10
lines changed

2 files changed

+295
-10
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 215 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1616
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
1717
#include "mlir/IR/AffineExpr.h"
18+
#include "mlir/IR/Attributes.h"
19+
#include "mlir/IR/BuiltinTypes.h"
1820
#include "mlir/Interfaces/SideEffectInterfaces.h"
1921
#include "mlir/Transforms/RegionUtils.h"
2022
#include "llvm/ADT/SetVector.h"
23+
#include "llvm/ADT/SmallVectorExtras.h"
2124
#include "llvm/Support/FormatVariadic.h"
2225
#include <utility>
2326

@@ -52,6 +55,25 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
5255
return map;
5356
}
5457

58+
/// Given a sequential and distributed vector type, returns the distributed
59+
/// dimension. This function expects that only a single dimension is
60+
/// distributed.
61+
static int getDistributedDim(VectorType sequentialType,
62+
VectorType distributedType) {
63+
assert(sequentialType.getRank() == distributedType.getRank() &&
64+
"sequential and distributed vector types must have the same rank");
65+
int64_t distributedDim = -1;
66+
for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
67+
if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
68+
// Keep this assert here in case WarpExecuteOnLane0Op gets extended to
69+
// support distributing multiple dimensions in the future.
70+
assert(distributedDim == -1 && "found multiple distributed dims");
71+
distributedDim = i;
72+
}
73+
}
74+
return distributedDim;
75+
}
76+
5577
namespace {
5678

5779
/// Helper struct to create the load / store operations that permit transit
@@ -1076,6 +1098,196 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
10761098
}
10771099
};
10781100

1101+
/// Sink out insert_strided_slice op feeding into a warp op yield.
1102+
/// ```
1103+
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
1104+
/// ...
1105+
/// %src = ... : vector<4x32xf32>
1106+
/// %dest = ... : vector<8x32xf32>
1107+
/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1108+
/// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
1109+
/// gpu.yield %insert : vector<8x32xf32>
1110+
/// }
1111+
/// ```
1112+
/// To
1113+
/// ```
1114+
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
1115+
/// vector<8x1xf32>) {
1116+
/// ...
1117+
/// %src = ... : vector<4x32xf32>
1118+
/// %dest = ... : vector<8x32xf32>
1119+
/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
1120+
/// }
1121+
/// %insert = vector.insert_strided_slice %0#0, %0#1,
1122+
/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
1123+
/// ```
1124+
/// NOTE: Current support assumes that both src and dest vectors are distributed
1125+
/// to lanes and sinking the insert op does not require any cross lane
1126+
/// communication.
1127+
struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
1128+
using Base::Base;
1129+
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1130+
PatternRewriter &rewriter) const override {
1131+
OpOperand *operand =
1132+
getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1133+
if (!operand)
1134+
return failure();
1135+
unsigned int operandNumber = operand->getOperandNumber();
1136+
auto insertOp =
1137+
operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1138+
auto distributedType =
1139+
cast<VectorType>(warpOp.getResult(operandNumber).getType());
1140+
// Distributed type must be 2D or higher.
1141+
// TODO: Support 1D distributed types.
1142+
if (distributedType.getRank() < 2)
1143+
return rewriter.notifyMatchFailure(
1144+
insertOp, "result vector type must be 2D or higher");
1145+
// Find the distributed dimension of the dest vector. There should be
1146+
// exactly one.
1147+
auto yieldedType = cast<VectorType>(operand->get().getType());
1148+
int64_t destDistributedDim =
1149+
getDistributedDim(yieldedType, distributedType);
1150+
assert(destDistributedDim != -1 && "could not find distributed dimension");
1151+
1152+
VectorType srcType = insertOp.getSourceVectorType();
1153+
VectorType destType = insertOp.getDestVectorType();
1154+
// Currently we require that both source (kD) and dest (nD) vectors are
1155+
// distributed. This requires that distributedDim (d) is contained in the
1156+
// last k dims of the dest vector (d >= n - k).
1157+
// TODO: Add support for case where source vector is not distributed.
1158+
int64_t sourceDistributedDim =
1159+
destDistributedDim - (destType.getRank() - srcType.getRank());
1160+
if (sourceDistributedDim < 0)
1161+
return rewriter.notifyMatchFailure(
1162+
insertOp,
1163+
"distributed dimension must be in the last k dims of dest vector");
1164+
// Distributed dimension must be fully inserted.
1165+
if (srcType.getDimSize(sourceDistributedDim) !=
1166+
destType.getDimSize(destDistributedDim))
1167+
return rewriter.notifyMatchFailure(
1168+
insertOp, "distributed dimension must be fully inserted");
1169+
SmallVector<int64_t> newSourceDistShape(
1170+
insertOp.getSourceVectorType().getShape());
1171+
newSourceDistShape[sourceDistributedDim] =
1172+
distributedType.getDimSize(destDistributedDim);
1173+
auto newSourceTy =
1174+
VectorType::get(newSourceDistShape, distributedType.getElementType());
1175+
VectorType newDestTy = distributedType;
1176+
SmallVector<size_t> newRetIndices;
1177+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1178+
rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1179+
{newSourceTy, newDestTy}, newRetIndices);
1180+
rewriter.setInsertionPointAfter(newWarpOp);
1181+
Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1182+
Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1183+
// Create a new insert strided slice op that inserts distributed source into
1184+
// distributed dest.
1185+
Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(
1186+
insertOp.getLoc(), distributedDest.getType(), distributedSource,
1187+
distributedDest, insertOp.getOffsets(), insertOp.getStrides());
1188+
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
1189+
return success();
1190+
}
1191+
};
1192+
1193+
/// Sink out extract_strided_slice op feeding into a warp op yield.
1194+
/// ```
1195+
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
1196+
/// ...
1197+
/// %src = ... : vector<64x32xf32>
1198+
/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1199+
/// strides = [1] : vector<64x32xf32> to vector<16x32xf32>
1200+
/// gpu.yield %extract : vector<16x32xf32>
1201+
/// }
1202+
/// ```
1203+
/// To
1204+
/// ```
1205+
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
1206+
/// ...
1207+
/// %src = ... : vector<64x32xf32>
1208+
/// gpu.yield %src : vector<64x32xf32>
1209+
/// }
1210+
/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1211+
/// strides = [1] : vector<64x1xf32> to vector<16x1xf32>
1212+
/// ```
1213+
/// NOTE: Current support assumes that the extraction happens only on non
1214+
/// distributed dimensions (does not require cross lane communication).
1215+
struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
1216+
using Base::Base;
1217+
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1218+
PatternRewriter &rewriter) const override {
1219+
OpOperand *operand =
1220+
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1221+
if (!operand)
1222+
return failure();
1223+
unsigned int operandNumber = operand->getOperandNumber();
1224+
auto extractOp =
1225+
operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
1226+
auto distributedType =
1227+
cast<VectorType>(warpOp.getResult(operandNumber).getType());
1228+
// Distributed type must be 2D or higher.
1229+
// TODO: Support 1D distributed types.
1230+
if (distributedType.getRank() < 2)
1231+
return rewriter.notifyMatchFailure(
1232+
extractOp, "result vector type must be 2D or higher");
1233+
1234+
// Find the distributed dimension. There should be exactly one.
1235+
auto yieldedType = cast<VectorType>(operand->get().getType());
1236+
int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1237+
assert(distributedDim != -1 && "could not find distributed dimension");
1238+
1239+
int64_t numOfExtractedDims =
1240+
static_cast<int64_t>(extractOp.getSizes().size());
1241+
// If the distributed dim is included in the extracted dims, then we make
1242+
// sure distributed dim is fully extracted. If distributed dim is not
1243+
// included in extracted dims, it is guaranteed to be fully extracted (i.e.
1244+
// distributed dim comes after all the extracted dims)
1245+
// TODO: Partial extraction from distributed dimension require cross lane
1246+
// communication.
1247+
if (distributedDim < numOfExtractedDims) {
1248+
int64_t distributedDimOffset =
1249+
llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1250+
.getInt();
1251+
int64_t distributedDimSize =
1252+
llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1253+
.getInt();
1254+
if (distributedDimOffset != 0 ||
1255+
distributedDimSize != yieldedType.getDimSize(distributedDim))
1256+
return rewriter.notifyMatchFailure(
1257+
extractOp, "distributed dimension must be fully extracted");
1258+
}
1259+
SmallVector<int64_t> newDistributedShape(
1260+
extractOp.getSourceVectorType().getShape());
1261+
newDistributedShape[distributedDim] =
1262+
distributedType.getDimSize(distributedDim);
1263+
auto newDistributedType =
1264+
VectorType::get(newDistributedShape, distributedType.getElementType());
1265+
SmallVector<size_t> newRetIndices;
1266+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1267+
rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1268+
newRetIndices);
1269+
rewriter.setInsertionPointAfter(newWarpOp);
1270+
SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1271+
extractOp.getSizes(), [](Attribute attr) { return attr; });
1272+
// Update the distributed sizes to match the distributed type.
1273+
if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
1274+
distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1275+
distributedType.getDimSize(distributedDim));
1276+
1277+
// Create a new extract strided slice op that extracts from the
1278+
// distributed vector.
1279+
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1280+
Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
1281+
extractOp.getLoc(), distributedType, distributedVec,
1282+
extractOp.getOffsets(),
1283+
ArrayAttr::get(rewriter.getContext(), distributedSizes),
1284+
extractOp.getStrides());
1285+
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1286+
newExtract);
1287+
return success();
1288+
}
1289+
};
1290+
10791291
/// Pattern to move out vector.extract of single element vector. Those don't
10801292
/// need to be distributed and can just be propagated outside of the region.
10811293
struct WarpOpExtract : public WarpDistributionPattern {
@@ -1122,15 +1334,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
11221334
auto distributedType =
11231335
cast<VectorType>(warpOp.getResult(operandNumber).getType());
11241336
auto yieldedType = cast<VectorType>(operand->get().getType());
1125-
int64_t distributedDim = -1;
1126-
for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1127-
if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1128-
// Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1129-
// support distributing multiple dimensions in the future.
1130-
assert(distributedDim == -1 && "found multiple distributed dims");
1131-
distributedDim = i;
1132-
}
1133-
}
1337+
int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
11341338
assert(distributedDim != -1 && "could not find distributed dimension");
11351339
(void)distributedDim;
11361340

@@ -1764,7 +1968,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
17641968
patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
17651969
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
17661970
WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1767-
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1971+
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
1972+
WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
17681973
patterns.getContext(), benefit);
17691974
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
17701975
benefit);

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,86 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
12961296
return %r : vector<4x96xf32>
12971297
}
12981298

1299+
// -----
1300+
// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_inner(
1301+
// CHECK-RPOP-SAME: %[[LANEID:.*]]: index
1302+
// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<64x1xf32>) {
1303+
// CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<64x32xf32>
1304+
// CHECK-PROP: gpu.yield %[[VEC]] : vector<64x32xf32>
1305+
// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
1306+
// CHECK-PROP-SAME: {offsets = [8], sizes = [24], strides = [1]} : vector<64x1xf32> to vector<24x1xf32>
1307+
// CHECK-PROP: return %[[EXTRACT]] : vector<24x1xf32>
1308+
func.func @vector_extract_strided_slice_2d_distr_inner(%laneid: index) -> (vector<24x1xf32>) {
1309+
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<24x1xf32>) {
1310+
%0 = "some_def"() : () -> (vector<64x32xf32>)
1311+
%1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [24], strides = [1]}
1312+
: vector<64x32xf32> to vector<24x32xf32>
1313+
gpu.yield %1 : vector<24x32xf32>
1314+
}
1315+
return %r : vector<24x1xf32>
1316+
}
1317+
1318+
// -----
1319+
// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_outer(
1320+
// CHECK-PROP-SAME: %[[LANEID:.*]]: index
1321+
// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x64xf32>) {
1322+
// CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<32x64xf32>
1323+
// CHECK-PROP: gpu.yield %[[VEC]] : vector<32x64xf32>
1324+
// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
1325+
// CHECK-PROP-SAME: {offsets = [0, 12], sizes = [1, 8], strides = [1, 1]} : vector<1x64xf32> to vector<1x8xf32>
1326+
// CHECK-PROP: return %[[EXTRACT]] : vector<1x8xf32>
1327+
func.func @vector_extract_strided_slice_2d_distr_outer(%laneid: index) -> (vector<1x8xf32>) {
1328+
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x8xf32>) {
1329+
%0 = "some_def"() : () -> (vector<32x64xf32>)
1330+
%1 = vector.extract_strided_slice %0 { offsets = [0, 12], sizes = [32, 8], strides = [1, 1]}
1331+
: vector<32x64xf32> to vector<32x8xf32>
1332+
gpu.yield %1 : vector<32x8xf32>
1333+
}
1334+
return %r : vector<1x8xf32>
1335+
}
1336+
1337+
// -----
1338+
// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_1d_to_2d(
1339+
// CHECK-PROP-SAME: %[[LANEID:.*]]: index)
1340+
// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}} -> (vector<1xf32>, vector<64x1xf32>) {
1341+
// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<32xf32>
1342+
// CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
1343+
// CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<32xf32>, vector<64x32xf32>
1344+
// CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1
1345+
// CHECK-PROP-SAME: {offsets = [18, 0], strides = [1]} : vector<1xf32> into vector<64x1xf32>
1346+
// CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
1347+
func.func @vector_insert_strided_slice_1d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
1348+
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
1349+
%0 = "some_def"() : () -> (vector<32xf32>)
1350+
%1 = "some_def"() : () -> (vector<64x32xf32>)
1351+
%2 = vector.insert_strided_slice %0, %1 { offsets = [18, 0], strides = [1]}
1352+
: vector<32xf32> into vector<64x32xf32>
1353+
gpu.yield %2 : vector<64x32xf32>
1354+
}
1355+
return %r : vector<64x1xf32>
1356+
}
1357+
1358+
// -----
1359+
// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_2d_to_2d(
1360+
// CHECK-PROP-SAME: %[[LANEID:.*]]: index)
1361+
// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<16x1xf32>, vector<64x1xf32>) {
1362+
// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<16x32xf32>
1363+
// CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
1364+
// CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<16x32xf32>, vector<64x32xf32>
1365+
// CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1 {offsets = [36, 0], strides = [1, 1]} :
1366+
// CHECK-PROP-SAME: vector<16x1xf32> into vector<64x1xf32>
1367+
// CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
1368+
func.func @vector_insert_strided_slice_2d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
1369+
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
1370+
%0 = "some_def"() : () -> (vector<16x32xf32>)
1371+
%1 = "some_def"() : () -> (vector<64x32xf32>)
1372+
%2 = vector.insert_strided_slice %0, %1 { offsets = [36, 0], strides = [1, 1]}
1373+
: vector<16x32xf32> into vector<64x32xf32>
1374+
gpu.yield %2 : vector<64x32xf32>
1375+
}
1376+
return %r : vector<64x1xf32>
1377+
}
1378+
12991379
// -----
13001380

13011381
// Make sure that all operands of the transfer_read op are properly propagated.

0 commit comments

Comments
 (0)