Skip to content

Commit ab34c3a

Browse files
authored
[AMD] Support multi-cta and multicast for TDM operations (#8790)
Adds support for multi-cta TDM load and stores and sets the multicast mask based on the `CGALayout`. Similar to `tt.load` and `ttg.async_copy_global_to_local`, multicast is automatically enabled if the `CGALayout` contains broadcasting bases.
1 parent 9c2cefd commit ab34c3a

File tree

7 files changed

+254
-33
lines changed

7 files changed

+254
-33
lines changed

lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ struct MakeRangeOpConversion
2525
auto elemTy = ty.getElementType();
2626
assert(elemTy.isInteger(32));
2727
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.getStart());
28-
auto idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, true);
28+
auto numCTAs = triton::gpu::getNumCTAs(layout);
29+
auto idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, numCTAs > 1);
2930
unsigned elems = idxs.size();
3031
SmallVector<Value> retVals(elems);
3132
// TODO: slice layout has more elements than expected.

test/Conversion/amd/tritongpu_tdm_to_llvm.mlir

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,100 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
4949
tt.return
5050
}
5151
}
52+
53+
// -----
54+
55+
// Check that CTA offsets are computed and applied to base pointer for multi-cta layouts
56+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CGALayout = [[0, 0]]}>
57+
#smem = #ttg.shared_memory
58+
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
59+
// CHECK-LABEL: tdm_load_multi_cta
60+
tt.func public @tdm_load_multi_cta(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
61+
%c_shape = arith.constant 128 : i32
62+
%c_stride0 = arith.constant 128 : i64
63+
%c_stride1 = arith.constant 1 : i64
64+
%c_offset = arith.constant 0 : i32
65+
%c_pred = arith.constant true
66+
67+
// CHECK-DAG: %[[STRIDE0:.*]] = llvm.mlir.constant(128 : i64) : i64
68+
// CHECK-DAG: %[[STRIDE1:.*]] = llvm.mlir.constant(1 : i32) : i32
69+
// CHECK-DAG: llvm.call_intrinsic "llvm.amdgcn.cluster.workgroup.id.x"
70+
// CHECK-DAG: %[[STRIDE0_TRUNC:.*]] = llvm.trunc %[[STRIDE0]] : i64 to i32
71+
// CHECK: %[[OFFSET_DIM0:.*]] = llvm.mul{{.*}}%[[STRIDE0_TRUNC]]
72+
// CHECK: %[[OFFSET_TMP1:.*]] = llvm.add{{.*}}%[[OFFSET_DIM0]]
73+
// CHECK: %[[OFFSET_DIM1:.*]] = llvm.mul{{.*}}%[[STRIDE1]]
74+
// CHECK: %[[TOTAL_OFFSET:.*]] = llvm.add %[[OFFSET_TMP1]], %[[OFFSET_DIM1]]
75+
// CHECK: %[[ADJUSTED_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[TOTAL_OFFSET]]]
76+
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <tensor<64x64xf16, #shared>>
77+
%1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
78+
79+
// CHECK: llvm.amdgcn.tensor.load.to.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> ()
80+
%2 = amdg.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, %c_pred : !tt.tensordesc<tensor<64x64xf16, #shared>> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
81+
tt.return
82+
}
83+
}
84+
85+
// -----
86+
87+
// Check that CTA offsets are computed and applied to base pointer for multi-cta layouts (store)
88+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CGALayout = [[0, 1]]}>
89+
#blocked_store = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
90+
#smem = #ttg.shared_memory
91+
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
92+
// CHECK-LABEL: tdm_store_multi_cta
93+
tt.func public @tdm_store_multi_cta(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
94+
%c_shape = arith.constant 128 : i32
95+
%c_stride0 = arith.constant 128 : i64
96+
%c_stride1 = arith.constant 1 : i64
97+
%c_offset = arith.constant 0 : i32
98+
99+
// CHECK-DAG: %[[STRIDE0:.*]] = llvm.mlir.constant(128 : i64) : i64
100+
// CHECK-DAG: %[[STRIDE1:.*]] = llvm.mlir.constant(1 : i32) : i32
101+
// CHECK-DAG: llvm.call_intrinsic "llvm.amdgcn.cluster.workgroup.id.x"
102+
// CHECK-DAG: %[[STRIDE0_TRUNC:.*]] = llvm.trunc %[[STRIDE0]] : i64 to i32
103+
// CHECK: %[[OFFSET_DIM0:.*]] = llvm.mul{{.*}}%[[STRIDE0_TRUNC]]
104+
// CHECK: %[[OFFSET_TMP1:.*]] = llvm.add{{.*}}%[[OFFSET_DIM0]]
105+
// CHECK: %[[OFFSET_DIM1:.*]] = llvm.mul{{.*}}%[[STRIDE1]]
106+
// CHECK: %[[TOTAL_OFFSET:.*]] = llvm.add %[[OFFSET_TMP1]], %[[OFFSET_DIM1]]
107+
// CHECK: %[[ADJUSTED_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[TOTAL_OFFSET]]]
108+
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <tensor<64x64xf16, #shared>>
109+
%1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
110+
// CHECK: llvm.amdgcn.tensor.store.from.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> ()
111+
amdg.async_tdm_copy_local_to_global %0[%c_offset, %c_offset] from %1: !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<tensor<64x64xf16, #shared>>
112+
tt.return
113+
}
114+
}
115+
116+
// -----
117+
118+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CGALayout = [[0, 1], [0, 2], [0, 0], [0, 0]]}>
119+
#smem = #ttg.shared_memory
120+
module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
121+
// CHECK-LABEL: tdm_load_multicast
122+
tt.func public @tdm_load_multicast(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
123+
%c_shape = arith.constant 128 : i32
124+
%c_stride0 = arith.constant 128 : i64
125+
%c_stride1 = arith.constant 1 : i64
126+
%c_offset = arith.constant 0 : i32
127+
%c_pred = arith.constant true
128+
129+
// Check we compute the multicast mask and used it in the second group of SGPRs (vector<8xi32>)
130+
// CHECK-DAG: %[[GROUP_MASK:.*]] = llvm.mlir.constant(4369 : i32) : i32
131+
// CHECK-DAG: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-13 : i32) : i32
132+
// CHECK-DAG: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
133+
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
134+
// CHECK: %[[CTA_MASK:.*]] = llvm.shl %[[GROUP_MASK]], %[[SHIFT_AMOUNT]]
135+
// Combine with other values
136+
// CHECK: %[[TMP:.*]] = llvm.or %{{.*}}, %[[CTA_MASK]]
137+
// CHECK: %[[TMP2:.*]] = llvm.and %[[TMP]]
138+
// CHECK-NOT: llvm.insertelement{{.*}} : vector<8xi32>
139+
// CHECK: llvm.insertelement %[[TMP2]]
140+
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <tensor<64x64xf16, #shared>>
141+
%1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
142+
143+
144+
// CHECK: llvm.amdgcn.tensor.load.to.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> ()
145+
%2 = amdg.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, %c_pred : !tt.tensordesc<tensor<64x64xf16, #shared>> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
146+
tt.return
147+
}
148+
}

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,8 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
482482
LogicalResult lowerDirectToLDSLoad(
483483
RewriterBase &rewriter, Location loc, RankedTensorType srcTy,
484484
MemDescType dstTy, SmallVector<Value> loadVals, Value llDst,
485-
Type resElemTy, unsigned vec, triton::AMD::ISAFamily isaFamily,
485+
Type resElemTy, unsigned vec, int numCTAs,
486+
triton::AMD::ISAFamily isaFamily,
486487
std::function<SmallVector<Value>(RewriterBase &, Location,
487488
ArrayRef<Value>, Value, int, VectorType,
488489
Value)>
@@ -514,7 +515,7 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
514515
{str_attr("offset")});
515516

516517
Value ctaMulticastMask;
517-
if (isaFamily == ISAFamily::GFX1250) {
518+
if (numCTAs > 1 && isaFamily == ISAFamily::GFX1250) {
518519
ctaMulticastMask = LLVM::AMD::emitCtaMulticastMask(
519520
rewriter, loc, targetInfo.getClusterCTAId(rewriter, loc), srcLayout);
520521
}
@@ -909,9 +910,10 @@ struct BufferLoadToLocalOpConversion
909910
return {};
910911
};
911912

913+
int numCTAs = TritonGPUDialect::getNumCTAs(op->getParentOfType<ModuleOp>());
912914
auto res = lowerDirectToLDSLoad(
913915
rewriter, loc, ptrType, flatDstTy, loadVals, llDst, resElemTy, vec,
914-
targetInfo.getISAFamily(), emitBufferLoadLds);
916+
numCTAs, targetInfo.getISAFamily(), emitBufferLoadLds);
915917
if (failed(res)) {
916918
return failure();
917919
}
@@ -1047,9 +1049,10 @@ struct AsyncCopyGlobalToLocalOpConversion
10471049
return {};
10481050
};
10491051

1052+
int numCTAs = TritonGPUDialect::getNumCTAs(op->getParentOfType<ModuleOp>());
10501053
auto res = lowerDirectToLDSLoad(
10511054
rewriter, loc, srcTy, flatDstTy, loadVals, llDst, resElemTy, vec,
1052-
targetInfo.getISAFamily(), emitGlobalLoadLds);
1055+
numCTAs, targetInfo.getISAFamily(), emitGlobalLoadLds);
10531056
if (failed(res)) {
10541057
return failure();
10551058
}
@@ -1123,20 +1126,26 @@ struct AsyncTDMCopyGlobalToLocalOpConversion
11231126
auto paddedEnc =
11241127
llvm::dyn_cast<PaddedSharedEncodingAttr>(smemTy.getEncoding());
11251128
Type elementType = getTypeConverter()->convertType(smemTy.getElementType());
1129+
int numCTAs = TritonGPUDialect::getNumCTAs(op->getParentOfType<ModuleOp>());
11261130

1131+
triton::LinearLayout sharedLayout;
11271132
unsigned padInterval = 0;
11281133
unsigned padAmount = 0;
11291134
if (paddedEnc) {
11301135
assert(paddedEnc.getIntervals().size() == 1 &&
11311136
paddedEnc.getPaddings().size() == 1);
1137+
sharedLayout = paddedEnc.getLinearComponent();
11321138
padInterval = paddedEnc.getIntervals()[0];
11331139
padAmount = paddedEnc.getPaddings()[0];
1140+
} else {
1141+
sharedLayout = triton::gpu::toLinearLayout(smemTy);
1142+
}
1143+
Value multicastMask;
1144+
if (numCTAs > 1) {
1145+
multicastMask = LLVM::AMD::emitCtaMulticastMask(
1146+
rewriter, loc, targetInfo.getClusterCTAId(rewriter, loc),
1147+
sharedLayout);
11341148
}
1135-
1136-
auto mod = op->getParentOfType<ModuleOp>();
1137-
int numCTAs = TritonGPUDialect::getNumCTAs(mod);
1138-
if (numCTAs > 1)
1139-
return rewriter.notifyMatchFailure(op, "NYI: Support multicast.");
11401149

11411150
SmallVector<Value> desc =
11421151
unpackLLElements(loc, adaptor.getDesc(), rewriter);
@@ -1165,10 +1174,17 @@ struct AsyncTDMCopyGlobalToLocalOpConversion
11651174
barrierPtr = smemObj.getBase();
11661175
}
11671176

1168-
mlir::LLVM::AMD::emitTDMOperation(rewriter, loc, getTypeConverter(), desc,
1169-
blockShape, numWarps, padInterval,
1170-
padAmount, offset, dstPtr, op.getPred(),
1171-
elementType, barrierPtr, /*isLoad=*/true);
1177+
auto kBlock = rewriter.getStringAttr("block");
1178+
auto cgaLayout = sharedLayout.sublayout(
1179+
{kBlock}, to_vector(sharedLayout.getOutDimNames()));
1180+
auto ctaId =
1181+
numCTAs > 1 ? targetInfo.getClusterCTAId(rewriter, loc) : b.i32_val(0);
1182+
1183+
auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy);
1184+
mlir::LLVM::AMD::emitTDMOperation(
1185+
rewriter, loc, getTypeConverter(), desc, shapePerCTA, numWarps,
1186+
padInterval, padAmount, offset, dstPtr, op.getPred(), multicastMask,
1187+
elementType, barrierPtr, /*isLoad=*/true, cgaLayout, ctaId);
11721188

11731189
rewriter.eraseOp(op);
11741190
return success();
@@ -1196,6 +1212,7 @@ struct AsyncTDMCopyLocalToGlobalOpConversion
11961212
auto tensorDescTy = op.getDesc().getType();
11971213
auto smemTy = op.getSrc().getType();
11981214
Type elementType = getTypeConverter()->convertType(smemTy.getElementType());
1215+
int numCTAs = TritonGPUDialect::getNumCTAs(op->getParentOfType<ModuleOp>());
11991216

12001217
SmallVector<Value> desc =
12011218
unpackLLElements(loc, adaptor.getDesc(), rewriter);
@@ -1214,11 +1231,21 @@ struct AsyncTDMCopyLocalToGlobalOpConversion
12141231
SmallVector<Value> offset = adaptor.getIndices();
12151232
int numWarps = triton::gpu::lookupNumWarps(op);
12161233

1234+
// Verifier ensures smem is not usind a PaddedSharedEncodingAttr
1235+
auto sharedLayout = triton::gpu::toLinearLayout(smemTy);
1236+
auto kBlock = rewriter.getStringAttr("block");
1237+
auto cgaLayout = sharedLayout.sublayout(
1238+
{kBlock}, to_vector(sharedLayout.getOutDimNames()));
1239+
auto ctaId =
1240+
numCTAs > 1 ? targetInfo.getClusterCTAId(rewriter, loc) : b.i32_val(0);
1241+
1242+
auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy);
12171243
mlir::LLVM::AMD::emitTDMOperation(
1218-
rewriter, loc, getTypeConverter(), desc, blockShape, numWarps,
1244+
rewriter, loc, getTypeConverter(), desc, shapePerCTA, numWarps,
12191245
/*padInterval=*/0, /*padAmount=*/0, offset, dstPtr, b.true_val(),
1220-
elementType, /*barrierPtr=*/nullptr,
1221-
/*isLoad=*/false);
1246+
/*multicastMask=*/{}, elementType,
1247+
/*barrierPtr=*/nullptr,
1248+
/*isLoad=*/false, cgaLayout, ctaId);
12221249

12231250
rewriter.eraseOp(op);
12241251
return success();

third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "TDMUtility.h"
22
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
3+
#include "triton/Tools/LayoutUtils.h"
34
#include <optional>
45

56
namespace mlir::LLVM::AMD {
@@ -365,7 +366,8 @@ void fillTDMDescriptor(
365366
unsigned padAmount, SmallVector<Value> &group0, SmallVector<Value> &group1,
366367
std::optional<std::reference_wrapper<SmallVector<Value>>> group2,
367368
std::optional<std::reference_wrapper<SmallVector<Value>>> group3,
368-
SmallVector<Value> offset, Value dstPtr, Value pred, Value barrierPtr) {
369+
SmallVector<Value> offset, Value dstPtr, Value pred, Value multicastMask,
370+
Value barrierPtr, const triton::LinearLayout &cgaLayout, Value ctaId) {
369371
size_t numDims = offset.size();
370372
assert(numDims >= 1 && numDims <= 5 && "TDM supports 1D to 5D tensors.");
371373

@@ -408,6 +410,19 @@ void fillTDMDescriptor(
408410
offset[i] = b.add(offset[i], globalOffset[i]);
409411
}
410412

413+
// We need to adjust the outer strides based on our CTAId and the block layout
414+
auto kBlock = str_attr("block");
415+
auto cgaOffsets =
416+
applyLinearLayout(loc, rewriter, cgaLayout, {{kBlock, ctaId}});
417+
// Apply CTA offsets to the base pointer
418+
// Compute the global address offset: sum(ctaOffsets[i] * tensorStride[i])
419+
Value cgaBaseOffset = b.i32_val(0);
420+
for (size_t i = 0; i < numDims; ++i) {
421+
Value dimOffset = b.mul(cgaOffsets[i].second, tensorStride[i]);
422+
cgaBaseOffset = b.add(cgaBaseOffset, dimOffset);
423+
}
424+
srcPtr = b.gep(globalPtrTy, elementType, srcPtr, cgaBaseOffset);
425+
411426
// Calculate the full global address offset based on all dimensions
412427
Value baseOffset = b.i32_val(0);
413428
for (size_t i = 0; i < numDims; ++i) {
@@ -453,6 +468,8 @@ void fillTDMDescriptor(
453468
group0[3] =
454469
b.or_(group0[3], b.trunc(i32_ty, b.lshr(globalAddr, b.i64_val(32))));
455470

471+
if (multicastMask)
472+
group1[0] = b.or_(group1[0], multicastMask);
456473
// Update groups with adjusted tensor shapes
457474
group1[1] = b.shl(tensorShape[numDims - 1], b.i32_val(16));
458475
group1[2] = b.lshr(tensorShape[numDims - 1], b.i32_val(16));
@@ -501,7 +518,9 @@ void emitTDMOperation(RewriterBase &rewriter, Location loc,
501518
ArrayRef<Value> desc, ArrayRef<int64_t> blockShape,
502519
int numWarps, unsigned padInterval, unsigned padAmount,
503520
ArrayRef<Value> offset, Value dstPtr, Value pred,
504-
Type elementType, Value barrierPtr, bool isLoad) {
521+
Value multicastMask, Type elementType, Value barrierPtr,
522+
bool isLoad, const triton::LinearLayout &cgaLayout,
523+
Value ctaId) {
505524
auto b = TritonLLVMOpBuilder(loc, rewriter);
506525

507526
assert(blockShape.size() <= 5);
@@ -514,10 +533,10 @@ void emitTDMOperation(RewriterBase &rewriter, Location loc,
514533
auto group3Vec = SmallVector<Value>(desc.begin() + 16, desc.end());
515534

516535
fillTDMDescriptor(rewriter, loc, typeConverter, elementType,
517-
SmallVector<int64_t>(blockShape), numWarps, padInterval,
518-
padAmount, group0Vec, group1Vec, std::ref(group2Vec),
519-
std::ref(group3Vec), SmallVector<Value>(offset), dstPtr,
520-
pred, barrierPtr);
536+
to_vector(blockShape), numWarps, padInterval, padAmount,
537+
group0Vec, group1Vec, std::ref(group2Vec),
538+
std::ref(group3Vec), to_vector(offset), dstPtr, pred,
539+
multicastMask, barrierPtr, cgaLayout, ctaId);
521540

522541
auto group0 = packLLVector(loc, group0Vec, rewriter);
523542
auto group1 = packLLVector(loc, group1Vec, rewriter);
@@ -535,10 +554,10 @@ void emitTDMOperation(RewriterBase &rewriter, Location loc,
535554
auto group1Vec = SmallVector<Value>(desc.begin() + 4, desc.end());
536555

537556
fillTDMDescriptor(rewriter, loc, typeConverter, elementType,
538-
SmallVector<int64_t>(blockShape), numWarps, padInterval,
539-
padAmount, group0Vec, group1Vec, std::nullopt,
540-
std::nullopt, SmallVector<Value>(offset), dstPtr, pred,
541-
barrierPtr);
557+
to_vector(blockShape), numWarps, padInterval, padAmount,
558+
group0Vec, group1Vec, std::nullopt, std::nullopt,
559+
to_vector(offset), dstPtr, pred, multicastMask,
560+
barrierPtr, cgaLayout, ctaId);
542561

543562
auto group0 = packLLVector(loc, group0Vec, rewriter);
544563
auto group1 = packLLVector(loc, group1Vec, rewriter);

third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ void emitTDMOperation(RewriterBase &rewriter, Location loc,
5151
ArrayRef<Value> desc, ArrayRef<int64_t> blockShape,
5252
int numWarps, unsigned padInterval, unsigned padAmount,
5353
ArrayRef<Value> offset, Value dstPtr, Value pred,
54-
Type elementType, Value barrierPtr, bool isLoad);
54+
Value multicastMask, Type elementType, Value barrierPtr,
55+
bool isLoad, const triton::LinearLayout &cgaLayout,
56+
Value ctaId);
5557

5658
} // namespace mlir::LLVM::AMD
5759

third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ struct MakeTensorDescOpConversion
2727

2828
auto tensorDescTy = result.getType();
2929
auto blockTy = tensorDescTy.getBlockType();
30-
auto enc = blockTy.getEncoding();
31-
if (!enc) {
30+
auto sharedEnc = blockTy.getEncoding();
31+
if (!sharedEnc) {
3232
return rewriter.notifyMatchFailure(op, "Descriptor has no layout.");
3333
}
34-
auto paddedEnc = llvm::dyn_cast<PaddedSharedEncodingAttr>(enc);
34+
auto paddedEnc = llvm::dyn_cast<PaddedSharedEncodingAttr>(sharedEnc);
3535

3636
unsigned padInterval = 0;
3737
unsigned padAmount = 0;
@@ -46,12 +46,13 @@ struct MakeTensorDescOpConversion
4646

4747
Type elementType =
4848
getTypeConverter()->convertType(blockTy.getElementType());
49-
SmallVector<int64_t> blockShape = llvm::to_vector(blockTy.getShape());
49+
SmallVector<int64_t> blockShape = to_vector(blockTy.getShape());
5050
int numWarps = lookupNumWarps(op);
51+
auto shapePerCTA = triton::gpu::getShapePerCTA(sharedEnc, blockShape);
5152

5253
// Create TDM descriptor for 2D-5D tensors
5354
auto tdmDesc = LLVM::AMD::createTDMDescriptor(
54-
rewriter, loc, getTypeConverter(), elementType, blockShape, numWarps,
55+
rewriter, loc, getTypeConverter(), elementType, shapePerCTA, numWarps,
5556
padInterval, padAmount, tensorShape, tensorStride, basePtr);
5657

5758
SmallVector<Value> groups = tdmDesc.getAllGroups();

0 commit comments

Comments
 (0)