@@ -439,16 +439,15 @@ struct AsyncCopyGlobalToLocalOpConversion
439439 matchAndRewrite (triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor,
440440 ConversionPatternRewriter &rewriter) const override {
441441
442- MLIRContext *ctx = rewriter.getContext ();
443442 auto loc = op.getLoc ();
443+ auto b = TritonLLVMOpBuilder (loc, rewriter);
444444
445445 auto srcTy = op.getSrc ().getType ();
446446 auto srcEncoding = srcTy.getEncoding ();
447447 assert ((isa<BlockedEncodingAttr, SliceEncodingAttr>(srcEncoding) &&
448448 " Unexpected srcEncoding in AsyncCopyGlobalToLocalOpConversion" ));
449- auto srcShape = srcTy.getShape ();
450- assert (srcShape.size () <= 2 && " Async copy only supports 1d and 2d "
451- " tensors: Unexpected rank of %src" );
449+ assert (srcTy.getShape ().size () <= 2 && " Async copy only supports 1d and 2d "
450+ " tensors: Unexpected rank of %src" );
452451
453452 auto dstTy = op.getResult ().getType ();
454453 auto resElemTy = getTypeConverter ()->convertType (dstTy.getElementType ());
@@ -479,7 +478,7 @@ struct AsyncCopyGlobalToLocalOpConversion
479478 shape, dstTy.getEncoding (), resElemTy.getIntOrFloatBitWidth ());
480479 LinearLayout srcToSharedLayout = srcLayout.invertAndCompose (sharedLayout);
481480
482- auto kLane = str_attr (" lane" );
481+ StringAttr kLane = rewriter. getStringAttr (" lane" );
483482 for (int inLane : llvm::seq (srcToSharedLayout.getInDimSizeLog2 (kLane ))) {
484483 auto basis = srcToSharedLayout.getBasis (kLane , inLane)[0 ];
485484 unsigned expected = maxVec * (1 << inLane);
@@ -510,9 +509,9 @@ struct AsyncCopyGlobalToLocalOpConversion
510509
511510 int vecBytes = vecBits / 8 ;
512511 assert (llvm::isPowerOf2_32 (vecBytes));
513- Value vecBytesVal = i32_val (vecBytes);
512+ Value vecBytesVal = b. i32_val (vecBytes);
514513
515- Value cacheModifiers = i32_val (
514+ Value cacheModifiers = b. i32_val (
516515 getCtrlBitsForCacheModifierOnTarget (op.getCache (), false , targetInfo));
517516
518517 Value llMask = adaptor.getMask ();
@@ -535,7 +534,7 @@ struct AsyncCopyGlobalToLocalOpConversion
535534
536535 if (!mask) {
537536 rewriter.create <ROCDL::GlobalLoadLDSOp>(
538- loc, srcPtr, shmemAddrs[i], vecBytesVal, /* offset=*/ i32_val (0 ),
537+ loc, srcPtr, shmemAddrs[i], vecBytesVal, /* offset=*/ b. i32_val (0 ),
539538 cacheModifiers);
540539 } else {
541540 Block *currentBlock = rewriter.getInsertionBlock ();
@@ -546,8 +545,9 @@ struct AsyncCopyGlobalToLocalOpConversion
546545 rewriter.create <LLVM::CondBrOp>(loc, maskElems[srcIdx], loadBlock,
547546 afterLoad);
548547 rewriter.setInsertionPointToStart (loadBlock);
549- rewriter.create <ROCDL::GlobalLoadLDSOp>(
550- loc, srcPtr, shmemAddrs[i], vecBytesVal, i32_val (0 ), i32_val (0 ));
548+ rewriter.create <ROCDL::GlobalLoadLDSOp>(loc, srcPtr, shmemAddrs[i],
549+ vecBytesVal, b.i32_val (0 ),
550+ cacheModifiers);
551551
552552 rewriter.create <LLVM::BrOp>(loc, afterLoad);
553553 rewriter.setInsertionPointToStart (afterLoad);
@@ -556,7 +556,7 @@ struct AsyncCopyGlobalToLocalOpConversion
556556 packElementRangeIntoVector (rewriter, this ->getTypeConverter (),
557557 loc, vecTy, otherElems, srcIdx);
558558 llStore (rewriter, loc, shmemAddrs[i], storeVal,
559- icmp_ne (maskElems[srcIdx], true_val ()), 0 , op.getCache ());
559+ b. icmp_ne (maskElems[srcIdx], b. true_val ()), 0 , op.getCache ());
560560 }
561561 }
562562 }
@@ -1648,8 +1648,9 @@ struct AsyncWaitConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
16481648 ConversionPatternRewriter &rewriter) const override {
16491649
16501650 auto loc = op->getLoc ();
1651+ auto b = TritonLLVMOpBuilder (loc, rewriter);
16511652 rewriter.create <ROCDL::WaitcntOp>(loc, op.getNum ());
1652- rewriter.replaceOp (op, i32_val (0 ));
1653+ rewriter.replaceOp (op, b. i32_val (0 ));
16531654 return success ();
16541655 }
16551656};
@@ -1669,7 +1670,8 @@ struct AsyncCommitGroupConversion
16691670 ConversionPatternRewriter &rewriter) const override {
16701671 // Drop the result token
16711672 auto loc = op->getLoc ();
1672- rewriter.replaceOp (op, i32_val (0 ));
1673+ auto b = TritonLLVMOpBuilder (loc, rewriter);
1674+ rewriter.replaceOp (op, b.i32_val (0 ));
16731675 return success ();
16741676 }
16751677};
0 commit comments