Skip to content

Commit d0dafa5

Browse files
committed
rebase
1 parent debf809 commit d0dafa5

File tree

6 files changed

+118
-91
lines changed

6 files changed

+118
-91
lines changed

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,10 @@ class WarpIdOpPattern : public OpRewritePattern<ttn::WarpIdOp> {
231231
LogicalResult matchAndRewrite(ttn::WarpIdOp op,
232232
PatternRewriter &rewriter) const override {
233233
auto loc = op.getLoc();
234+
auto b = TritonLLVMOpBuilder(loc, rewriter);
235+
234236
Value threadId = rewriter.create<NVVM::ThreadIdXOp>(loc, i32_ty);
235-
Value warpId = udiv(threadId, i32_val(32));
237+
Value warpId = b.udiv(threadId, b.i32_val(32));
236238
warpId = LLVM::NVIDIA::shuffleIdx(loc, rewriter, warpId, 0);
237239
rewriter.replaceOp(op, warpId);
238240
return success();
@@ -648,6 +650,7 @@ static Value createTMAlloc(IRRewriter &rewriter, LLVM::LLVMFuncOp func,
648650
size_t size, Value pred, bool twoCTAs) {
649651
PTXBuilder ptxBuilder;
650652
Location loc = func.getLoc();
653+
auto b = TritonLLVMOpBuilder(loc, rewriter);
651654
Value sharedMem = mlir::LLVM::getStackPointer(rewriter, func);
652655
std::string ptxString =
653656
"@$0 tcgen05.alloc.cta_group::" + std::to_string(twoCTAs ? 2 : 1) +
@@ -660,9 +663,9 @@ static Value createTMAlloc(IRRewriter &rewriter, LLVM::LLVMFuncOp func,
660663
auto voidTy = void_ty(func->getContext());
661664
ptxBuilder.launch(rewriter, loc, void_ty(func->getContext()));
662665
rewriter.create<NVVM::Barrier0Op>(loc);
663-
Value address = load(i32_ty, sharedMem);
666+
Value address = b.load(i32_ty, sharedMem);
664667
rewriter.create<NVVM::Barrier0Op>(loc);
665-
address = inttoptr(ptr_ty(func.getContext(), 6), address);
668+
address = b.inttoptr(ptr_ty(func.getContext(), 6), address);
666669
return address;
667670
}
668671

@@ -709,6 +712,7 @@ static Value initTensorMemory(LLVM::LLVMFuncOp func) {
709712
rewriter.setInsertionPointToStart(&func.front());
710713
auto ctx = mod.getContext();
711714
auto loc = func.getLoc();
715+
auto b = TritonLLVMOpBuilder(loc, rewriter);
712716
// A proper error will be raised by the frontend, but to allow compilation to
713717
// continue we emit a trap.
714718
if (size > 512) {
@@ -721,7 +725,7 @@ static Value initTensorMemory(LLVM::LLVMFuncOp func) {
721725
// should be fine for now.
722726
bool useTwoCTAs = numCTAs == 2;
723727
Value threadId = rewriter.create<NVVM::ThreadIdXOp>(loc, i32_ty);
724-
Value pred = icmp_ult(threadId, i32_val(32));
728+
Value pred = b.icmp_ult(threadId, b.i32_val(32));
725729
Value alloc = createTMAlloc(rewriter, func, size, pred, useTwoCTAs);
726730
createRelinquishAlloc(rewriter, loc, pred, useTwoCTAs);
727731
// TODO: pred will have a long liverange, we need to check if this is a

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,23 @@ mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::DotOpMmaV5TmemLoader(
3333

3434
Value mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::tmemLoad(
3535
int a, int b, ConversionPatternRewriter &rewriter, Location loc) {
36+
auto tb = TritonLLVMOpBuilder(loc, rewriter);
3637
int numRows = 64;
3738
if (interleaved || instrShape[0] >= 128)
3839
numRows = 128;
3940
int numColPerBlock =
4041
((instrShape[0] * instrShape[1]) / numRows) / numElementsPer32b;
4142
Value address = base;
4243
int blockId = a + b * numRepM;
43-
address = ptrtoint(i32_ty, address);
44+
address = tb.ptrtoint(i32_ty, address);
4445
if (!interleaved) {
45-
address = add(address, i32_val(numColPerBlock * blockId));
46+
address = tb.add(address, tb.i32_val(numColPerBlock * blockId));
4647
} else {
4748
int blockIdIsOdd = blockId & 1;
4849
int blockIdPrevEven = blockId - blockIdIsOdd;
49-
Value offset =
50-
i32_val(numColPerBlock * blockIdPrevEven + ((16 * blockIdIsOdd) << 16));
51-
address = add(address, offset);
50+
Value offset = tb.i32_val(numColPerBlock * blockIdPrevEven +
51+
((16 * blockIdIsOdd) << 16));
52+
address = tb.add(address, offset);
5253
}
5354
return address;
5455
}
@@ -72,6 +73,7 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter,
7273
triton::nvidia_gpu::TCGen5MMAOp op, int M,
7374
int N, bool transposeA, bool transposeB) {
7475
Location loc = op.getLoc();
76+
auto b = TritonLLVMOpBuilder(loc, rewriter);
7577
union TCGen5InstructionDescriptor {
7678
uint32_t descriptor;
7779
struct {
@@ -119,7 +121,7 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter,
119121
Type dstElType = op.getD().getType().getElementType();
120122
assert(dstElType.isF16() || dstElType.isF32());
121123
desc.dType = dstElType.isF16() ? 0 : 1;
122-
return int_val(32, desc.descriptor);
124+
return b.int_val(32, desc.descriptor);
123125
}
124126

125127
static Value createScaleInstDescriptor(ConversionPatternRewriter &rewriter,
@@ -129,6 +131,7 @@ static Value createScaleInstDescriptor(ConversionPatternRewriter &rewriter,
129131
int scaleFactorsubIdxB,
130132
mxfpKind mxfpInstKind) {
131133
Location loc = op.getLoc();
134+
auto b = TritonLLVMOpBuilder(loc, rewriter);
132135
union TCGen5InstructionDescriptor {
133136
uint32_t descriptor;
134137
struct {
@@ -209,7 +212,7 @@ static Value createScaleInstDescriptor(ConversionPatternRewriter &rewriter,
209212
}
210213
}
211214

212-
return int_val(32, desc.descriptor);
215+
return b.int_val(32, desc.descriptor);
213216
}
214217

215218
static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc,
@@ -276,6 +279,7 @@ static void createScaledGen5MMA(ConversionPatternRewriter &rewriter,
276279
static void createMMACommit(ConversionPatternRewriter &rewriter, Location loc,
277280
Value barrier, Value pred, bool twoCTAs = false) {
278281
PTXBuilder ptxBuilder;
282+
auto b = TritonLLVMOpBuilder(loc, rewriter);
279283
SmallVector<PTXBuilder::Operand *> ptxOperands;
280284
auto *predOperand = ptxBuilder.newOperand(pred, "b");
281285
ptxOperands.push_back(predOperand);
@@ -285,7 +289,7 @@ static void createMMACommit(ConversionPatternRewriter &rewriter, Location loc,
285289
if (twoCTAs) {
286290
// .multicast::cluster and mask 0x3 means the completion of UTCMMA.2CTA will
287291
// be boardcasted into CTAid 0 and 1
288-
auto *ctaMask = ptxBuilder.newOperand(int_val(16, 0x3), "h");
292+
auto *ctaMask = ptxBuilder.newOperand(b.int_val(16, 0x3), "h");
289293
ptxOperands.push_back(ctaMask);
290294
opcode = "@$0 "
291295
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::"
@@ -303,23 +307,23 @@ void convertDot(const LLVMTypeConverter *typeConverter,
303307
triton::nvidia_gpu::TCGen5MMAOp op, Value a, Value b, Value d,
304308
Value loadedA, Value loadedB, Value loadedD, Value useDFlag,
305309
Value pred, Value barrier) {
306-
310+
auto tb = TritonLLVMOpBuilder(loc, rewriter);
307311
bool twoCTAs = op.getTwoCtas().has_value();
308312
// Only run mma on one thread. We currently use elect as ptxas is not able to
309313
// detect that tid.x == 0 is true only for 1 thread.
310314
Value warpId = rewriter.create<nvgpu::WarpIdOp>(loc);
311-
Value wapr0 = icmp_eq(warpId, i32_val(0));
315+
Value wapr0 = tb.icmp_eq(warpId, tb.i32_val(0));
312316
if (twoCTAs) {
313317
// TODO: we have to sync the two CTAs because we currently don't use remove
314318
// barriers for the copies.
315319
rewriter.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
316320
rewriter.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
317321

318322
Value clusterId = rewriter.create<nvgpu::ClusterCTAIdOp>(loc);
319-
Value cluster0 = icmp_eq(clusterId, i32_val(0));
320-
pred = and_(pred, cluster0);
323+
Value cluster0 = tb.icmp_eq(clusterId, tb.i32_val(0));
324+
pred = tb.and_(pred, cluster0);
321325
}
322-
pred = and_(pred, wapr0);
326+
pred = tb.and_(pred, wapr0);
323327

324328
// Wrap the whole mma code sequence within a IF block.
325329
auto *curBlock = rewriter.getInsertionBlock();
@@ -382,7 +386,7 @@ void convertDot(const LLVMTypeConverter *typeConverter,
382386
Value instDescriptor =
383387
createInstDescriptor(rewriter, op, twoCTAs ? mmaSizeM * 2 : mmaSizeM,
384388
mmaSizeN, transA, transB);
385-
Value zero = i32_val(0);
389+
Value zero = tb.i32_val(0);
386390
SmallVector<int64_t> shapeA(triton::gpu::getShapePerCTA(aTensorTy));
387391
SmallVector<int64_t> shapeB(triton::gpu::getShapePerCTA(bTensorTy));
388392
SmallVector<unsigned> aOperandShape = {(unsigned)mmaSizeM,
@@ -411,7 +415,7 @@ void convertDot(const LLVMTypeConverter *typeConverter,
411415
b = bLoader.smemLoad(n, k, rewriter, loc);
412416
createGen5MMA(rewriter, loc, op, a, b, accAddress, pred, instDescriptor,
413417
useInitAcc, aInTmem, twoCTAs);
414-
useInitAcc = i1_val(1);
418+
useInitAcc = tb.i1_val(1);
415419
}
416420
}
417421
}
@@ -475,6 +479,7 @@ struct TCGen5MMAScaledOpConversion
475479
"tensorcore op should have a barrier at this point.");
476480
auto typeConverter = getTypeConverter();
477481
Location loc = op.getLoc();
482+
auto tb = TritonLLVMOpBuilder(loc, rewriter);
478483
auto aTensorTy = cast<MemDescType>(op.getA().getType());
479484
auto bTensorTy = cast<MemDescType>(op.getB().getType());
480485
auto dTensorTy = cast<MemDescType>(op.getD().getType());
@@ -508,15 +513,15 @@ struct TCGen5MMAScaledOpConversion
508513
loc, adaptor.getD(),
509514
typeConverter->convertType(dTensorTy.getElementType()), rewriter)
510515
.getBase();
511-
baseD = ptrtoint(i32_ty, baseD);
516+
baseD = tb.ptrtoint(i32_ty, baseD);
512517
Value baseScaleA = getSharedMemoryObjectFromStruct(loc, adaptor.getAScale(),
513518
i8_ty, rewriter)
514519
.getBase();
515520
Value baseScaleB = getSharedMemoryObjectFromStruct(loc, adaptor.getBScale(),
516521
i8_ty, rewriter)
517522
.getBase();
518-
baseScaleA = ptrtoint(i32_ty, baseScaleA);
519-
baseScaleB = ptrtoint(i32_ty, baseScaleB);
523+
baseScaleA = tb.ptrtoint(i32_ty, baseScaleA);
524+
baseScaleB = tb.ptrtoint(i32_ty, baseScaleB);
520525

521526
unsigned int M = dTensorTy.getDimSize(0);
522527
unsigned int N = dTensorTy.getDimSize(1);
@@ -537,7 +542,7 @@ struct TCGen5MMAScaledOpConversion
537542
int numRepK = ceil<unsigned>(K, mmaSizeK);
538543
bool interleaved = (mmaSizeM == 64 && (numRepM > 1 || numRepN > 1));
539544

540-
Value zero = i32_val(0);
545+
Value zero = tb.i32_val(0);
541546
SmallVector<int64_t> shapeA(aTensorTy.getShape());
542547
SmallVector<int64_t> shapeB(bTensorTy.getShape());
543548
if (opKindIsMXFP4) {
@@ -561,11 +566,12 @@ struct TCGen5MMAScaledOpConversion
561566
numBitsPerElementB, rewriter, loc);
562567

563568
// TODO: Support accumulator init optimization for scaled dot
564-
Value useInitAcc = int_val(1, 1);
569+
Value useInitAcc = tb.int_val(1, 1);
565570
// Only run mma on one thread. We currently use elect as ptxas is not able
566571
// to detect that tid.x == 0 is true only for 1 thread.
567-
Value pred = and_(adaptor.getPred(),
568-
LLVM::NVIDIA::createElectPredicateWarp0(loc, rewriter));
572+
Value pred =
573+
tb.and_(adaptor.getPred(),
574+
LLVM::NVIDIA::createElectPredicateWarp0(loc, rewriter));
569575
int numRows = 128;
570576
int colSizeInBits = 32;
571577
int numColPerBlock =
@@ -599,16 +605,16 @@ struct TCGen5MMAScaledOpConversion
599605
// Blocks are laid out along M first then N as described in
600606
// `TensorMemorySpace` definition.
601607
int blockId = m + n * numRepM;
602-
Value accAddress = add(baseD, i32_val(numColPerBlock * blockId));
608+
Value accAddress = tb.add(baseD, tb.i32_val(numColPerBlock * blockId));
603609
for (int k = 0; k < numRepK; k++) {
604610
Value a = aLoader->memLoad(m, k, rewriter, loc);
605611
Value b = bLoader.smemLoad(n, k, rewriter, loc);
606612
int subWordIdx = k % (4 / scaleFactorColsPerSet);
607613
int wordIdx = k / (4 / scaleFactorColsPerSet);
608-
Value scaleA = add(baseScaleA, i32_val((m + wordIdx * numRepM) *
609-
numColPerScaleBlockA));
610-
Value scaleB = add(baseScaleB, i32_val((n + wordIdx * numRepN) *
611-
numColPerScaleBlockB));
614+
Value scaleA = tb.add(baseScaleA, tb.i32_val((m + wordIdx * numRepM) *
615+
numColPerScaleBlockA));
616+
Value scaleB = tb.add(baseScaleB, tb.i32_val((n + wordIdx * numRepN) *
617+
numColPerScaleBlockB));
612618
Value instDescriptor = createScaleInstDescriptor(
613619
rewriter, op, mmaSizeM, mmaSizeN, transA, transB, subWordIdx,
614620
subWordIdx, mxfpInstKind);

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ int64_t getSwizzlingFromLayout(const SharedEncodingAttr &layout,
9393

9494
static Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc,
9595
int64_t swizzling, uint32_t stride) {
96+
auto b = TritonLLVMOpBuilder(loc, rewriter);
9697
static_assert(sizeof(SMEMDescriptor) == 8,
9798
"Descriptor size should be 64 bits.");
9899
SMEMDescriptor desc;
@@ -144,17 +145,16 @@ Value mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::smemLoad(
144145
auto tb = TritonLLVMOpBuilder(loc, rewriter);
145146
Value k = tb.i32_val(b * instrShape[1]);
146147
Value m = tb.add(tb.i32_val(a * dimWpt * instrShape[0]),
147-
tb.mul(warpId, tb.i32_val(instrShape[0])));
148-
if (trans) {
149-
std::swap(k, m);
150-
}
151-
Value leading_offset =
152-
tb.mul(tb.udiv(k, elemsPerSwizzlingRowVal),
153-
tb.i32_val(shape[ord[1]] * elemsPerSwizzlingRow));
148+
tb.mul(warpId, tb.i32_val(instrShape[0])));
149+
if (trans) {
150+
std::swap(k, m);
151+
}
152+
Value leading_offset =
153+
tb.mul(tb.udiv(k, elemsPerSwizzlingRowVal),
154+
tb.i32_val(shape[ord[1]] * elemsPerSwizzlingRow));
154155
Value stride_offset = tb.mul(m, elemsPerSwizzlingRowVal);
155-
Value offset =
156-
tb.add(tb.add(leading_offset, stride_offset),
157-
tb.urem(k, elemsPerSwizzlingRowVal));
156+
Value offset = tb.add(tb.add(leading_offset, stride_offset),
157+
tb.urem(k, elemsPerSwizzlingRowVal));
158158
Value off1;
159159
// Avoid the runtime udiv if we know the elements are byte multiples
160160
if (elemBits % 8) {
@@ -168,8 +168,8 @@ Value mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::smemLoad(
168168
// Add the base at the end to make it easier to do loop invariant code
169169
// motion.
170170
loadDesc = tb.add(
171-
loadDesc, tb.lshr(tb.shl(tb.ptrtoint(i64_ty, base), tb.int_val(64, 46)),
172-
tb.int_val(64, 50)));
171+
loadDesc, tb.lshr(tb.shl(tb.ptrtoint(i64_ty, base), tb.int_val(64, 46)),
172+
tb.int_val(64, 50)));
173173
return loadDesc;
174174
}
175175

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,7 @@ static LogicalResult iterateGatherScatterIndices(
13841384
function_ref<void(Value, Value, Value, ArrayRef<Value>)> callback) {
13851385
MLIRContext *ctx = op->getContext();
13861386
Location loc = op->getLoc();
1387+
auto b = TritonLLVMOpBuilder(loc, rewriter);
13871388

13881389
StringAttr kDim0 = str_attr("dim0");
13891390
StringAttr kDim1 = str_attr("dim1");
@@ -1461,24 +1462,25 @@ static LogicalResult iterateGatherScatterIndices(
14611462

14621463
Value warpId = rewriter.create<nvgpu::WarpIdOp>(loc);
14631464
// Each block has separate shared memory. Multiple CTAs don't work anyways.
1464-
Value blockId = i32_val(0);
1465+
Value blockId = b.i32_val(0);
14651466

14661467
// Mask out warps with redundant x offsets.
1467-
pred = and_(pred, icmp_eq(i32_val(0), and_(warpId, i32_val(warpMask))));
1468+
pred = b.and_(pred,
1469+
b.icmp_eq(b.i32_val(0), b.and_(warpId, b.i32_val(warpMask))));
14681470
// Select one thread in each warp to issue the gather4 messages.
1469-
pred = and_(pred, LLVM::NVIDIA::createElectPredicate(loc, rewriter));
1471+
pred = b.and_(pred, LLVM::NVIDIA::createElectPredicate(loc, rewriter));
14701472

14711473
SmallVector<Value> xOffsets = unpackLLElements(loc, xOffsetsValue, rewriter);
14721474
// Lane ID doesn't matter.
1473-
Value laneId = i32_val(0);
1475+
Value laneId = b.i32_val(0);
14741476
for (auto regId : seq<unsigned>(0, xOffsets.size(), 4)) {
14751477
// Skip redundant x offsets within a thread.
14761478
if ((regMask & regId) != 0)
14771479
continue;
1478-
Value regIdVal = i32_val(regId);
1480+
Value regIdVal = b.i32_val(regId);
14791481

14801482
for (auto msgId : llvm::seq(numMessagesPerRow)) {
1481-
Value msgIdVal = i32_val(msgId);
1483+
Value msgIdVal = b.i32_val(msgId);
14821484

14831485
auto result = applyLinearLayout(loc, rewriter, msgToShared,
14841486
{{kMsg, msgIdVal},
@@ -1492,8 +1494,8 @@ static LogicalResult iterateGatherScatterIndices(
14921494
// Because we checked that the memdesc's allocshape and shape match, we
14931495
// can ignore the strides and directly index into the shmem object.
14941496
Value shMemPtr =
1495-
gep(elemPtrTy, llvmElemTy, smemObj.getBase(), shMemOffset);
1496-
Value yOffset = add(yOffsetValue, i32_val(msgId * msgSize));
1497+
b.gep(elemPtrTy, llvmElemTy, smemObj.getBase(), shMemOffset);
1498+
Value yOffset = b.add(yOffsetValue, b.i32_val(msgId * msgSize));
14971499

14981500
callback(pred, shMemPtr, yOffset, ArrayRef(xOffsets).slice(regId, 4));
14991501
};
@@ -1571,6 +1573,7 @@ LogicalResult AsyncTMAScatterOpConversion::matchAndRewrite(
15711573
triton::nvidia_gpu::AsyncTMAScatterOp op, OpAdaptor adaptor,
15721574
ConversionPatternRewriter &rewriter) const {
15731575
Location loc = op.getLoc();
1576+
auto b = TritonLLVMOpBuilder(loc, rewriter);
15741577
MLIRContext *ctx = getContext();
15751578
LLVM::LLVMVoidType voidTy = void_ty(op->getContext());
15761579

@@ -1601,7 +1604,7 @@ LogicalResult AsyncTMAScatterOpConversion::matchAndRewrite(
16011604
if (failed(iterateGatherScatterIndices(
16021605
op, rewriter, *getTypeConverter(), op.getXOffsets(), op.getSrc(),
16031606
adaptor.getSrc(), adaptor.getXOffsets(), adaptor.getYOffset(),
1604-
/*pred=*/true_val(), callback)))
1607+
/*pred=*/b.true_val(), callback)))
16051608
return failure();
16061609

16071610
// TODO: Separate the syncronizations operations into separate TTGIR ops to

0 commit comments

Comments
 (0)