Skip to content

Commit 9ca8bd3

Browse files
authored
[LAYOUTS] Get warp number and thread number from Module (#6068)
As per title
1 parent c2fd8e1 commit 9ca8bd3

File tree

13 files changed

+48
-71
lines changed

13 files changed

+48
-71
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ int lookupNumWarps(Operation *op);
5151
// verifiers.
5252
std::optional<int> maybeLookupNumWarps(Operation *op);
5353

54+
// FIXME: Make this API and that of maybeLookupNumWarps consistent!
55+
// Utility to find the number of threads per warp
56+
int lookupThreadsPerWarp(OpBuilder &rewriter);
57+
5458
class LinearLayoutCache {
5559
public:
5660
std::optional<LinearLayout> get(const CacheKey &key) {
@@ -97,8 +101,6 @@ SmallVector<unsigned> getElemsPerThread(Type type);
97101
// getThreadsPerWarpWithUniqueData.
98102
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
99103

100-
unsigned getWarpSize(Attribute layout);
101-
102104
// Returns the number of warps per CTA that may have access to replicated
103105
// elements. If you want non-replicated warps, use getWarpsPerCTAWithUniqueData.
104106
SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
@@ -196,8 +198,6 @@ SmallVector<int64_t> getAllocationShapePerCTA(Attribute layout,
196198
ArrayRef<int64_t> shape);
197199
SmallVector<int64_t> getAllocationShapePerCTA(Type type);
198200

199-
unsigned getNumWarpsPerCTA(Attribute layout);
200-
201201
unsigned getNumCTAs(Attribute layout);
202202

203203
// Return the order that represents that the batch is in row-major or

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -194,26 +194,17 @@ Value getThreadId(OpBuilder &rewriter, Location loc) {
194194
return tid;
195195
}
196196

197-
static int lookupThreadsPerWarp(OpBuilder &rewriter) {
198-
assert(rewriter.getInsertionBlock() && "expected an insertion point");
199-
Operation *op = rewriter.getInsertionBlock()->getParentOp();
200-
while (op && !isa<ModuleOp>(op))
201-
op = op->getParentOp();
202-
assert(op && "cannot create thread ID outside of module");
203-
return triton::gpu::TritonGPUDialect::getThreadsPerWarp(cast<ModuleOp>(op));
204-
}
205-
206197
Value getLaneId(OpBuilder &rewriter, Location loc) {
207198
TritonLLVMOpBuilder b(loc, rewriter);
208199
Value tid = getThreadId(rewriter, loc);
209-
int threadsPerWarp = lookupThreadsPerWarp(rewriter);
200+
int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter);
210201
return b.urem(tid, b.i32_val(threadsPerWarp));
211202
}
212203

213204
std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc) {
214205
TritonLLVMOpBuilder b(loc, rewriter);
215206
Value tid = getThreadId(rewriter, loc);
216-
int threadsPerWarp = lookupThreadsPerWarp(rewriter);
207+
int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter);
217208
Value warpSizeVal = b.i32_val(threadsPerWarp);
218209

219210
Value laneId = b.urem(tid, warpSizeVal);

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,6 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
7575
}
7676
}
7777

78-
unsigned getWarpSize(Attribute layout) {
79-
unsigned size = 1;
80-
auto threadsPerWarp = getThreadsPerWarp(layout);
81-
for (auto e : threadsPerWarp) {
82-
size *= e;
83-
}
84-
return size;
85-
}
86-
8778
SmallVector<unsigned>
8879
getThreadsPerWarpWithUniqueData(Attribute layout,
8980
ArrayRef<int64_t> tensorShape) {
@@ -377,28 +368,6 @@ SmallVector<int64_t> getAllocationShapePerCTA(Type type) {
377368
tensorType.getShape());
378369
}
379370

380-
unsigned getNumWarpsPerCTA(Attribute layout) {
381-
SmallVector<unsigned> warpsPerCTA;
382-
if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout))
383-
warpsPerCTA = blockedLayout.getWarpsPerCTA();
384-
else if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout))
385-
return getNumWarpsPerCTA(sliceLayout.getParent());
386-
else if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
387-
// Use the distributed layout interface to get the number of warps per
388-
// CTA.
389-
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
390-
warpsPerCTA = distributedLayout.getWarpsPerCTA();
391-
} else if (auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(layout))
392-
warpsPerCTA = mfmaLayout.getWarpsPerCTA();
393-
else if (auto wmmaLayout = dyn_cast<AMDWmmaEncodingAttr>(layout))
394-
warpsPerCTA = wmmaLayout.getWarpsPerCTA();
395-
else if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout))
396-
warpsPerCTA = dotLayout.getWarpsPerCTA();
397-
else
398-
llvm::report_fatal_error("Unimplemented usage of getNumWarpsPerCTA");
399-
return product<unsigned>(warpsPerCTA);
400-
}
401-
402371
unsigned getNumCTAs(Attribute layout) {
403372
return product<unsigned>(getCTAsPerCGA(layout));
404373
}
@@ -3496,3 +3465,12 @@ int triton::gpu::lookupNumWarps(Operation *op) {
34963465
}
34973466
return *numWarps;
34983467
}
3468+
3469+
int triton::gpu::lookupThreadsPerWarp(OpBuilder &rewriter) {
3470+
assert(rewriter.getInsertionBlock() && "expected an insertion point");
3471+
Operation *op = rewriter.getInsertionBlock()->getParentOp();
3472+
while (op && !isa<ModuleOp>(op))
3473+
op = op->getParentOp();
3474+
assert(op && "cannot create thread ID outside of module");
3475+
return triton::gpu::TritonGPUDialect::getThreadsPerWarp(cast<ModuleOp>(op));
3476+
}

lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,18 @@ Type replaceLayout(const Type &type, const Attribute &newLayout) {
5959

6060
ttg::DistributedEncodingTrait
6161
replaceCTALayout(ttg::DistributedEncodingTrait layout,
62-
llvm::ArrayRef<int64_t> shape,
62+
llvm::ArrayRef<int64_t> shape, int numWarps,
6363
const ttg::CTALayoutAttr &newCTALayout) {
6464
if (auto blockedLayout = mlir::dyn_cast<ttg::BlockedEncodingAttr>(layout)) {
6565
return ttg::BlockedEncodingAttr::get(
6666
layout.getContext(), shape, blockedLayout.getSizePerThread(),
67-
blockedLayout.getOrder(), ttg::getNumWarpsPerCTA(layout), 32,
68-
newCTALayout);
67+
blockedLayout.getOrder(), numWarps, 32, newCTALayout);
6968
} else if (auto sliceLayout =
7069
mlir::dyn_cast<ttg::SliceEncodingAttr>(layout)) {
7170
return ttg::SliceEncodingAttr::get(
7271
layout.getContext(), sliceLayout.getDim(),
73-
replaceCTALayout(sliceLayout.getParent(), shape, newCTALayout));
72+
replaceCTALayout(sliceLayout.getParent(), shape, numWarps,
73+
newCTALayout));
7474
} else {
7575
// Other layouts are generated by passes after PlanCTAPass
7676
llvm::report_fatal_error("replaceCTALayout not implemented");
@@ -293,11 +293,15 @@ bool CTAPlanner::processDot(triton::FuncOp &funcOp) {
293293
// FIXME: Should consider IR with more than one DotOps
294294
setTiling({splitM, splitN, 1});
295295

296+
OpBuilder builder(dot);
297+
auto numThreads = ttg::lookupThreadsPerWarp(builder);
298+
auto numWarps = ttg::lookupNumWarps(dot);
299+
296300
auto newCTALayout = ttg::CTALayoutAttr::get(ctx, {splitM, splitN},
297301
{splitM, splitN}, {1, 0});
298302
auto newDLayout = ttg::BlockedEncodingAttr::get(
299303
ctx, dTy.getShape(), dLayout.getSizePerThread(), dLayout.getOrder(),
300-
ttg::getNumWarpsPerCTA(dLayout), 32, newCTALayout);
304+
numWarps, numThreads, newCTALayout);
301305
auto newALayout = ttg::DotOperandEncodingAttr::get(ctx, aLayout.getOpIdx(),
302306
newDLayout, 0);
303307
auto newBLayout = ttg::DotOperandEncodingAttr::get(ctx, bLayout.getOpIdx(),
@@ -359,12 +363,14 @@ bool CTAPlanner::processReduce(triton::FuncOp &funcOp) {
359363
if (remainingCTAs > 0)
360364
CTAsPerCGA[order[rank - 1]] *= remainingCTAs;
361365

366+
auto numWarps = ttg::lookupNumWarps(reduce);
362367
auto CTALayout =
363368
ttg::CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
364369
if (!tiled)
365370
setTiling(CTALayout.getCTAsPerCGA());
366-
auto newSrcLayout = replaceCTALayout(
367-
cast<ttg::DistributedEncodingTrait>(srcLayout), srcShape, CTALayout);
371+
auto newSrcLayout =
372+
replaceCTALayout(cast<ttg::DistributedEncodingTrait>(srcLayout),
373+
srcShape, numWarps, CTALayout);
368374
auto newResultLayout =
369375
ttg::SliceEncodingAttr::get(context, axis, newSrcLayout);
370376
unsigned numOperands = reduce.getNumOperands();
@@ -386,6 +392,7 @@ void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
386392
stores.push_back(op);
387393
});
388394
assert(stores.size() > 0 && "Cannot find store-like ops");
395+
auto numWarps = ttg::lookupNumWarps(funcOp);
389396

390397
ttg::CTALayoutAttr CTALayout;
391398
for (Operation *store : stores) {
@@ -398,7 +405,7 @@ void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
398405
}
399406
auto newLayout = replaceCTALayout(
400407
cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding()),
401-
tensorTy.getShape(), CTALayout);
408+
tensorTy.getShape(), numWarps, CTALayout);
402409
processElementwise(store, newLayout);
403410
}
404411
}
@@ -624,6 +631,7 @@ bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) {
624631
}
625632

626633
auto CTALayout = ttg::getCTALayout(layout);
634+
auto numWarps = ttg::lookupNumWarps(op);
627635

628636
llvm::SmallVector<Attribute> newOperandLayouts;
629637
for (unsigned i = 0; i < op->getNumOperands(); ++i) {
@@ -634,7 +642,7 @@ bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) {
634642
auto oldLayout =
635643
cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding());
636644
auto newLayout =
637-
replaceCTALayout(oldLayout, tensorTy.getShape(), CTALayout);
645+
replaceCTALayout(oldLayout, tensorTy.getShape(), numWarps, CTALayout);
638646
newOperandLayouts.push_back(newLayout);
639647
}
640648

@@ -647,7 +655,7 @@ bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) {
647655
auto oldLayout =
648656
cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding());
649657
auto newLayout =
650-
replaceCTALayout(oldLayout, tensorTy.getShape(), CTALayout);
658+
replaceCTALayout(oldLayout, tensorTy.getShape(), numWarps, CTALayout);
651659
newResultLayouts.push_back(newLayout);
652660
}
653661

test/TritonGPU/amd/mfma-double-rate.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
// CHECK-LABEL:mfma_16x16x32_f16
44

55
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = false}>
6-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.threads-per-warp" = 64 : i32} {
6+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
77
tt.func public @mfma_16x16x32_f16(%arg0: tensor<16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>,
88
%arg1: tensor<32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) {
99
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
@@ -18,7 +18,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.th
1818
// CHECK-LABEL:mfma_16x16x32_bf16
1919

2020
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = false}>
21-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.threads-per-warp" = 64 : i32} {
21+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
2222
tt.func public @mfma_16x16x32_bf16(%arg0: tensor<16x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>,
2323
%arg1: tensor<32x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) {
2424
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
@@ -33,7 +33,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.th
3333
// CHECK-LABEL:mfma_32x32x16_f16
3434

3535
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = false}>
36-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.threads-per-warp" = 64 : i32} {
36+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
3737
tt.func public @mfma_32x32x16_f16(%arg0: tensor<32x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>,
3838
%arg1: tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) {
3939
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
@@ -49,7 +49,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.th
4949
// CHECK-LABEL:mfma_32x32x16_bf16
5050

5151
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = false}>
52-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.threads-per-warp" = 64 : i32} {
52+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
5353
tt.func public @mfma_32x32x16_bf16(%arg0: tensor<32x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>,
5454
%arg1: tensor<16x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) {
5555
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ struct ConvertLayoutOpMFMAToDotOpConversion
5757
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcType.getEncoding());
5858
assert((mfmaLayout.getMDim() == 16 || mfmaLayout.getMDim() == 32) &&
5959
"Expected MFMA size 16 or 32");
60-
assert(triton::gpu::getWarpSize(mfmaLayout) == 64 &&
60+
assert(triton::gpu::lookupThreadsPerWarp(rewriter) == 64 &&
6161
"Expected warp size 64 for MFMA");
6262

6363
auto elemTy = int_ty(8);

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
252252
numRepK = numReps[kDimIdx + 1];
253253
}
254254

255-
unsigned iWarpSize = triton::gpu::getWarpSize(mfmaLayout);
255+
unsigned iWarpSize = triton::gpu::lookupThreadsPerWarp(rewriter);
256256
assert(iWarpSize == 64);
257257
Value warpSize = tb.i32_val(iWarpSize);
258258
Value linearWarpId = tb.udiv(thread, warpSize);

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
171171
auto numRepK = numReps[opIdx == 0 ? 2 : 1];
172172
auto repB = numReps[0];
173173

174-
unsigned iWaveSize = triton::gpu::getWarpSize(wmmaLayout);
174+
unsigned iWaveSize = triton::gpu::lookupThreadsPerWarp(rewriter);
175175
assert(iWaveSize == 32);
176176
Value waveSize = tb.i32_val(iWaveSize);
177177
Value linearWaveId = tb.udiv(thread, waveSize);

third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ struct DecomposeUnsupportedAMDConversions
8282
return;
8383
}
8484

85-
unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc);
85+
unsigned numWarps = lookupNumWarps(cvtOp);
8686

8787
// Find all possible shapes of WarpsPerCTA by finding all possible
8888
// factorizations of numWarps. Pick shape for which both conversions in

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ struct DotOpMFMAConversionHelper {
313313
auto dstElemTy = dTensorTy.getElementType();
314314
auto fc = unpackLLElements(loc, loadedC, rewriter);
315315

316-
unsigned warpSize = triton::gpu::getWarpSize(mfmaLayout);
316+
unsigned warpSize = triton::gpu::lookupThreadsPerWarp(rewriter);
317317
// compute number of output elements that each thread holds for one MFMA
318318
// instruction.
319319
const int subBlocks =
@@ -640,7 +640,7 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
640640
auto dstElemTy = dTensorTy.getElementType();
641641
auto fc = unpackLLElements(loc, loadedC, rewriter);
642642

643-
unsigned warpSize = triton::gpu::getWarpSize(mfmaLayout);
643+
unsigned warpSize = triton::gpu::lookupThreadsPerWarp(rewriter);
644644
// compute number of output elements that each thread holds for one MFMA
645645
// instruction. subBlocks
646646
const int subBlocks =

0 commit comments

Comments
 (0)