Skip to content

Commit d4d2891

Browse files
authored
[mlir][vector] Add distribution pattern for vector.create_mask (llvm#71619)
This is the last step needed for basic support for distributing masked vector code. The lane id gets delinearized based on the distributed mask shape and then compared against the original mask sizes to compute the bounds for the distributed mask. Note that the distribution of masks is implicit on the shape specified by the warp op. As a result, it is the responsibility of the consumer of the mask to ensure the distributed mask will match its own distribution semantics.
1 parent 22c6851 commit d4d2891

File tree

2 files changed

+116
-4
lines changed

2 files changed

+116
-4
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,83 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
10781078
}
10791079
};
10801080

1081+
/// Sink out vector.create_mask op feeding into a warp op yield.
1082+
/// ```
1083+
/// %0 = ...
1084+
/// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1085+
/// ...
1086+
/// %mask = vector.create_mask %0 : vector<32xi1>
1087+
/// vector.yield %mask : vector<32xi1>
1088+
/// }
1089+
/// ```
1090+
/// To
1091+
/// ```
1092+
/// %0 = ...
1093+
/// vector.warp_execute_on_lane_0(%arg0) {
1094+
/// ...
1095+
/// }
1096+
/// %cmp = arith.cmpi ult, %laneid, %0
1097+
/// %ub = arith.select %cmp, %c0, %c1
1098+
/// %1 = vector.create_mask %ub : vector<1xi1>
1099+
struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
1100+
using OpRewritePattern::OpRewritePattern;
1101+
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1102+
PatternRewriter &rewriter) const override {
1103+
OpOperand *yieldOperand = getWarpResult(
1104+
warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
1105+
if (!yieldOperand)
1106+
return failure();
1107+
1108+
auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1109+
1110+
// Early exit if any values needed for calculating the new mask indices
1111+
// are defined inside the warp op.
1112+
if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1113+
return warpOp.isDefinedOutsideOfRegion(value);
1114+
}))
1115+
return failure();
1116+
1117+
Location loc = mask.getLoc();
1118+
unsigned operandIndex = yieldOperand->getOperandNumber();
1119+
1120+
auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1121+
VectorType seqType = mask.getVectorType();
1122+
ArrayRef<int64_t> seqShape = seqType.getShape();
1123+
ArrayRef<int64_t> distShape = distType.getShape();
1124+
1125+
rewriter.setInsertionPointAfter(warpOp);
1126+
1127+
// Delinearize the lane ID for constructing the distributed mask sizes.
1128+
SmallVector<Value> delinearizedIds;
1129+
if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1130+
warpOp.getWarpSize(), warpOp.getLaneid(),
1131+
delinearizedIds))
1132+
return rewriter.notifyMatchFailure(
1133+
mask, "cannot delinearize lane ID for distribution");
1134+
assert(!delinearizedIds.empty());
1135+
1136+
AffineExpr s0, s1;
1137+
bindSymbols(rewriter.getContext(), s0, s1);
1138+
SmallVector<Value> newOperands;
1139+
for (int i = 0, e = distShape.size(); i < e; ++i) {
1140+
// Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1141+
// find the distance from the largest mask index owned by this lane to the
1142+
// original mask size. `vector.create_mask` implicitly clamps mask
1143+
// operands to the range [0, mask_vector_size[i]], or in other words, the
1144+
// mask sizes are always in the range [0, mask_vector_size[i]).
1145+
Value maskDimIdx = affine::makeComposedAffineApply(
1146+
rewriter, loc, s1 - s0 * distShape[i],
1147+
{delinearizedIds[i], mask.getOperand(i)});
1148+
newOperands.push_back(maskDimIdx);
1149+
}
1150+
1151+
auto newMask =
1152+
rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1153+
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1154+
return success();
1155+
}
1156+
};
1157+
10811158
/// Pattern to move out vector.extract of single element vector. Those don't
10821159
/// need to be distributed and can just be propagated outside of the region.
10831160
struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -1731,10 +1808,11 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
17311808
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
17321809
PatternBenefit readBenefit) {
17331810
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1734-
patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1735-
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1736-
WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
1737-
patterns.getContext(), benefit);
1811+
patterns
1812+
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1813+
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
1814+
WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
1815+
patterns.getContext(), benefit);
17381816
patterns.add<WarpOpExtractElement>(patterns.getContext(),
17391817
warpShuffleFromIdxFn, benefit);
17401818
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,3 +1369,37 @@ func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref
13691369
// CHECK-DIST-AND-PROP: %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
13701370
// CHECK-DIST-AND-PROP: vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32>
13711371
// CHECK-DIST-AND-PROP: vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32>
1372+
1373+
// -----
1374+
1375+
func.func @warp_propagate_create_mask(%laneid: index, %m0: index) -> vector<1xi1> {
1376+
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
1377+
%1 = vector.create_mask %m0 : vector<32xi1>
1378+
vector.yield %1 : vector<32xi1>
1379+
}
1380+
return %r : vector<1xi1>
1381+
}
1382+
1383+
// CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0, s1] -> (-s0 + s1)>
1384+
// CHECK-PROP-LABEL: func @warp_propagate_create_mask
1385+
// CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index
1386+
// CHECK-PROP: %[[MDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
1387+
// CHECK-PROP: vector.create_mask %[[MDIST]] : vector<1xi1>
1388+
1389+
// -----
1390+
1391+
func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1: index, %m2: index) -> vector<1x2x4xi1> {
1392+
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
1393+
%1 = vector.create_mask %m0, %m1, %m2 : vector<16x4x4xi1>
1394+
vector.yield %1 : vector<16x4x4xi1>
1395+
}
1396+
return %r : vector<1x2x4xi1>
1397+
}
1398+
1399+
// CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0, s1] -> (s0 - s1 floordiv 2)>
1400+
// CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0, s1] -> (s0 - s1 * 2 + (s1 floordiv 2) * 4)>
1401+
// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_create_mask
1402+
// CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index, %[[M1:.+]]: index, %[[M2:.+]]: index
1403+
// CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
1404+
// CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
1405+
// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>

0 commit comments

Comments
 (0)