Skip to content

Commit 893a254

Browse files
committed
Replace macros for llvm ops with TritonLLVMOpBuilder
1 parent c8d05e1 commit 893a254

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -439,16 +439,15 @@ struct AsyncCopyGlobalToLocalOpConversion
439439
matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor,
440440
ConversionPatternRewriter &rewriter) const override {
441441

442-
MLIRContext *ctx = rewriter.getContext();
443442
auto loc = op.getLoc();
443+
auto b = TritonLLVMOpBuilder(loc, rewriter);
444444

445445
auto srcTy = op.getSrc().getType();
446446
auto srcEncoding = srcTy.getEncoding();
447447
assert((isa<BlockedEncodingAttr, SliceEncodingAttr>(srcEncoding) &&
448448
"Unexpected srcEncoding in AsyncCopyGlobalToLocalOpConversion"));
449-
auto srcShape = srcTy.getShape();
450-
assert(srcShape.size() <= 2 && "Async copy only supports 1d and 2d "
451-
"tensors: Unexpected rank of %src");
449+
assert(srcTy.getShape().size() <= 2 && "Async copy only supports 1d and 2d "
450+
"tensors: Unexpected rank of %src");
452451

453452
auto dstTy = op.getResult().getType();
454453
auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType());
@@ -479,7 +478,7 @@ struct AsyncCopyGlobalToLocalOpConversion
479478
shape, dstTy.getEncoding(), resElemTy.getIntOrFloatBitWidth());
480479
LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout);
481480

482-
auto kLane = str_attr("lane");
481+
StringAttr kLane = rewriter.getStringAttr("lane");
483482
for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) {
484483
auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0];
485484
unsigned expected = maxVec * (1 << inLane);
@@ -510,9 +509,9 @@ struct AsyncCopyGlobalToLocalOpConversion
510509

511510
int vecBytes = vecBits / 8;
512511
assert(llvm::isPowerOf2_32(vecBytes));
513-
Value vecBytesVal = i32_val(vecBytes);
512+
Value vecBytesVal = b.i32_val(vecBytes);
514513

515-
Value cacheModifiers = i32_val(
514+
Value cacheModifiers = b.i32_val(
516515
getCtrlBitsForCacheModifierOnTarget(op.getCache(), false, targetInfo));
517516

518517
Value llMask = adaptor.getMask();
@@ -535,7 +534,7 @@ struct AsyncCopyGlobalToLocalOpConversion
535534

536535
if (!mask) {
537536
rewriter.create<ROCDL::GlobalLoadLDSOp>(
538-
loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/i32_val(0),
537+
loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0),
539538
cacheModifiers);
540539
} else {
541540
Block *currentBlock = rewriter.getInsertionBlock();
@@ -546,8 +545,9 @@ struct AsyncCopyGlobalToLocalOpConversion
546545
rewriter.create<LLVM::CondBrOp>(loc, maskElems[srcIdx], loadBlock,
547546
afterLoad);
548547
rewriter.setInsertionPointToStart(loadBlock);
549-
rewriter.create<ROCDL::GlobalLoadLDSOp>(
550-
loc, srcPtr, shmemAddrs[i], vecBytesVal, i32_val(0), i32_val(0));
548+
rewriter.create<ROCDL::GlobalLoadLDSOp>(loc, srcPtr, shmemAddrs[i],
549+
vecBytesVal, b.i32_val(0),
550+
cacheModifiers);
551551

552552
rewriter.create<LLVM::BrOp>(loc, afterLoad);
553553
rewriter.setInsertionPointToStart(afterLoad);
@@ -556,7 +556,7 @@ struct AsyncCopyGlobalToLocalOpConversion
556556
packElementRangeIntoVector(rewriter, this->getTypeConverter(),
557557
loc, vecTy, otherElems, srcIdx);
558558
llStore(rewriter, loc, shmemAddrs[i], storeVal,
559-
icmp_ne(maskElems[srcIdx], true_val()), 0, op.getCache());
559+
b.icmp_ne(maskElems[srcIdx], b.true_val()), 0, op.getCache());
560560
}
561561
}
562562
}
@@ -1648,8 +1648,9 @@ struct AsyncWaitConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
16481648
ConversionPatternRewriter &rewriter) const override {
16491649

16501650
auto loc = op->getLoc();
1651+
auto b = TritonLLVMOpBuilder(loc, rewriter);
16511652
rewriter.create<ROCDL::WaitcntOp>(loc, op.getNum());
1652-
rewriter.replaceOp(op, i32_val(0));
1653+
rewriter.replaceOp(op, b.i32_val(0));
16531654
return success();
16541655
}
16551656
};
@@ -1669,7 +1670,8 @@ struct AsyncCommitGroupConversion
16691670
ConversionPatternRewriter &rewriter) const override {
16701671
// Drop the result token
16711672
auto loc = op->getLoc();
1672-
rewriter.replaceOp(op, i32_val(0));
1673+
auto b = TritonLLVMOpBuilder(loc, rewriter);
1674+
rewriter.replaceOp(op, b.i32_val(0));
16731675
return success();
16741676
}
16751677
};

0 commit comments

Comments
 (0)