|
15 | 15 | #include "mlir/Dialect/Vector/IR/VectorOps.h"
|
16 | 16 | #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
|
17 | 17 | #include "mlir/IR/AffineExpr.h"
|
| 18 | +#include "mlir/IR/Attributes.h" |
| 19 | +#include "mlir/IR/BuiltinTypes.h" |
18 | 20 | #include "mlir/Interfaces/SideEffectInterfaces.h"
|
19 | 21 | #include "mlir/Transforms/RegionUtils.h"
|
20 | 22 | #include "llvm/ADT/SetVector.h"
|
| 23 | +#include "llvm/ADT/SmallVectorExtras.h" |
21 | 24 | #include "llvm/Support/FormatVariadic.h"
|
22 | 25 | #include <utility>
|
23 | 26 |
|
@@ -52,6 +55,25 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
|
52 | 55 | return map;
|
53 | 56 | }
|
54 | 57 |
|
| 58 | +/// Given a sequential and distributed vector type, returns the distributed |
| 59 | +/// dimension. This function expects that only a single dimension is |
| 60 | +/// distributed. |
| 61 | +static int getDistributedDim(VectorType sequentialType, |
| 62 | + VectorType distributedType) { |
| 63 | + assert(sequentialType.getRank() == distributedType.getRank() && |
| 64 | + "sequential and distributed vector types must have the same rank"); |
| 65 | + int64_t distributedDim = -1; |
| 66 | + for (int64_t i = 0; i < sequentialType.getRank(); ++i) { |
| 67 | + if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) { |
| 68 | + // Keep this assert here in case WarpExecuteOnLane0Op gets extended to |
| 69 | + // support distributing multiple dimensions in the future. |
| 70 | + assert(distributedDim == -1 && "found multiple distributed dims"); |
| 71 | + distributedDim = i; |
| 72 | + } |
| 73 | + } |
| 74 | + return distributedDim; |
| 75 | +} |
| 76 | + |
55 | 77 | namespace {
|
56 | 78 |
|
57 | 79 | /// Helper struct to create the load / store operations that permit transit
|
@@ -1076,6 +1098,196 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
|
1076 | 1098 | }
|
1077 | 1099 | };
|
1078 | 1100 |
|
| 1101 | +/// Sink out insert_strided_slice op feeding into a warp op yield. |
| 1102 | +/// ``` |
| 1103 | +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) { |
| 1104 | +/// ... |
| 1105 | +/// %src = ... : vector<4x32xf32> |
| 1106 | +/// %dest = ... : vector<8x32xf32> |
| 1107 | +/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0], |
| 1108 | +/// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32> |
| 1109 | +/// gpu.yield %insert : vector<8x32xf32> |
| 1110 | +/// } |
| 1111 | +/// ``` |
| 1112 | +/// To |
| 1113 | +/// ``` |
| 1114 | +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>, |
| 1115 | +/// vector<8x1xf32>) { |
| 1116 | +/// ... |
| 1117 | +/// %src = ... : vector<4x32xf32> |
| 1118 | +/// %dest = ... : vector<8x32xf32> |
| 1119 | +/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32> |
| 1120 | +/// } |
| 1121 | +/// %insert = vector.insert_strided_slice %0#0, %0#1, |
| 1122 | +/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32> |
| 1123 | +/// ``` |
| 1124 | +/// NOTE: Current support assumes that both src and dest vectors are distributed |
| 1125 | +/// to lanes and sinking the insert op does not require any cross lane |
| 1126 | +/// communication. |
| 1127 | +struct WarpOpInsertStridedSlice : public WarpDistributionPattern { |
| 1128 | + using Base::Base; |
| 1129 | + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| 1130 | + PatternRewriter &rewriter) const override { |
| 1131 | + OpOperand *operand = |
| 1132 | + getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>); |
| 1133 | + if (!operand) |
| 1134 | + return failure(); |
| 1135 | + unsigned int operandNumber = operand->getOperandNumber(); |
| 1136 | + auto insertOp = |
| 1137 | + operand->get().getDefiningOp<vector::InsertStridedSliceOp>(); |
| 1138 | + auto distributedType = |
| 1139 | + cast<VectorType>(warpOp.getResult(operandNumber).getType()); |
| 1140 | + // Distributed type must be 2D or higher. |
| 1141 | + // TODO: Support 1D distributed types. |
| 1142 | + if (distributedType.getRank() < 2) |
| 1143 | + return rewriter.notifyMatchFailure( |
| 1144 | + insertOp, "result vector type must be 2D or higher"); |
| 1145 | + // Find the distributed dimension of the dest vector. There should be |
| 1146 | + // exactly one. |
| 1147 | + auto yieldedType = cast<VectorType>(operand->get().getType()); |
| 1148 | + int64_t destDistributedDim = |
| 1149 | + getDistributedDim(yieldedType, distributedType); |
| 1150 | + assert(destDistributedDim != -1 && "could not find distributed dimension"); |
| 1151 | + |
| 1152 | + VectorType srcType = insertOp.getSourceVectorType(); |
| 1153 | + VectorType destType = insertOp.getDestVectorType(); |
| 1154 | + // Currently we require that both source (kD) and dest (nD) vectors are |
| 1155 | + // distributed. This requires that distributedDim (d) is contained in the |
| 1156 | + // last k dims of the dest vector (d >= n - k). |
| 1157 | + // TODO: Add support for case where source vector is not distributed. |
| 1158 | + int64_t sourceDistributedDim = |
| 1159 | + destDistributedDim - (destType.getRank() - srcType.getRank()); |
| 1160 | + if (sourceDistributedDim < 0) |
| 1161 | + return rewriter.notifyMatchFailure( |
| 1162 | + insertOp, |
| 1163 | + "distributed dimension must be in the last k dims of dest vector"); |
| 1164 | + // Distributed dimension must be fully inserted. |
| 1165 | + if (srcType.getDimSize(sourceDistributedDim) != |
| 1166 | + destType.getDimSize(destDistributedDim)) |
| 1167 | + return rewriter.notifyMatchFailure( |
| 1168 | + insertOp, "distributed dimension must be fully inserted"); |
| 1169 | + SmallVector<int64_t> newSourceDistShape( |
| 1170 | + insertOp.getSourceVectorType().getShape()); |
| 1171 | + newSourceDistShape[sourceDistributedDim] = |
| 1172 | + distributedType.getDimSize(destDistributedDim); |
| 1173 | + auto newSourceTy = |
| 1174 | + VectorType::get(newSourceDistShape, distributedType.getElementType()); |
| 1175 | + VectorType newDestTy = distributedType; |
| 1176 | + SmallVector<size_t> newRetIndices; |
| 1177 | + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| 1178 | + rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()}, |
| 1179 | + {newSourceTy, newDestTy}, newRetIndices); |
| 1180 | + rewriter.setInsertionPointAfter(newWarpOp); |
| 1181 | + Value distributedSource = newWarpOp->getResult(newRetIndices[0]); |
| 1182 | + Value distributedDest = newWarpOp->getResult(newRetIndices[1]); |
| 1183 | + // Create a new insert strided slice op that inserts distributed source into |
| 1184 | + // distributed dest. |
| 1185 | + Value newInsert = rewriter.create<vector::InsertStridedSliceOp>( |
| 1186 | + insertOp.getLoc(), distributedDest.getType(), distributedSource, |
| 1187 | + distributedDest, insertOp.getOffsets(), insertOp.getStrides()); |
| 1188 | + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert); |
| 1189 | + return success(); |
| 1190 | + } |
| 1191 | +}; |
| 1192 | + |
| 1193 | +/// Sink out extract_strided_slice op feeding into a warp op yield. |
| 1194 | +/// ``` |
| 1195 | +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) { |
| 1196 | +/// ... |
| 1197 | +/// %src = ... : vector<64x32xf32> |
| 1198 | +/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16], |
| 1199 | +/// strides = [1] : vector<64x32xf32> to vector<16x32xf32> |
| 1200 | +/// gpu.yield %extract : vector<16x32xf32> |
| 1201 | +/// } |
| 1202 | +/// ``` |
| 1203 | +/// To |
| 1204 | +/// ``` |
| 1205 | +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) { |
| 1206 | +/// ... |
| 1207 | +/// %src = ... : vector<64x32xf32> |
| 1208 | +/// gpu.yield %src : vector<64x32xf32> |
| 1209 | +/// } |
| 1210 | +/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16], |
| 1211 | +/// strides = [1] : vector<64x1xf32> to vector<16x1xf32> |
| 1212 | +/// ``` |
| 1213 | +/// NOTE: Current support assumes that the extraction happens only on non |
| 1214 | +/// distributed dimensions (does not require cross lane communication). |
| 1215 | +struct WarpOpExtractStridedSlice : public WarpDistributionPattern { |
| 1216 | + using Base::Base; |
| 1217 | + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| 1218 | + PatternRewriter &rewriter) const override { |
| 1219 | + OpOperand *operand = |
| 1220 | + getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>); |
| 1221 | + if (!operand) |
| 1222 | + return failure(); |
| 1223 | + unsigned int operandNumber = operand->getOperandNumber(); |
| 1224 | + auto extractOp = |
| 1225 | + operand->get().getDefiningOp<vector::ExtractStridedSliceOp>(); |
| 1226 | + auto distributedType = |
| 1227 | + cast<VectorType>(warpOp.getResult(operandNumber).getType()); |
| 1228 | + // Distributed type must be 2D or higher. |
| 1229 | + // TODO: Support 1D distributed types. |
| 1230 | + if (distributedType.getRank() < 2) |
| 1231 | + return rewriter.notifyMatchFailure( |
| 1232 | + extractOp, "result vector type must be 2D or higher"); |
| 1233 | + |
| 1234 | + // Find the distributed dimension. There should be exactly one. |
| 1235 | + auto yieldedType = cast<VectorType>(operand->get().getType()); |
| 1236 | + int64_t distributedDim = getDistributedDim(yieldedType, distributedType); |
| 1237 | + assert(distributedDim != -1 && "could not find distributed dimension"); |
| 1238 | + |
| 1239 | + int64_t numOfExtractedDims = |
| 1240 | + static_cast<int64_t>(extractOp.getSizes().size()); |
| 1241 | + // If the distributed dim is included in the extracted dims, then we make |
| 1242 | + // sure distributed dim is fully extracted. If distributed dim is not |
| 1243 | + // included in extracted dims, it is guaranteed to be fully extracted (i.e. |
| 1244 | + // distributed dim comes after all the extracted dims) |
| 1245 | + // TODO: Partial extraction from distributed dimension require cross lane |
| 1246 | + // communication. |
| 1247 | + if (distributedDim < numOfExtractedDims) { |
| 1248 | + int64_t distributedDimOffset = |
| 1249 | + llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]) |
| 1250 | + .getInt(); |
| 1251 | + int64_t distributedDimSize = |
| 1252 | + llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim]) |
| 1253 | + .getInt(); |
| 1254 | + if (distributedDimOffset != 0 || |
| 1255 | + distributedDimSize != yieldedType.getDimSize(distributedDim)) |
| 1256 | + return rewriter.notifyMatchFailure( |
| 1257 | + extractOp, "distributed dimension must be fully extracted"); |
| 1258 | + } |
| 1259 | + SmallVector<int64_t> newDistributedShape( |
| 1260 | + extractOp.getSourceVectorType().getShape()); |
| 1261 | + newDistributedShape[distributedDim] = |
| 1262 | + distributedType.getDimSize(distributedDim); |
| 1263 | + auto newDistributedType = |
| 1264 | + VectorType::get(newDistributedShape, distributedType.getElementType()); |
| 1265 | + SmallVector<size_t> newRetIndices; |
| 1266 | + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| 1267 | + rewriter, warpOp, {extractOp.getVector()}, {newDistributedType}, |
| 1268 | + newRetIndices); |
| 1269 | + rewriter.setInsertionPointAfter(newWarpOp); |
| 1270 | + SmallVector<Attribute> distributedSizes = llvm::map_to_vector( |
| 1271 | + extractOp.getSizes(), [](Attribute attr) { return attr; }); |
| 1272 | + // Update the distributed sizes to match the distributed type. |
| 1273 | + if (distributedDim < static_cast<int64_t>(distributedSizes.size())) |
| 1274 | + distributedSizes[distributedDim] = rewriter.getI64IntegerAttr( |
| 1275 | + distributedType.getDimSize(distributedDim)); |
| 1276 | + |
| 1277 | + // Create a new extract strided slice op that extracts from the |
| 1278 | + // distributed vector. |
| 1279 | + Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
| 1280 | + Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>( |
| 1281 | + extractOp.getLoc(), distributedType, distributedVec, |
| 1282 | + extractOp.getOffsets(), |
| 1283 | + ArrayAttr::get(rewriter.getContext(), distributedSizes), |
| 1284 | + extractOp.getStrides()); |
| 1285 | + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
| 1286 | + newExtract); |
| 1287 | + return success(); |
| 1288 | + } |
| 1289 | +}; |
| 1290 | + |
1079 | 1291 | /// Pattern to move out vector.extract of single element vector. Those don't
|
1080 | 1292 | /// need to be distributed and can just be propagated outside of the region.
|
1081 | 1293 | struct WarpOpExtract : public WarpDistributionPattern {
|
@@ -1122,15 +1334,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
|
1122 | 1334 | auto distributedType =
|
1123 | 1335 | cast<VectorType>(warpOp.getResult(operandNumber).getType());
|
1124 | 1336 | auto yieldedType = cast<VectorType>(operand->get().getType());
|
1125 |
| - int64_t distributedDim = -1; |
1126 |
| - for (int64_t i = 0; i < yieldedType.getRank(); ++i) { |
1127 |
| - if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { |
1128 |
| - // Keep this assert here in case WarpExecuteOnLane0Op gets extended to |
1129 |
| - // support distributing multiple dimensions in the future. |
1130 |
| - assert(distributedDim == -1 && "found multiple distributed dims"); |
1131 |
| - distributedDim = i; |
1132 |
| - } |
1133 |
| - } |
| 1337 | + int64_t distributedDim = getDistributedDim(yieldedType, distributedType); |
1134 | 1338 | assert(distributedDim != -1 && "could not find distributed dimension");
|
1135 | 1339 | (void)distributedDim;
|
1136 | 1340 |
|
@@ -1764,7 +1968,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
|
1764 | 1968 | patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
|
1765 | 1969 | WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
|
1766 | 1970 | WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
|
1767 |
| - WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>( |
| 1971 | + WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask, |
| 1972 | + WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>( |
1768 | 1973 | patterns.getContext(), benefit);
|
1769 | 1974 | patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
|
1770 | 1975 | benefit);
|
|
0 commit comments