Skip to content

Commit df49a97

Browse files
authored
[mlir][vector] Root the transfer write distribution pattern on the warp op (llvm#71868)
Currently when there is a mix of transfer read ops and transfer write ops that need to be distributed, because the pattern for write distribution is rooted on the transfer write, it is hard to guarantee that the write gets distributed after the read when the two aren't directly connected by SSA. This is likely still relatively unsafe when there are undistributable ops, but structurally these patterns are a bit difficult to work with. For now pattern benefits give fairly good guarantees for happy paths.
1 parent d96ea27 commit df49a97

File tree

4 files changed

+63
-21
lines changed

4 files changed

+63
-21
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,15 @@ using DistributionMapFn = std::function<AffineMap(Value)>;
5959
/// vector.yield %v : vector<32xf32>
6060
/// }
6161
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
62+
///
63+
/// When applied at the same time as the vector propagation patterns,
64+
/// distribution of `vector.transfer_write` is expected to have the highest
65+
/// priority (pattern benefit). By making propagation of `vector.transfer_read`
66+
/// be the lowest priority pattern, it will be the last vector operation to
67+
/// distribute, meaning writes should propagate first.
6268
void populateDistributeTransferWriteOpPatterns(
6369
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
64-
PatternBenefit benefit = 1);
70+
PatternBenefit benefit = 2);
6571

6672
/// Move scalar operations with no dependency on the warp op outside of the
6773
/// region.
@@ -75,10 +81,19 @@ using WarpShuffleFromIdxFn =
7581
/// Collect patterns to propagate warp distribution. `distributionMapFn` is used
7682
/// to decide how a value should be distributed when this cannot be inferred
7783
/// from its uses.
84+
///
85+
/// The separate control over the `vector.transfer_read` op pattern benefit
86+
/// is given to ensure the order of reads/writes before and after distribution
87+
/// is consistent. As noted above, writes are expected to have the highest
88+
/// priority for distribution, but are only ever distributed if adjacent to the
89+
/// yield. By making reads the lowest priority pattern, it will be the last
90+
/// vector operation to distribute, meaning writes should propagate first. This
91+
/// is relatively brittle when ops fail to distribute, but that is a limitation
92+
/// of these propagation patterns when there is a dependency not modeled by SSA.
7893
void populatePropagateWarpVectorDistributionPatterns(
7994
RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn,
8095
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
81-
PatternBenefit benefit = 1);
96+
PatternBenefit benefit = 1, PatternBenefit readBenefit = 0);
8297

8398
/// Lambda signature to compute a reduction of a distributed value for the given
8499
/// reduction kind and size.

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -474,10 +474,10 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
474474
/// vector.yield %v : vector<32xf32>
475475
/// }
476476
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
477-
struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
477+
struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
478478
WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
479479
PatternBenefit b = 1)
480-
: OpRewritePattern<vector::TransferWriteOp>(ctx, b),
480+
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
481481
distributionMapFn(std::move(fn)) {}
482482

483483
/// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
@@ -584,18 +584,15 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
584584
return success();
585585
}
586586

587-
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
587+
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
588588
PatternRewriter &rewriter) const override {
589-
auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
590-
if (!warpOp)
589+
auto yield = cast<vector::YieldOp>(
590+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
591+
Operation *lastNode = yield->getPrevNode();
592+
auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
593+
if (!writeOp)
591594
return failure();
592595

593-
// There must be no op with a side effect after writeOp.
594-
Operation *nextOp = writeOp.getOperation();
595-
while ((nextOp = nextOp->getNextNode()))
596-
if (!isMemoryEffectFree(nextOp))
597-
return failure();
598-
599596
Value maybeMask = writeOp.getMask();
600597
if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
601598
return writeOp.getVector() == value ||
@@ -1731,11 +1728,13 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
17311728

17321729
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
17331730
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1734-
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
1735-
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
1736-
WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
1737-
WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement,
1738-
WarpOpInsert>(patterns.getContext(), benefit);
1731+
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1732+
PatternBenefit readBenefit) {
1733+
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1734+
patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1735+
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1736+
WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
1737+
patterns.getContext(), benefit);
17391738
patterns.add<WarpOpExtractElement>(patterns.getContext(),
17401739
warpShuffleFromIdxFn, benefit);
17411740
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,3 +1348,24 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
13481348
// CHECK-PROP: vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[ARG2]]], {{.*}}, %[[R]]#1 {{.*}} vector<2x2xf32>
13491349
// CHECK-PROP: %[[DIST_READ_IDX1:.+]] = affine.apply #[[$MAP1]]()[%[[ARG2]], %[[ARG0]]]
13501350
// CHECK-PROP: vector.transfer_read {{.*}}[%[[C0]], %[[DIST_READ_IDX1]]], {{.*}}, %[[R]]#0 {{.*}} vector<2xf32>
1351+
1352+
// -----
1353+
1354+
func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref<128xf32>, %f1: f32) -> (vector<2xf32>, vector<4xf32>) {
1355+
%f0 = arith.constant 0.000000e+00 : f32
1356+
%c0 = arith.constant 0 : index
1357+
%r:2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>, vector<4xf32>) {
1358+
%cst = arith.constant dense<2.0> : vector<128xf32>
1359+
%0 = vector.transfer_read %buffer[%c0], %f0 {in_bounds = [true]} : memref<128xf32>, vector<128xf32>
1360+
vector.transfer_write %cst, %buffer[%c0] : vector<128xf32>, memref<128xf32>
1361+
%1 = vector.broadcast %f1 : f32 to vector<64xf32>
1362+
vector.yield %1, %0 : vector<64xf32>, vector<128xf32>
1363+
}
1364+
return %r#0, %r#1 : vector<2xf32>, vector<4xf32>
1365+
}
1366+
1367+
// Verify that the write comes after the read
1368+
// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_unconnected_read_write(
1369+
// CHECK-DIST-AND-PROP: %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
1370+
// CHECK-DIST-AND-PROP: vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32>
1371+
// CHECK-DIST-AND-PROP: vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,12 +594,19 @@ struct TestVectorDistribution
594594
.getResult(0);
595595
return result;
596596
};
597-
if (distributeTransferWriteOps) {
597+
if (distributeTransferWriteOps && propagateDistribution) {
598+
RewritePatternSet patterns(ctx);
599+
vector::populatePropagateWarpVectorDistributionPatterns(
600+
patterns, distributionFn, shuffleFn, /*benefit=*/1,
601+
/*readBenefit=*/0);
602+
vector::populateDistributeReduction(patterns, warpReduction, 1);
603+
populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
604+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
605+
} else if (distributeTransferWriteOps) {
598606
RewritePatternSet patterns(ctx);
599607
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
600608
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
601-
}
602-
if (propagateDistribution) {
609+
} else if (propagateDistribution) {
603610
RewritePatternSet patterns(ctx);
604611
vector::populatePropagateWarpVectorDistributionPatterns(
605612
patterns, distributionFn, shuffleFn);

0 commit comments

Comments
 (0)