Skip to content

Commit cdf49bf

Browse files
authored
[BACKEND] Allow ranked reduced descriptor load (triton-lang#5880)
This allows us to implement efficient batch matmul with TMAs. Note that this could be solved by allowing reshapes on shared descriptor and propagating layout but for simplicity we currently do it with a pattern match.
1 parent ae1a8f1 commit cdf49bf

File tree

4 files changed

+56
-2
lines changed

4 files changed

+56
-2
lines changed

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1237,7 +1237,19 @@ static LogicalResult verifyDesciptorLoadStoreType(Operation *op,
12371237
TensorDescType desc,
12381238
RankedTensorType tensor) {
12391239
RankedTensorType block = desc.getBlockType();
1240-
if (block.getShape() == tensor.getShape() &&
1240+
ArrayRef<int64_t> blockShape = block.getShape();
1241+
ArrayRef<int64_t> tensorShape = tensor.getShape();
1242+
if (blockShape.size() > tensorShape.size()) {
1243+
// Allow ranked reduced load if the leading dimensions are all 1s.
1244+
for (int i = 0; i < blockShape.size() - tensorShape.size(); ++i) {
1245+
if (blockShape[i] != 1)
1246+
return op->emitOpError(
1247+
"ranked reduce load only allowed for unit dimension leading dim.");
1248+
}
1249+
blockShape = blockShape.take_back(tensorShape.size());
1250+
}
1251+
1252+
if (blockShape == tensorShape &&
12411253
block.getElementType() == tensor.getElementType())
12421254
return success();
12431255
return op->emitOpError("tensor desciptor block and tensor types must match");

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,37 @@ class CombineReshapeReducePatterns : public mlir::OpRewritePattern<ReshapeOp> {
210210
}
211211
};
212212

213+
class RankedReduceDescriptorLoads : public mlir::OpRewritePattern<ReshapeOp> {
214+
public:
215+
using OpRewritePattern::OpRewritePattern;
216+
217+
mlir::LogicalResult
218+
matchAndRewrite(triton::ReshapeOp reshapeOp,
219+
mlir::PatternRewriter &rewriter) const override {
220+
auto loadDef = reshapeOp.getSrc()
221+
.getDefiningOp<triton::ExperimentalDescriptorLoadOp>();
222+
if (!loadDef || !loadDef->hasOneUse())
223+
return failure();
224+
int loadRank = loadDef.getType().getRank();
225+
int reshapeRank = reshapeOp.getType().getRank();
226+
if (!(reshapeRank < loadRank))
227+
return failure();
228+
ArrayRef<int64_t> loadShape = loadDef.getType().getShape();
229+
ArrayRef<int64_t> reshapeShape = reshapeOp.getType().getShape();
230+
for (int i = 0; i < loadRank - reshapeRank; ++i) {
231+
// Only rank reduce unit dims.
232+
if (loadShape[i] != 1)
233+
return failure();
234+
}
235+
if (loadShape.take_back(reshapeRank) != reshapeShape)
236+
return failure();
237+
rewriter.modifyOpInPlace(
238+
loadDef, [&]() { loadDef.getResult().setType(reshapeOp.getType()); });
239+
rewriter.replaceOp(reshapeOp, loadDef.getResult());
240+
return success();
241+
}
242+
};
243+
213244
class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
214245
public:
215246
void runOnOperation() override {
@@ -227,6 +258,7 @@ class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
227258
patterns.add<CombineAddPtrPattern>(context);
228259
patterns.add<CombineBroadcastMulReducePattern>(context);
229260
patterns.add<CombineReshapeReducePatterns>(context);
261+
patterns.add<RankedReduceDescriptorLoads>(context);
230262

231263
if (applyPatternsGreedily(m, std::move(patterns)).failed())
232264
signalPassFailure();

test/Triton/combine.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,13 @@ tt.func @test_reshape_reduce(%0: tensor<32x4x2xi32>) -> (i32, tensor<16xi32>) {
380380
%3 = tt.histogram %1 : tensor<256xi32> -> tensor<16xi32>
381381
tt.return %2, %3 : i32, tensor<16xi32>
382382
}
383+
384+
// CHECK-LABEL: test_rank_reduce_desc_load
385+
tt.func @test_rank_reduce_desc_load(%0: !tt.tensordesc<tensor<1x128x64xf16>>) -> (tensor<128x64xf16>) {
386+
%c0 = arith.constant 0 : i32
387+
// CHECK: %[[R:.+]] = tt.experimental_descriptor_load {{.*}} : !tt.tensordesc<tensor<1x128x64xf16>> -> tensor<128x64xf16>
388+
// CHECK: tt.return %[[R]]
389+
%l = tt.experimental_descriptor_load %0[%c0, %c0, %c0] : !tt.tensordesc<tensor<1x128x64xf16>> -> tensor<1x128x64xf16>
390+
%r = tt.reshape %l : tensor<1x128x64xf16> -> tensor<128x64xf16>
391+
tt.return %r : tensor<128x64xf16>
392+
}

test/Triton/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) {
409409

410410
tt.func @invalid_desc_load(%arg0: !tt.tensordesc<tensor<16x16xf32>>) {
411411
%c = arith.constant 0 : i32
412-
// expected-error @below {{tensor desciptor block and tensor types must match}}
412+
// expected-error @below {{ranked reduce load only allowed for unit dimension leading dim}}
413413
tt.experimental_descriptor_load %arg0[%c, %c] : !tt.tensordesc<tensor<16x16xf32>> -> tensor<16xf32>
414414
tt.return
415415
}

0 commit comments

Comments
 (0)