@@ -1078,6 +1078,83 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1078
1078
}
1079
1079
};
1080
1080
1081
+ // / Sink out vector.create_mask op feeding into a warp op yield.
1082
+ // / ```
1083
+ // / %0 = ...
1084
+ // / %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1085
+ // / ...
1086
+ // / %mask = vector.create_mask %0 : vector<32xi1>
1087
+ // / vector.yield %mask : vector<32xi1>
1088
+ // / }
1089
+ // / ```
1090
+ // / To
1091
+ // / ```
1092
+ // / %0 = ...
1093
+ // / vector.warp_execute_on_lane_0(%arg0) {
1094
+ // / ...
1095
+ // / }
1096
+ // / %cmp = arith.cmpi ult, %laneid, %0
1097
+ // / %ub = arith.select %cmp, %c0, %c1
1098
+ // / %1 = vector.create_mask %ub : vector<1xi1>
1099
+ struct WarpOpCreateMask : public OpRewritePattern <WarpExecuteOnLane0Op> {
1100
+ using OpRewritePattern::OpRewritePattern;
1101
+ LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1102
+ PatternRewriter &rewriter) const override {
1103
+ OpOperand *yieldOperand = getWarpResult (
1104
+ warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
1105
+ if (!yieldOperand)
1106
+ return failure ();
1107
+
1108
+ auto mask = yieldOperand->get ().getDefiningOp <vector::CreateMaskOp>();
1109
+
1110
+ // Early exit if any values needed for calculating the new mask indices
1111
+ // are defined inside the warp op.
1112
+ if (!llvm::all_of (mask->getOperands (), [&](Value value) {
1113
+ return warpOp.isDefinedOutsideOfRegion (value);
1114
+ }))
1115
+ return failure ();
1116
+
1117
+ Location loc = mask.getLoc ();
1118
+ unsigned operandIndex = yieldOperand->getOperandNumber ();
1119
+
1120
+ auto distType = cast<VectorType>(warpOp.getResult (operandIndex).getType ());
1121
+ VectorType seqType = mask.getVectorType ();
1122
+ ArrayRef<int64_t > seqShape = seqType.getShape ();
1123
+ ArrayRef<int64_t > distShape = distType.getShape ();
1124
+
1125
+ rewriter.setInsertionPointAfter (warpOp);
1126
+
1127
+ // Delinearize the lane ID for constructing the distributed mask sizes.
1128
+ SmallVector<Value> delinearizedIds;
1129
+ if (!delinearizeLaneId (rewriter, loc, seqShape, distShape,
1130
+ warpOp.getWarpSize (), warpOp.getLaneid (),
1131
+ delinearizedIds))
1132
+ return rewriter.notifyMatchFailure (
1133
+ mask, " cannot delinearize lane ID for distribution" );
1134
+ assert (!delinearizedIds.empty ());
1135
+
1136
+ AffineExpr s0, s1;
1137
+ bindSymbols (rewriter.getContext (), s0, s1);
1138
+ SmallVector<Value> newOperands;
1139
+ for (int i = 0 , e = distShape.size (); i < e; ++i) {
1140
+ // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1141
+ // find the distance from the largest mask index owned by this lane to the
1142
+ // original mask size. `vector.create_mask` implicitly clamps mask
1143
+ // operands to the range [0, mask_vector_size[i]], or in other words, the
1144
+ // mask sizes are always in the range [0, mask_vector_size[i]).
1145
+ Value maskDimIdx = affine::makeComposedAffineApply (
1146
+ rewriter, loc, s1 - s0 * distShape[i],
1147
+ {delinearizedIds[i], mask.getOperand (i)});
1148
+ newOperands.push_back (maskDimIdx);
1149
+ }
1150
+
1151
+ auto newMask =
1152
+ rewriter.create <vector::CreateMaskOp>(loc, distType, newOperands);
1153
+ rewriter.replaceAllUsesWith (warpOp.getResult (operandIndex), newMask);
1154
+ return success ();
1155
+ }
1156
+ };
1157
+
1081
1158
// / Pattern to move out vector.extract of single element vector. Those don't
1082
1159
// / need to be distributed and can just be propagated outside of the region.
1083
1160
struct WarpOpExtract : public OpRewritePattern <WarpExecuteOnLane0Op> {
@@ -1731,10 +1808,11 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1731
1808
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1732
1809
PatternBenefit readBenefit) {
1733
1810
patterns.add <WarpOpTransferRead>(patterns.getContext (), readBenefit);
1734
- patterns.add <WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1735
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1736
- WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
1737
- patterns.getContext (), benefit);
1811
+ patterns
1812
+ .add <WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1813
+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
1814
+ WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
1815
+ patterns.getContext (), benefit);
1738
1816
patterns.add <WarpOpExtractElement>(patterns.getContext (),
1739
1817
warpShuffleFromIdxFn, benefit);
1740
1818
patterns.add <WarpOpScfForOp>(patterns.getContext (), distributionMapFn,
0 commit comments