From ba17a9e89718700169e0d0db6e18d3bcac8ebf5f Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Fri, 17 Jan 2025 14:27:41 +0000 Subject: [PATCH 01/29] Basic lowering AsyncCommitGroup and AsyncWait --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 63 ++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 18f5cfc68abe..194daebe42e3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1459,6 +1459,65 @@ struct AtomicRMWOpConversion return endBlock->getArgument(0); } }; + +struct AsyncWaitConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + AsyncWaitConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // TODO Alex: correctly handle pending count + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.s.waitcnt", {}, + {i32_val(0)}); + + // Drop the result token. + Value zero = rewriter.create( + op.getLoc(), IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); + rewriter.replaceOp(op, zero); + return success(); + + return success(); + } +}; + +struct AsyncCommitGroupConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + AsyncCommitGroupConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(AsyncCommitGroupOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // TODO Alex: correctly handle pending count + // LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.s.waitcnt", + // {}, + // {i32_val(0)}); + + // Drop the result token. + Value zero = rewriter.create( + op.getLoc(), IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); + rewriter.replaceOp(op, zero); + return success(); + } +}; + } // namespace namespace mlir::triton::AMD { @@ -1470,7 +1529,9 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, PatternBenefit benefit) { patterns.add( + BufferStoreOpConversion, BufferAtomicRMWOpConversion, + AsyncLoadOpConversion, AsyncCommitGroupConversion, + AsyncWaitConversion, AsyncCommitGroupConversion>( typeConverter, targetInfo, axisInfoAnalysis, benefit); } } // namespace mlir::triton::AMD From 2327587d5d326114c1041de96517aa3526f1e00a Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Fri, 17 Jan 2025 14:28:03 +0000 Subject: [PATCH 02/29] WIP lowering of AsyncCopy --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 216 ++++++++++++++++++ 1 file changed, 216 insertions(+) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 194daebe42e3..72f34530f526 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -396,6 +396,222 @@ struct BufferLoadOpConversion } }; +struct AsyncLoadOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern< + triton::gpu::AsyncCopyGlobalToLocalOp>::ConvertOpToLLVMPattern; + + AsyncLoadOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + Value res = op.getResult(); + Value mask = op.getMask(); + Value other = op.getOther(); + // assert(!mask && "GlobalLoadToLDS with mask is not implemented yet!"); + // assert(!other && "GlobalLoadToLDS with other is not implemented yet!"); + auto funcOp = op->getParentOfType(); + + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getResult().getType(); + auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + auto srcLayout = srcTy.getEncoding(); + + assert((isa(srcLayout) && + "Unexpected srcLayout in AsyncCopyGlobalToLocalOpConversion")); + auto resSharedLayout = cast(dstTy.getEncoding()); + auto srcShape = srcTy.getShape(); + assert( + (srcShape.size() <= 2) && + "Async copy only supports 1d and 2d tensors: Unexpected rank of %src"); + + Value llDst = adaptor.getResult(); + Value llSrc = adaptor.getSrc(); + Value llMask = adaptor.getMask(); + Value llOther = adaptor.getOther(); + + // %src + auto srcElems = unpackLLElements(loc, llSrc, rewriter); + // llvm::outs() << "Src elems count: " << srcElems.size() << "\n"; + + // %dst + auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct( + loc, llDst, resElemTy, rewriter); + // %mask + SmallVector maskElems; + if (llMask) { + maskElems = unpackLLElements(loc, llMask, rewriter); + assert(srcElems.size() == maskElems.size()); + } + + // %other + SmallVector otherElems; + if (llOther) { + // assert(false && "Not implemented"); + // FIXME(Keren): assume other is 0 for now. + // + // It's not necessary for now because the pipeline pass will skip + // generating insert_slice_async if the load op has any "other" tensor. + otherElems = unpackLLElements(loc, llOther, rewriter); + assert(srcElems.size() == otherElems.size()); + } + + // We can load N elements at a time if: + // 1. Every group of N source pointers are contiguous. For example, if + // N=2, then the pointers should be [x, x+1, y, y+1, ...]. + unsigned vec = getVectorSize(op.getSrc()); + // llvm::outs() << "Vec size: " << vec << "\n"; + + unsigned maxVec = getContiguity(op.getSrc()); + // llvm::outs() << "Max vec: " << maxVec << "\n"; + if (mask) { + maxVec = std::min(maxVec, getMaskAlignment(mask)); + } + llvm::outs() << "Max Vec: " << maxVec << "\n"; + llvm::outs().flush(); + + // Addresses to store into, one per `vecTy`. + VectorType vecTy; + SmallVector shmemAddrs; + bool ok = emitTransferBetweenRegistersAndShared( + srcTy, dstTy, resElemTy, maxVec, smemObj, loc, rewriter, targetInfo, + [&](VectorType vecTy_, Value shmemAddr) { + vecTy = vecTy_; + shmemAddrs.push_back(shmemAddr); + }); + assert(ok); + llvm::outs() << "Shared to reg\n"; + llvm::outs().flush(); + + int vecBytes = vecTy.getNumElements() * vecTy.getElementTypeBitWidth() / 8; + assert(llvm::isPowerOf2_32(vecBytes)); + if (vecBytes < 4) { + return emitError( + loc, + "direct load to lds does not support transfers smaller than " + "4 bytes; calculated this as ") + << vecBytes << " bytes"; + } + if (vecBytes > 8) { + llvm::outs() << "here\n"; + // TODO we should probably emit a perf warning here + // return emitWarning( + // loc, + // "direct load to lds does not support transfers larger than " + // "8 bytes; calculated this as ") + // << vecBytes << " bytes, which means less bandwidth"; + llvm::outs() << "There\n"; + } + + // Value zr = rewriter.create( + // op.getLoc(), IntegerType::get(op.getContext(), 32), + // rewriter.getI32IntegerAttr(0)); + // Value two = rewriter.create( + // op.getLoc(), IntegerType::get(op.getContext(), 32), + // rewriter.getI32IntegerAttr(2)); + StringRef funcName = "llvm.amdgcn.global.load.lds"; + // Type funcType = getFunctionType( + // void_ty(getContext()), + // ValueRange({srcElems[0], smemObj.getBase(), zr, zr, zr})); + // LLVM::LLVMFuncOp llFuncOp = + // appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + + for (int i = 0; i < shmemAddrs.size(); i++) { + // It's possible that vecTy is larger than 128 bits, in which case we + // have to use multiple cp.async instructions. + int wordBytes = std::min(vecBytes, 4); + int wordElems = wordBytes * 8 / vecTy.getElementTypeBitWidth(); + int numWordsInVec = std::max(1, vecBytes / wordBytes); + llvm::outs() << "Create intrinsics\n"; + llvm::outs().flush(); + + if (wordBytes < 1 && wordBytes > 8) { + return emitError(loc, "[GlobalLoadToLDS] Unsupported load size of: " + + std::to_string(wordBytes) + "bytes"); + } + + // [LLVMQualPointerType<1>, // Base global pointer to load from + // LLVMQualPointerType<3>, // LDS base pointer to store to + // llvm_i32_ty, // Data byte size: 1/2/4 (/12/16 + // for gfx950) llvm_i32_ty, // imm offset (applied + // to both global and LDS address) llvm_i32_ty], // + // auxiliary data (imm, cachepolicy (bit 0 = sc0, + // // bit 1 = sc1, + // // bit 4 = scc)) + std::string intrinsic = "llvm.amdgcn.global.load.lds"; + llvm::outs() << "Wordbytes: " << wordBytes << "\n"; + Value loadWidth = rewriter.create( + op.getLoc(), rewriter.getI32Type(), + rewriter.getI32IntegerAttr(wordBytes)); + + Value sizeValue = rewriter.create( + op.getLoc(), IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(wordBytes)); + + for (int j = 0; j < numWordsInVec; j++) { + llvm::outs() << "Word bytes: " << wordBytes << "\n"; + llvm::outs() << "Word elems: " << wordElems << "\n"; + llvm::outs() << "Word nums : " << numWordsInVec << "\n"; + // Tune CG and CA. + // TODO Alex select correct cache modifier + // CacheModifier srcCacheModifier = + // wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA; + // assert(wordBytes == 16 || wordBytes == 8 || wordBytes == 4); + + Value offsetValue = rewriter.create( + op.getLoc(), IntegerType::get(op.getContext(), 32), + // rewriter.getI32IntegerAttr(j)); + rewriter.getI32IntegerAttr(j * wordBytes)); + + int elemIdx = i * vecTy.getNumElements() + j * wordElems; + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, {}, + {srcElems[0], shmemAddrs[0], loadWidth, + /*imm + offset=*/i32_val(0), i32_val(2)}); + // LLVM::createLLVMIntrinsicCallOp(rewriter, loc, + // "llvm.amdgcn.s.waitcnt", + // {}, {i32_val(0)}); + + // Block *currentBlock = rewriter.getInsertionBlock(); + // Block *afterLoad = + // rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + // Block *loadBlock = rewriter.createBlock(afterLoad); + // rewriter.setInsertionPointToEnd(currentBlock); + // rewriter.create(loc, maskElems[elemIdx], loadBlock, + // afterLoad); + // rewriter.setInsertionPointToStart(loadBlock); + // rewriter + // .create(loc, llFuncOp, + // ValueRange({srcElems[elemIdx], + // shmemAddrs[i], + // loadWidth, offsetValue, two})) + // .getResult(); + // rewriter.create(loc, afterLoad); + // rewriter.setInsertionPointToStart(afterLoad); + + llvm::outs() << "Word nums : " << numWordsInVec << "\n"; + } + } + + // Drop the result token. + Value zero = rewriter.create( + op.getLoc(), IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); + rewriter.replaceOp(op, zero); + return success(); + } +}; + struct StoreOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; From 1d4edf6d6a5a79473a33436e9776c219b14f3368 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Fri, 24 Jan 2025 16:58:14 +0000 Subject: [PATCH 03/29] Added layout checks for asynccopy lowering --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 289 +++++++++--------- 1 file changed, 137 insertions(+), 152 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 72f34530f526..b17bbbeed09b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -20,9 +20,11 @@ using namespace mlir::triton::gpu; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::getSharedMemoryBase; +using ::mlir::LLVM::AMD::getContiguity; using ::mlir::LLVM::AMD::getVectorSize; using ::mlir::LLVM::AMD::llLoad; using ::mlir::LLVM::AMD::llStore; +using ::mlir::triton::AMD::ISAFamily; using ::mlir::triton::gpu::getTotalElemsPerThread; namespace { @@ -414,13 +416,15 @@ struct AsyncLoadOpConversion matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = rewriter.getContext(); auto loc = op.getLoc(); Value res = op.getResult(); Value mask = op.getMask(); Value other = op.getOther(); - // assert(!mask && "GlobalLoadToLDS with mask is not implemented yet!"); - // assert(!other && "GlobalLoadToLDS with other is not implemented yet!"); - auto funcOp = op->getParentOfType(); + if (other) { + return emitError(loc, "ttg.AsyncLoad does not support other values use " + "tt.load/store instead"); + } auto srcTy = op.getSrc().getType(); auto dstTy = op.getResult().getType(); @@ -442,7 +446,6 @@ struct AsyncLoadOpConversion // %src auto srcElems = unpackLLElements(loc, llSrc, rewriter); - // llvm::outs() << "Src elems count: " << srcElems.size() << "\n"; // %dst auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct( @@ -454,37 +457,73 @@ struct AsyncLoadOpConversion assert(srcElems.size() == maskElems.size()); } - // %other - SmallVector otherElems; - if (llOther) { - // assert(false && "Not implemented"); - // FIXME(Keren): assume other is 0 for now. - // - // It's not necessary for now because the pipeline pass will skip - // generating insert_slice_async if the load op has any "other" tensor. - otherElems = unpackLLElements(loc, llOther, rewriter); - assert(srcElems.size() == otherElems.size()); - } - - // We can load N elements at a time if: - // 1. Every group of N source pointers are contiguous. For example, if - // N=2, then the pointers should be [x, x+1, y, y+1, ...]. - unsigned vec = getVectorSize(op.getSrc()); - // llvm::outs() << "Vec size: " << vec << "\n"; + // global.load.lds has a shared dst register so we cannot have per thread + // offsets This means our load size has to align with the load_width of - unsigned maxVec = getContiguity(op.getSrc()); - // llvm::outs() << "Max vec: " << maxVec << "\n"; + unsigned maxVec = getContiguity(op.getSrc(), axisAnalysisPass); if (mask) { maxVec = std::min(maxVec, getMaskAlignment(mask)); } - llvm::outs() << "Max Vec: " << maxVec << "\n"; - llvm::outs().flush(); + + llvm::SmallSetVector supportedLoadBits; + switch (targetInfo.getISAFamily()) { + case mlir::triton::AMD::ISAFamily::CDNA3: + supportedLoadBits.insert(8); + supportedLoadBits.insert(16); + supportedLoadBits.insert(32); + break; + case mlir::triton::AMD::ISAFamily::CDNA4: + supportedLoadBits.insert(8); + supportedLoadBits.insert(16); + supportedLoadBits.insert(32); + supportedLoadBits.insert(98); + supportedLoadBits.insert(128); + break; + default: + return emitError(loc, "Async copy not supported on target ISA"); + } + + unsigned int loadStoreBitWidth = maxVec * resElemTy.getIntOrFloatBitWidth(); + + if (!supportedLoadBits.contains(loadStoreBitWidth)) { + return emitError(loc, "Async copy does not supported the required load " + "vectorization, got ") + << loadStoreBitWidth << "bits"; + } + + { + + auto shape = dstTy.getShape(); + LinearLayout regLayout = + triton::gpu::toLinearLayout(shape, srcTy.getEncoding()); + LinearLayout sharedLayout = triton::gpu::toLinearLayout( + shape, dstTy.getEncoding(), resElemTy.getIntOrFloatBitWidth()); + LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout); + + // We need to check if the lane basis is contigeous because + // global.load.lds does not support per lane offset + auto kLane = str_attr("lane"); + auto kBlock = str_attr("block"); + auto kWarp = str_attr("warp"); + auto kRegister = str_attr("register"); + + for (int inLane : llvm::seq(regToSharedLayout.getInDimSize(kLane))) { + auto idx = regToSharedLayout.apply( + {{kRegister, 0}, {kLane, inLane}, {kWarp, 0}, {kBlock, 0}}); + int32_t offset = idx[0].second; + if (offset != (inLane * maxVec)) { + return emitError(loc, "Invalid layout in AsyncCopy: ") + << "Lane: " << inLane << " is " << offset << " should be " + << inLane << "\n"; + } + } + } // Addresses to store into, one per `vecTy`. VectorType vecTy; SmallVector shmemAddrs; bool ok = emitTransferBetweenRegistersAndShared( - srcTy, dstTy, resElemTy, maxVec, smemObj, loc, rewriter, targetInfo, + srcTy, dstTy, resElemTy, {}, smemObj, loc, rewriter, targetInfo, [&](VectorType vecTy_, Value shmemAddr) { vecTy = vecTy_; shmemAddrs.push_back(shmemAddr); @@ -495,112 +534,77 @@ struct AsyncLoadOpConversion int vecBytes = vecTy.getNumElements() * vecTy.getElementTypeBitWidth() / 8; assert(llvm::isPowerOf2_32(vecBytes)); - if (vecBytes < 4) { - return emitError( - loc, - "direct load to lds does not support transfers smaller than " - "4 bytes; calculated this as ") - << vecBytes << " bytes"; - } - if (vecBytes > 8) { - llvm::outs() << "here\n"; - // TODO we should probably emit a perf warning here - // return emitWarning( - // loc, - // "direct load to lds does not support transfers larger than " - // "8 bytes; calculated this as ") - // << vecBytes << " bytes, which means less bandwidth"; - llvm::outs() << "There\n"; - } - // Value zr = rewriter.create( - // op.getLoc(), IntegerType::get(op.getContext(), 32), - // rewriter.getI32IntegerAttr(0)); - // Value two = rewriter.create( - // op.getLoc(), IntegerType::get(op.getContext(), 32), - // rewriter.getI32IntegerAttr(2)); - StringRef funcName = "llvm.amdgcn.global.load.lds"; - // Type funcType = getFunctionType( - // void_ty(getContext()), - // ValueRange({srcElems[0], smemObj.getBase(), zr, zr, zr})); - // LLVM::LLVMFuncOp llFuncOp = - // appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + std::string intrinsic = "llvm.amdgcn.global.load.lds"; + Value loadStoreByteWidthVal = i32_val(loadStoreBitWidth / 8); + llvm::outs() << "Load byte width: " << loadStoreByteWidthVal << "\n"; + llvm::outs() << "Shem addr count: " << shmemAddrs.size() << "\n"; for (int i = 0; i < shmemAddrs.size(); i++) { - // It's possible that vecTy is larger than 128 bits, in which case we - // have to use multiple cp.async instructions. - int wordBytes = std::min(vecBytes, 4); - int wordElems = wordBytes * 8 / vecTy.getElementTypeBitWidth(); - int numWordsInVec = std::max(1, vecBytes / wordBytes); - llvm::outs() << "Create intrinsics\n"; - llvm::outs().flush(); - - if (wordBytes < 1 && wordBytes > 8) { - return emitError(loc, "[GlobalLoadToLDS] Unsupported load size of: " + - std::to_string(wordBytes) + "bytes"); - } - - // [LLVMQualPointerType<1>, // Base global pointer to load from - // LLVMQualPointerType<3>, // LDS base pointer to store to - // llvm_i32_ty, // Data byte size: 1/2/4 (/12/16 - // for gfx950) llvm_i32_ty, // imm offset (applied - // to both global and LDS address) llvm_i32_ty], // - // auxiliary data (imm, cachepolicy (bit 0 = sc0, - // // bit 1 = sc1, - // // bit 4 = scc)) - std::string intrinsic = "llvm.amdgcn.global.load.lds"; - llvm::outs() << "Wordbytes: " << wordBytes << "\n"; - Value loadWidth = rewriter.create( - op.getLoc(), rewriter.getI32Type(), - rewriter.getI32IntegerAttr(wordBytes)); - - Value sizeValue = rewriter.create( - op.getLoc(), IntegerType::get(op.getContext(), 32), - rewriter.getI32IntegerAttr(wordBytes)); - - for (int j = 0; j < numWordsInVec; j++) { - llvm::outs() << "Word bytes: " << wordBytes << "\n"; - llvm::outs() << "Word elems: " << wordElems << "\n"; - llvm::outs() << "Word nums : " << numWordsInVec << "\n"; - // Tune CG and CA. - // TODO Alex select correct cache modifier - // CacheModifier srcCacheModifier = - // wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA; - // assert(wordBytes == 16 || wordBytes == 8 || wordBytes == 4); - - Value offsetValue = rewriter.create( - op.getLoc(), IntegerType::get(op.getContext(), 32), - // rewriter.getI32IntegerAttr(j)); - rewriter.getI32IntegerAttr(j * wordBytes)); - - int elemIdx = i * vecTy.getNumElements() + j * wordElems; - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, {}, - {srcElems[0], shmemAddrs[0], loadWidth, - /*imm - offset=*/i32_val(0), i32_val(2)}); - // LLVM::createLLVMIntrinsicCallOp(rewriter, loc, - // "llvm.amdgcn.s.waitcnt", - // {}, {i32_val(0)}); - - // Block *currentBlock = rewriter.getInsertionBlock(); - // Block *afterLoad = - // rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); - // Block *loadBlock = rewriter.createBlock(afterLoad); - // rewriter.setInsertionPointToEnd(currentBlock); - // rewriter.create(loc, maskElems[elemIdx], loadBlock, - // afterLoad); - // rewriter.setInsertionPointToStart(loadBlock); - // rewriter - // .create(loc, llFuncOp, - // ValueRange({srcElems[elemIdx], - // shmemAddrs[i], - // loadWidth, offsetValue, two})) - // .getResult(); - // rewriter.create(loc, afterLoad); - // rewriter.setInsertionPointToStart(afterLoad); - - llvm::outs() << "Word nums : " << numWordsInVec << "\n"; - } + // Tune CG and CA. + // TODO Alex select correct cache modifier + // CacheModifier srcCacheModifier = + // wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA; + // assert(wordBytes == 16 || wordBytes == 8 || wordBytes == 4); + + auto srcPtr = srcElems[i * maxVec]; + + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, {}, + {srcPtr, shmemAddrs[i], + loadStoreByteWidthVal, + /*imm + offset=*/i32_val(0), i32_val(0)}); + + // bool useIntrinsic = true; + // llvm::outs() << "Use intrinsics: " << useIntrinsic << "\n"; + + // auto basePtr = getPtrFromFirstLane(targetInfo, dstPtr); + // Value threadId = tid_val(); + // Value laneId = urem(threadId, i32_val(64)); + // Value offset = mul(laneId, i32_val(1)); + // Value dstPtrWithOffset = gep(elemPtrTy, resElemTy, basePtr, offset); + + // Build blocks to bypass the global.load.lds + // auto *curBlock = rewriter.getInsertionBlock(); + // auto *endBlock = + // curBlock->splitBlock(rewriter.getInsertionPoint()); auto + // *atomicBlock = rewriter.createBlock( + // curBlock->getParent(), std::next(Region::iterator(curBlock))); + + // // Fill entry block with global memory barrier and conditional + // branch. rewriter.setInsertionPointToEnd(curBlock); auto tid = + // tid_val(); Value pred = icmp_eq(tid, i32_val(i)); + // rewriter.create(loc, pred, atomicBlock, endBlock); + + // Build main block with atomic_cmpxchg. + // rewriter.setInsertionPointToEnd(atomicBlock); + // Value l = load(smemObj.getBaseElemType(), dstPtrWithOffset); + // store(l, srcPtr); + // LLVM::createLLVMIntrinsicCallOp( + // rewriter, loc, "llvm.amdgcn.s.waitcnt", {}, {i32_val(0)}); + // LLVM::createLLVMIntrinsicCallOp(rewriter, loc, + // "llvm.amdgcn.wave.barrier", {}, + // {}); + // LLVM::createLLVMIntrinsicCallOp(rewriter, loc, + // "llvm.amdgcn.s.waitcnt", + // {}, {i32_val(0)}); + + // Block *currentBlock = rewriter.getInsertionBlock(); + // Block *afterLoad = + // rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + // Block *loadBlock = rewriter.createBlock(afterLoad); + // rewriter.setInsertionPointToEnd(currentBlock); + // rewriter.create(loc, maskElems[elemIdx], loadBlock, + // afterLoad); + // rewriter.setInsertionPointToStart(loadBlock); + // rewriter + // .create(loc, llFuncOp, + // ValueRange({srcElems[elemIdx], + // shmemAddrs[i], + // loadWidth, offsetValue, two})) + // .getResult(); + // rewriter.create(loc, afterLoad); + // rewriter.setInsertionPointToStart(afterLoad); } // Drop the result token. @@ -1688,19 +1692,9 @@ struct AsyncWaitConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // MemBar already added the barrier for us so we can dimply drop it auto loc = op->getLoc(); - - // TODO Alex: correctly handle pending count - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.s.waitcnt", {}, - {i32_val(0)}); - - // Drop the result token. - Value zero = rewriter.create( - op.getLoc(), IntegerType::get(op.getContext(), 32), - rewriter.getI32IntegerAttr(0)); - rewriter.replaceOp(op, zero); - return success(); - + rewriter.replaceOp(op, i32_val(0)); return success(); } }; @@ -1718,18 +1712,9 @@ struct AsyncCommitGroupConversion LogicalResult matchAndRewrite(AsyncCommitGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // We do not have that concept so simply drop it auto loc = op->getLoc(); - - // TODO Alex: correctly handle pending count - // LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.s.waitcnt", - // {}, - // {i32_val(0)}); - - // Drop the result token. - Value zero = rewriter.create( - op.getLoc(), IntegerType::get(op.getContext(), 32), - rewriter.getI32IntegerAttr(0)); - rewriter.replaceOp(op, zero); + rewriter.replaceOp(op, i32_val(0)); return success(); } }; From ead4915fed620ed52562aed1e598d4c007e010c2 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Mon, 27 Jan 2025 11:37:26 +0000 Subject: [PATCH 04/29] Support direct to lds --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 14 +-- .../TritonAMDGPUTransforms/StreamPipeline.cpp | 102 ++++++++++++++++-- 2 files changed, 100 insertions(+), 16 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index b17bbbeed09b..a26412c74ae0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -466,18 +466,16 @@ struct AsyncLoadOpConversion } llvm::SmallSetVector supportedLoadBits; + // TODO look up if we support it on mi200 switch (targetInfo.getISAFamily()) { case mlir::triton::AMD::ISAFamily::CDNA3: supportedLoadBits.insert(8); supportedLoadBits.insert(16); supportedLoadBits.insert(32); - break; - case mlir::triton::AMD::ISAFamily::CDNA4: - supportedLoadBits.insert(8); - supportedLoadBits.insert(16); - supportedLoadBits.insert(32); - supportedLoadBits.insert(98); - supportedLoadBits.insert(128); + if (targetInfo.getGPUKind() == llvm::AMDGPU::GPUKind::GK_GFX950) { + supportedLoadBits.insert(98); + supportedLoadBits.insert(128); + } break; default: return emitError(loc, "Async copy not supported on target ISA"); @@ -499,6 +497,8 @@ struct AsyncLoadOpConversion LinearLayout sharedLayout = triton::gpu::toLinearLayout( shape, dstTy.getEncoding(), resElemTy.getIntOrFloatBitWidth()); LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout); + llvm::outs() << "Reg to shared: \n" + << regToSharedLayout.toString() << "\n"; // We need to check if the lane basis is contigeous because // global.load.lds does not support per lane offset diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index b53d4b55421c..1732e21fd7d1 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -257,12 +257,35 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, Value other = loadOp.getOther(); ttg::MemDescType allocTy = cast(alloc.getType()); + + auto sharedEncodingAttr = + cast(allocTy.getEncoding()); + llvm::outs() << "Shared alloc: \n"; + alloc.print(llvm::outs()); + llvm::outs() << "\n"; + + bool emitAsyncCopy = false; + + auto srcTy = dyn_cast(src.getType()); + // We can use AsyncCopy if we do not swizzle into smem + // TODO (alex) ensure it's 2D + if (sharedEncodingAttr.getPerPhase() == 1 && + sharedEncodingAttr.getMaxPhase() == 1 && + llvm::equal(sharedEncodingAttr.getOrder(), + ttg::getOrder(srcTy.getEncoding()))) { + emitAsyncCopy = true; + } + llvm::outs() << "Emit async: " << emitAsyncCopy << "\n"; + SmallVector copyOffsets(allocTy.getRank(), zero); - Operation *copy = builder.clone(*loadOp); - auto [stage, cluster] = schedule[loadOp]; - schedule.erase(loadOp); - schedule.insert(copy, stage, cluster); + Operation *newLoadOp{}; + if (!emitAsyncCopy) { + newLoadOp = builder.clone(*loadOp); + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(newLoadOp, stage, cluster); + } // Extract part. SmallVector loadOffsets(allocTy.getRank(), zero); @@ -274,6 +297,58 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); + + if (emitAsyncCopy) { + auto srcTy = dyn_cast(src.getType()); + if (!srcTy) { + llvm::outs() << "INVALID SRC!\n"; + } + // We need to ensure we read coalesced into LDS so we adjust the blocked to + // read coalesced for now + + auto shape = subviewTy.getShape(); + auto order = sharedEncodingAttr.getOrder(); + // Aim to use wider loads + llvm::SmallVector sizePerThread{1, 1}; + sizePerThread[order[0]] = + 32 / allocTy.getElementType().getIntOrFloatBitWidth(); + llvm::SmallVector threadsPerWarp{1, 1}; + assert((shape[order[0]] % sizePerThread[0]) == 0); + threadsPerWarp[order[0]] = shape[order[0]] / sizePerThread[order[0]]; + unsigned warpSize = 64; + threadsPerWarp[order[1]] = + std::max(1, warpSize / threadsPerWarp[order[0]]); + + auto srcEncoding = srcTy.getEncoding(); + auto newLayout = ttg::BlockedEncodingAttr::get( + loadOp->getContext(), + sizePerThread, //{1, 1}, // triton::gpu::getSizePerThread(srcEncoding), + threadsPerWarp, //{2, 32}, // + // triton::gpu::getThreadsPerWarp(srcEncoding), + triton::gpu::getWarpsPerCTA(srcEncoding), + triton::gpu::getOrder(srcEncoding), + triton::gpu::getCTALayout(srcEncoding)); + llvm::outs() << "New src encoding: "; + newLayout.printStripped(llvm::outs()); + llvm::outs() << "\n"; + RankedTensorType newArgType = RankedTensorType::get( + srcTy.getShape(), srcTy.getElementType(), newLayout); + llvm::outs() << "Source encoding: "; + srcTy.getEncoding().print(llvm::outs()); + llvm::outs() << "\n"; + auto cvt = + builder.create(loadOp.getLoc(), newArgType, src); + + newLoadOp = builder.create( + loadOp.getLoc(), cvt.getResult(), viewLoad, mask, other, + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(cvt, stage, cluster); + schedule.insert(newLoadOp, stage, cluster); + } + // Clean up old local caches. SmallVector allocsToErase; for (Operation *user : loadOp->getUsers()) { @@ -286,10 +361,15 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, alloc.erase(); // Prefetch load ahead of the dot stage if is used by the dot. - auto storeOp = - builder.create(loc, copy->getResult(0), viewLoad); - scheduleOp(viewLoad, SCHED_LOCAL_STORE); - scheduleOp(storeOp, SCHED_LOCAL_STORE); + Operation *storeOp; + if (emitAsyncCopy) { + scheduleOp(newLoadOp, SCHED_LOCAL_STORE); + } else { + storeOp = builder.create(loc, newLoadOp->getResult(0), + viewLoad); + scheduleOp(viewLoad, SCHED_LOCAL_STORE); + scheduleOp(storeOp, SCHED_LOCAL_STORE); + } // Create local load auto sharedLoad = @@ -304,7 +384,11 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, // instruction scheduling hints to correctly count the emitted `ds_write` // instructions for each GEMM tile. if (auto attr = loadOp->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())) { - storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); + if (emitAsyncCopy) { + newLoadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); + } else { + storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); + } } loadOp->replaceAllUsesWith(ValueRange{result}); From 3141ba47da5bac97dfd1009770411ce3f7a93a1e Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Mon, 27 Jan 2025 14:26:38 +0000 Subject: [PATCH 05/29] Enable non working masking --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 4 ---- .../TritonAMDGPUTransforms/StreamPipeline.cpp | 18 ++++++++++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index a26412c74ae0..be4ae539899a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -421,10 +421,6 @@ struct AsyncLoadOpConversion Value res = op.getResult(); Value mask = op.getMask(); Value other = op.getOther(); - if (other) { - return emitError(loc, "ttg.AsyncLoad does not support other values use " - "tt.load/store instead"); - } auto srcTy = op.getSrc().getType(); auto dstTy = op.getResult().getType(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 1732e21fd7d1..fc32671dd894 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -314,8 +314,9 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, 32 / allocTy.getElementType().getIntOrFloatBitWidth(); llvm::SmallVector threadsPerWarp{1, 1}; assert((shape[order[0]] % sizePerThread[0]) == 0); - threadsPerWarp[order[0]] = shape[order[0]] / sizePerThread[order[0]]; unsigned warpSize = 64; + threadsPerWarp[order[0]] = + std::min(warpSize, shape[order[0]] / sizePerThread[order[0]]); threadsPerWarp[order[1]] = std::max(1, warpSize / threadsPerWarp[order[0]]); @@ -336,16 +337,25 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, llvm::outs() << "Source encoding: "; srcTy.getEncoding().print(llvm::outs()); llvm::outs() << "\n"; - auto cvt = + auto cvtSrc = builder.create(loadOp.getLoc(), newArgType, src); + auto maskTy = + dyn_cast(loadOp.getMask().getType()); + RankedTensorType newMaskTy = RankedTensorType::get( + maskTy.getShape(), maskTy.getElementType(), newLayout); + RankedTensorType newMaskType = RankedTensorType::get( + allocTy.getShape(), srcTy.getElementType(), newLayout); + auto cvtMask = builder.create( + loadOp->getLoc(), newMaskTy, loadOp.getMask()); + newLoadOp = builder.create( - loadOp.getLoc(), cvt.getResult(), viewLoad, mask, other, + loadOp.getLoc(), cvtSrc.getResult(), viewLoad, cvtMask, other, loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); auto [stage, cluster] = schedule[loadOp]; schedule.erase(loadOp); - schedule.insert(cvt, stage, cluster); + schedule.insert(cvtSrc, stage, cluster); schedule.insert(newLoadOp, stage, cluster); } From 644aa1eeec933137b32122cd9338472231530d91 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 09:43:09 +0000 Subject: [PATCH 06/29] Add support to enable disable direct to lds with env var AMDGCN_USE_DIRECT_TO_LDS --- include/triton/Tools/Sys/GetEnv.hpp | 1 + third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index b11d90be436d..cd36a473333a 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -14,6 +14,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { // clang-format off "AMDGCN_ENABLE_DUMP", "AMDGCN_USE_BUFFER_OPS", + "AMDGCN_USE_DIRECT_TO_LDS", "DISABLE_FAST_REDUCTION", "DISABLE_LLVM_OPT", "DISABLE_MMA_V3", diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index fc32671dd894..dd0a443a3c26 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -10,6 +10,7 @@ #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/Support/Debug.h" //===----------------------------------------------------------------------===// @@ -269,7 +270,8 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, auto srcTy = dyn_cast(src.getType()); // We can use AsyncCopy if we do not swizzle into smem // TODO (alex) ensure it's 2D - if (sharedEncodingAttr.getPerPhase() == 1 && + if (triton::tools::getBoolEnv("AMDGCN_USE_DIRECT_TO_LDS") && + sharedEncodingAttr.getPerPhase() == 1 && sharedEncodingAttr.getMaxPhase() == 1 && llvm::equal(sharedEncodingAttr.getOrder(), ttg::getOrder(srcTy.getEncoding()))) { From 7c9bab147dc863838d4a8deef72e4c8ab9de6a3c Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 09:44:42 +0000 Subject: [PATCH 07/29] Fix masking and others for direct to lds --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 78 +++++++------------ .../TritonAMDGPUTransforms/StreamPipeline.cpp | 18 ++--- 2 files changed, 35 insertions(+), 61 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index be4ae539899a..ca22d1e5c12a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -453,11 +453,20 @@ struct AsyncLoadOpConversion assert(srcElems.size() == maskElems.size()); } + SmallVector otherElems; + if (llOther) { + otherElems = unpackLLElements(loc, llOther, rewriter); + assert(srcElems.size() == otherElems.size()); + } + + // TODO check maxVec with mask alignment! + // global.load.lds has a shared dst register so we cannot have per thread // offsets This means our load size has to align with the load_width of unsigned maxVec = getContiguity(op.getSrc(), axisAnalysisPass); if (mask) { + // TODO, if this changes maxVec we cannot use global.load.lds? maxVec = std::min(maxVec, getMaskAlignment(mask)); } @@ -543,64 +552,29 @@ struct AsyncLoadOpConversion // wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA; // assert(wordBytes == 16 || wordBytes == 8 || wordBytes == 4); - auto srcPtr = srcElems[i * maxVec]; - + auto srcIdx = i * maxVec; + auto srcPtr = srcElems[srcIdx]; + + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterLoad = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *loadBlock = rewriter.createBlock(afterLoad); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, maskElems[srcIdx], loadBlock, + afterLoad); + rewriter.setInsertionPointToStart(loadBlock); LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, {}, {srcPtr, shmemAddrs[i], loadStoreByteWidthVal, /*imm offset=*/i32_val(0), i32_val(0)}); + rewriter.create(loc, afterLoad); + rewriter.setInsertionPointToStart(afterLoad); - // bool useIntrinsic = true; - // llvm::outs() << "Use intrinsics: " << useIntrinsic << "\n"; - - // auto basePtr = getPtrFromFirstLane(targetInfo, dstPtr); - // Value threadId = tid_val(); - // Value laneId = urem(threadId, i32_val(64)); - // Value offset = mul(laneId, i32_val(1)); - // Value dstPtrWithOffset = gep(elemPtrTy, resElemTy, basePtr, offset); - - // Build blocks to bypass the global.load.lds - // auto *curBlock = rewriter.getInsertionBlock(); - // auto *endBlock = - // curBlock->splitBlock(rewriter.getInsertionPoint()); auto - // *atomicBlock = rewriter.createBlock( - // curBlock->getParent(), std::next(Region::iterator(curBlock))); - - // // Fill entry block with global memory barrier and conditional - // branch. rewriter.setInsertionPointToEnd(curBlock); auto tid = - // tid_val(); Value pred = icmp_eq(tid, i32_val(i)); - // rewriter.create(loc, pred, atomicBlock, endBlock); - - // Build main block with atomic_cmpxchg. - // rewriter.setInsertionPointToEnd(atomicBlock); - // Value l = load(smemObj.getBaseElemType(), dstPtrWithOffset); - // store(l, srcPtr); - // LLVM::createLLVMIntrinsicCallOp( - // rewriter, loc, "llvm.amdgcn.s.waitcnt", {}, {i32_val(0)}); - // LLVM::createLLVMIntrinsicCallOp(rewriter, loc, - // "llvm.amdgcn.wave.barrier", {}, - // {}); - // LLVM::createLLVMIntrinsicCallOp(rewriter, loc, - // "llvm.amdgcn.s.waitcnt", - // {}, {i32_val(0)}); - - // Block *currentBlock = rewriter.getInsertionBlock(); - // Block *afterLoad = - // rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); - // Block *loadBlock = rewriter.createBlock(afterLoad); - // rewriter.setInsertionPointToEnd(currentBlock); - // rewriter.create(loc, maskElems[elemIdx], loadBlock, - // afterLoad); - // rewriter.setInsertionPointToStart(loadBlock); - // rewriter - // .create(loc, llFuncOp, - // ValueRange({srcElems[elemIdx], - // shmemAddrs[i], - // loadWidth, offsetValue, two})) - // .getResult(); - // rewriter.create(loc, afterLoad); - // rewriter.setInsertionPointToStart(afterLoad); + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, vecTy, otherElems, srcIdx); + llStore(rewriter, loc, shmemAddrs[i], storeVal, + icmp_ne(maskElems[srcIdx], true_val())); } // Drop the result token. diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index dd0a443a3c26..efa6016416bf 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -342,17 +342,17 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, auto cvtSrc = builder.create(loadOp.getLoc(), newArgType, src); - auto maskTy = - dyn_cast(loadOp.getMask().getType()); - RankedTensorType newMaskTy = RankedTensorType::get( - maskTy.getShape(), maskTy.getElementType(), newLayout); - RankedTensorType newMaskType = RankedTensorType::get( - allocTy.getShape(), srcTy.getElementType(), newLayout); - auto cvtMask = builder.create( - loadOp->getLoc(), newMaskTy, loadOp.getMask()); + auto mask = loadOp.getMask(); + if (mask) { + auto maskTy = dyn_cast(mask.getType()); + RankedTensorType newMaskTy = RankedTensorType::get( + maskTy.getShape(), maskTy.getElementType(), newLayout); + auto cvtMask = builder.create( + loadOp->getLoc(), newMaskTy, loadOp.getMask()); + } newLoadOp = builder.create( - loadOp.getLoc(), cvtSrc.getResult(), viewLoad, cvtMask, other, + loadOp.getLoc(), cvtSrc.getResult(), viewLoad, mask, other, loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); auto [stage, cluster] = schedule[loadOp]; From cb823d05334c66764095fcb120a9b9ddb4c3a30e Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 12:03:58 +0000 Subject: [PATCH 08/29] Fix when AsycCopy is lowered without a mask --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 56 ++++++++++++------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index ca22d1e5c12a..78b001bfb91f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -555,26 +555,42 @@ struct AsyncLoadOpConversion auto srcIdx = i * maxVec; auto srcPtr = srcElems[srcIdx]; - Block *currentBlock = rewriter.getInsertionBlock(); - Block *afterLoad = - rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); - Block *loadBlock = rewriter.createBlock(afterLoad); - rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, maskElems[srcIdx], loadBlock, - afterLoad); - rewriter.setInsertionPointToStart(loadBlock); - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, {}, - {srcPtr, shmemAddrs[i], - loadStoreByteWidthVal, - /*imm - offset=*/i32_val(0), i32_val(0)}); - rewriter.create(loc, afterLoad); - rewriter.setInsertionPointToStart(afterLoad); - - Value storeVal = packElementRangeIntoVector( - rewriter, this->getTypeConverter(), loc, vecTy, otherElems, srcIdx); - llStore(rewriter, loc, shmemAddrs[i], storeVal, - icmp_ne(maskElems[srcIdx], true_val())); + if (mask) { + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterLoad = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *loadBlock = rewriter.createBlock(afterLoad); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, maskElems[srcIdx], loadBlock, + afterLoad); + rewriter.setInsertionPointToStart(loadBlock); + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, {}, + {srcPtr, shmemAddrs[i], + loadStoreByteWidthVal, + /*imm + offset=*/i32_val(0), i32_val(0)}); + rewriter.create(loc, afterLoad); + rewriter.setInsertionPointToStart(afterLoad); + if (other) { + Value storeVal = + packElementRangeIntoVector(rewriter, this->getTypeConverter(), + loc, vecTy, otherElems, srcIdx); + llStore(rewriter, loc, shmemAddrs[i], storeVal, + icmp_ne(maskElems[srcIdx], true_val())); + } + } else { + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, {}, + {srcPtr, shmemAddrs[i], + loadStoreByteWidthVal, + /*imm + offset=*/i32_val(0), i32_val(0)}); + + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, {}, + {srcPtr, shmemAddrs[i], + loadStoreByteWidthVal, + /*imm + offset=*/i32_val(0), i32_val(0)}); + } } // Drop the result token. From c097616c34de57de26cfe5f08644d244b8700063 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 12:28:51 +0000 Subject: [PATCH 09/29] Use ROCDL instead of intrinsics --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 26 +++++++------------ .../TritonAMDGPUTransforms/StreamPipeline.cpp | 21 ++++++++++++--- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 78b001bfb91f..5757593584a1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -564,11 +564,10 @@ struct AsyncLoadOpConversion rewriter.create(loc, maskElems[srcIdx], loadBlock, afterLoad); rewriter.setInsertionPointToStart(loadBlock); - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, {}, - {srcPtr, shmemAddrs[i], - loadStoreByteWidthVal, - /*imm - offset=*/i32_val(0), i32_val(0)}); + rewriter.create(loc, srcPtr, shmemAddrs[i], + loadStoreByteWidthVal, + i32_val(0), i32_val(0)); + rewriter.create(loc, afterLoad); rewriter.setInsertionPointToStart(afterLoad); if (other) { @@ -579,17 +578,9 @@ struct AsyncLoadOpConversion icmp_ne(maskElems[srcIdx], true_val())); } } else { - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, {}, - {srcPtr, shmemAddrs[i], - loadStoreByteWidthVal, - /*imm - offset=*/i32_val(0), i32_val(0)}); - - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, {}, - {srcPtr, shmemAddrs[i], - loadStoreByteWidthVal, - /*imm - offset=*/i32_val(0), i32_val(0)}); + rewriter.create(loc, srcPtr, shmemAddrs[i], + loadStoreByteWidthVal, + i32_val(0), i32_val(0)); } } @@ -1678,8 +1669,9 @@ struct AsyncWaitConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // MemBar already added the barrier for us so we can dimply drop it + auto loc = op->getLoc(); + rewriter.create(loc, 0); rewriter.replaceOp(op, i32_val(0)); return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index efa6016416bf..c01bd90e83dc 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -300,6 +300,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); + Operation *wait{}; if (emitAsyncCopy) { auto srcTy = dyn_cast(src.getType()); if (!srcTy) { @@ -355,6 +356,8 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, loadOp.getLoc(), cvtSrc.getResult(), viewLoad, mask, other, loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + wait = builder.create(loc, newLoadOp->getResult(0), 0); + auto [stage, cluster] = schedule[loadOp]; schedule.erase(loadOp); schedule.insert(cvtSrc, stage, cluster); @@ -373,8 +376,11 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, alloc.erase(); // Prefetch load ahead of the dot stage if is used by the dot. - Operation *storeOp; + Operation *storeOp{}; if (emitAsyncCopy) { + // FIXME: it should be scheduled as a local_load to hide latency but that + // currently breaks the scheduling as we require one more lds buffer to make + // that work scheduleOp(newLoadOp, SCHED_LOCAL_STORE); } else { storeOp = builder.create(loc, newLoadOp->getResult(0), @@ -384,9 +390,16 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, } // Create local load - auto sharedLoad = - builder.create(loc, loadOp.getType(), viewLoad); - Value result = sharedLoad.getResult(); + Operation *sharedLoad{}; + if (emitAsyncCopy) { + // scheduleOp(wait, SCHED_LOCAL_LOAD); + sharedLoad = builder.create(loc, loadOp.getType(), + viewLoad, wait->getResult(0)); + } else { + sharedLoad = + builder.create(loc, loadOp.getType(), viewLoad); + } + Value result = sharedLoad->getResult(0); if (prefetch) scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); From 1a9f1e025bbedb9c9ceca86cb9517b2391f4f75e Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 15:14:42 +0000 Subject: [PATCH 10/29] Cleanup and simplify AsyncCopy lowering --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 159 ++++++++---------- 1 file changed, 69 insertions(+), 90 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 5757593584a1..8e6b0633a990 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -398,129 +398,111 @@ struct BufferLoadOpConversion } }; -struct AsyncLoadOpConversion +struct AsyncCopyToGlobalOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { using ConvertOpToLLVMPattern< triton::gpu::AsyncCopyGlobalToLocalOp>::ConvertOpToLLVMPattern; - AsyncLoadOpConversion(LLVMTypeConverter &converter, - const AMD::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + AsyncCopyToGlobalOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + bool isLoadWidthSupported(unsigned bits, + const AMD::TargetInfo &targetInfo) const { + llvm::SmallSetVector supportedWidths; + switch (targetInfo.getISAFamily()) { + case mlir::triton::AMD::ISAFamily::CDNA2: + case mlir::triton::AMD::ISAFamily::CDNA3: + supportedWidths.insert(8); + supportedWidths.insert(16); + supportedWidths.insert(32); + if (targetInfo.getGPUKind() == llvm::AMDGPU::GPUKind::GK_GFX950) { + supportedWidths.insert(98); + supportedWidths.insert(128); + } + break; + default: + return false; + } + + return supportedWidths.contains(bits); + } + LogicalResult matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *ctx = rewriter.getContext(); auto loc = op.getLoc(); - Value res = op.getResult(); + Value mask = op.getMask(); Value other = op.getOther(); auto srcTy = op.getSrc().getType(); + auto srcEncoding = srcTy.getEncoding(); + assert((isa(srcEncoding) && + "Unexpected srcEncoding in AsyncCopyGlobalToLocalOpConversion")); + auto dstTy = op.getResult().getType(); auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType()); - auto srcLayout = srcTy.getEncoding(); - assert((isa(srcLayout) && - "Unexpected srcLayout in AsyncCopyGlobalToLocalOpConversion")); - auto resSharedLayout = cast(dstTy.getEncoding()); auto srcShape = srcTy.getShape(); - assert( - (srcShape.size() <= 2) && - "Async copy only supports 1d and 2d tensors: Unexpected rank of %src"); + assert(srcShape.size() <= 2 && "Async copy only supports 1d and 2d " + "tensors: Unexpected rank of %src"); - Value llDst = adaptor.getResult(); Value llSrc = adaptor.getSrc(); - Value llMask = adaptor.getMask(); - Value llOther = adaptor.getOther(); - // %src auto srcElems = unpackLLElements(loc, llSrc, rewriter); - // %dst + Value llDst = adaptor.getResult(); auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct( loc, llDst, resElemTy, rewriter); - // %mask + + Value llMask = adaptor.getMask(); SmallVector maskElems; if (llMask) { maskElems = unpackLLElements(loc, llMask, rewriter); assert(srcElems.size() == maskElems.size()); } + Value llOther = adaptor.getOther(); SmallVector otherElems; if (llOther) { otherElems = unpackLLElements(loc, llOther, rewriter); assert(srcElems.size() == otherElems.size()); } - // TODO check maxVec with mask alignment! - // global.load.lds has a shared dst register so we cannot have per thread // offsets This means our load size has to align with the load_width of unsigned maxVec = getContiguity(op.getSrc(), axisAnalysisPass); if (mask) { - // TODO, if this changes maxVec we cannot use global.load.lds? maxVec = std::min(maxVec, getMaskAlignment(mask)); } - llvm::SmallSetVector supportedLoadBits; - // TODO look up if we support it on mi200 - switch (targetInfo.getISAFamily()) { - case mlir::triton::AMD::ISAFamily::CDNA3: - supportedLoadBits.insert(8); - supportedLoadBits.insert(16); - supportedLoadBits.insert(32); - if (targetInfo.getGPUKind() == llvm::AMDGPU::GPUKind::GK_GFX950) { - supportedLoadBits.insert(98); - supportedLoadBits.insert(128); - } - break; - default: - return emitError(loc, "Async copy not supported on target ISA"); - } - - unsigned int loadStoreBitWidth = maxVec * resElemTy.getIntOrFloatBitWidth(); - - if (!supportedLoadBits.contains(loadStoreBitWidth)) { - return emitError(loc, "Async copy does not supported the required load " - "vectorization, got ") - << loadStoreBitWidth << "bits"; - } - - { - - auto shape = dstTy.getShape(); - LinearLayout regLayout = - triton::gpu::toLinearLayout(shape, srcTy.getEncoding()); - LinearLayout sharedLayout = triton::gpu::toLinearLayout( - shape, dstTy.getEncoding(), resElemTy.getIntOrFloatBitWidth()); - LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout); - llvm::outs() << "Reg to shared: \n" - << regToSharedLayout.toString() << "\n"; - - // We need to check if the lane basis is contigeous because - // global.load.lds does not support per lane offset - auto kLane = str_attr("lane"); - auto kBlock = str_attr("block"); - auto kWarp = str_attr("warp"); - auto kRegister = str_attr("register"); - - for (int inLane : llvm::seq(regToSharedLayout.getInDimSize(kLane))) { - auto idx = regToSharedLayout.apply( - {{kRegister, 0}, {kLane, inLane}, {kWarp, 0}, {kBlock, 0}}); - int32_t offset = idx[0].second; - if (offset != (inLane * maxVec)) { - return emitError(loc, "Invalid layout in AsyncCopy: ") - << "Lane: " << inLane << " is " << offset << " should be " - << inLane << "\n"; - } + auto shape = dstTy.getShape(); + LinearLayout srcLayout = + triton::gpu::toLinearLayout(shape, srcTy.getEncoding()); + LinearLayout sharedLayout = triton::gpu::toLinearLayout( + shape, dstTy.getEncoding(), resElemTy.getIntOrFloatBitWidth()); + LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout); + + // We need to check if the kLane basis is contigeous for the chose + // vectorization because global.load.lds does not support per lane offset + auto kLane = str_attr("lane"); + + for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) { + auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0]; + unsigned expected = maxVec * (1 << inLane); + if (basis != expected) { + return emitError(loc, "Invalid layout in AsyncCopy: ") + << "Lane: " << 1 + inLane << " is " << basis << " should be " + << expected << "\n"; } } @@ -534,23 +516,22 @@ struct AsyncLoadOpConversion shmemAddrs.push_back(shmemAddr); }); assert(ok); - llvm::outs() << "Shared to reg\n"; - llvm::outs().flush(); - int vecBytes = vecTy.getNumElements() * vecTy.getElementTypeBitWidth() / 8; + int vecBits = vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); + if (!isLoadWidthSupported(vecBits, targetInfo)) { + return emitError(loc, "Async copy does not support the required load " + "vectorization, got ") + << vecBits << " bits"; + } + + int vecBytes = vecBits / 8; assert(llvm::isPowerOf2_32(vecBytes)); std::string intrinsic = "llvm.amdgcn.global.load.lds"; - Value loadStoreByteWidthVal = i32_val(loadStoreBitWidth / 8); - llvm::outs() << "Load byte width: " << loadStoreByteWidthVal << "\n"; + Value vecBytesVal = i32_val(vecBytes); - llvm::outs() << "Shem addr count: " << shmemAddrs.size() << "\n"; for (int i = 0; i < shmemAddrs.size(); i++) { - // Tune CG and CA. // TODO Alex select correct cache modifier - // CacheModifier srcCacheModifier = - // wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA; - // assert(wordBytes == 16 || wordBytes == 8 || wordBytes == 4); auto srcIdx = i * maxVec; auto srcPtr = srcElems[srcIdx]; @@ -564,9 +545,8 @@ struct AsyncLoadOpConversion rewriter.create(loc, maskElems[srcIdx], loadBlock, afterLoad); rewriter.setInsertionPointToStart(loadBlock); - rewriter.create(loc, srcPtr, shmemAddrs[i], - loadStoreByteWidthVal, - i32_val(0), i32_val(0)); + rewriter.create( + loc, srcPtr, shmemAddrs[i], vecBytesVal, i32_val(0), i32_val(0)); rewriter.create(loc, afterLoad); rewriter.setInsertionPointToStart(afterLoad); @@ -578,9 +558,8 @@ struct AsyncLoadOpConversion icmp_ne(maskElems[srcIdx], true_val())); } } else { - rewriter.create(loc, srcPtr, shmemAddrs[i], - loadStoreByteWidthVal, - i32_val(0), i32_val(0)); + rewriter.create( + loc, srcPtr, shmemAddrs[i], vecBytesVal, i32_val(0), i32_val(0)); } } @@ -1709,7 +1688,7 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, patterns.add( typeConverter, targetInfo, axisInfoAnalysis, benefit); } From a20b68669818614398e960fffcea409050b9dcec Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 15:50:14 +0000 Subject: [PATCH 11/29] CacheModifiers for AsyncCopy --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 8e6b0633a990..3ab5db63f76c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -21,6 +21,7 @@ using namespace mlir::triton::gpu; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::getSharedMemoryBase; using ::mlir::LLVM::AMD::getContiguity; +using mlir::LLVM::AMD::getCtrlBitsForCacheModifierOnTarget; using ::mlir::LLVM::AMD::getVectorSize; using ::mlir::LLVM::AMD::llLoad; using ::mlir::LLVM::AMD::llStore; @@ -477,14 +478,15 @@ struct AsyncCopyToGlobalOpConversion assert(srcElems.size() == otherElems.size()); } - // global.load.lds has a shared dst register so we cannot have per thread - // offsets This means our load size has to align with the load_width of - unsigned maxVec = getContiguity(op.getSrc(), axisAnalysisPass); if (mask) { maxVec = std::min(maxVec, getMaskAlignment(mask)); } + // global.load.lds does not support per lane offsets. + // We need to ensure that we write coalesced into shared memory. + // This means that the kLane dim needs to be contigeous based on the + // vectorization size auto shape = dstTy.getShape(); LinearLayout srcLayout = triton::gpu::toLinearLayout(shape, srcTy.getEncoding()); @@ -492,10 +494,7 @@ struct AsyncCopyToGlobalOpConversion shape, dstTy.getEncoding(), resElemTy.getIntOrFloatBitWidth()); LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout); - // We need to check if the kLane basis is contigeous for the chose - // vectorization because global.load.lds does not support per lane offset auto kLane = str_attr("lane"); - for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) { auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0]; unsigned expected = maxVec * (1 << inLane); @@ -530,13 +529,18 @@ struct AsyncCopyToGlobalOpConversion std::string intrinsic = "llvm.amdgcn.global.load.lds"; Value vecBytesVal = i32_val(vecBytes); - for (int i = 0; i < shmemAddrs.size(); i++) { - // TODO Alex select correct cache modifier + Value cacheModifiers = i32_val( + getCtrlBitsForCacheModifierOnTarget(op.getCache(), false, targetInfo)); + for (int i = 0; i < shmemAddrs.size(); i++) { auto srcIdx = i * maxVec; auto srcPtr = srcElems[srcIdx]; - if (mask) { + if (!mask) { + rewriter.create( + loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/i32_val(0), + cacheModifiers); + } else { Block *currentBlock = rewriter.getInsertionBlock(); Block *afterLoad = rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); @@ -555,11 +559,8 @@ struct AsyncCopyToGlobalOpConversion packElementRangeIntoVector(rewriter, this->getTypeConverter(), loc, vecTy, otherElems, srcIdx); llStore(rewriter, loc, shmemAddrs[i], storeVal, - icmp_ne(maskElems[srcIdx], true_val())); + icmp_ne(maskElems[srcIdx], true_val()), 0, op.getCache()); } - } else { - rewriter.create( - loc, srcPtr, shmemAddrs[i], vecBytesVal, i32_val(0), i32_val(0)); } } From 97d677d91ebd9839e0e2f398c7c028bbb0ce8131 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 17:37:47 +0000 Subject: [PATCH 12/29] Add lit test for AsyncCopy --- test/Conversion/amd/tritongpu_to_llvm.mlir | 68 ++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 144d32c71df8..dce60cbf2f84 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -294,3 +294,71 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy + tt.func public @async_copy(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the splat to allow the AxisAnalysis to work during lowering + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + // CHECK: rocdl.global.load.lds + %2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy_vectorized + tt.func public @async_copy_vectorized(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the index calculation so AxisAnalysis sees that we can vectorize the load + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + + // Each thread needs to load 8 elements and we load 2 (sizePerThread) per global.load.lds + // CHECK-COUNT-4: rocdl.global.load.lds + // CHECK-NOT: rocdl.global.load.lds + %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy_vectorized + tt.func public @async_copy_vectorized(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the index calculation so AxisAnalysis sees that we can vectorize the load + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + + // Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds + // CHECK: rocdl.global.load.lds + // CHECK-NOT: rocdl.global.load.lds + %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} From 30352ad5fb86ab6853292eeb02c5b25910e6ab33 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 17:41:45 +0000 Subject: [PATCH 13/29] Split AsyncCopy Lit for gfx950 --- test/Conversion/amd/tritongpu_to_llvm.mlir | 25 ------------------- .../amd/tritongpu_to_llvm_gfx950.mlir | 24 ++++++++++++++++++ 2 files changed, 24 insertions(+), 25 deletions(-) create mode 100644 test/Conversion/amd/tritongpu_to_llvm_gfx950.mlir diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index dce60cbf2f84..004b8bbfd059 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -337,28 +337,3 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar tt.return } } - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: async_copy_vectorized - tt.func public @async_copy_vectorized(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, - %arg1: i32 {tt.divisibility = 16 : i32}, - %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { - // We need the index calculation so AxisAnalysis sees that we can vectorize the load - %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> - %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> - %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> - - // Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds - // CHECK: rocdl.global.load.lds - // CHECK-NOT: rocdl.global.load.lds - %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> - tt.return - } -} diff --git a/test/Conversion/amd/tritongpu_to_llvm_gfx950.mlir b/test/Conversion/amd/tritongpu_to_llvm_gfx950.mlir new file mode 100644 index 000000000000..bf9698f5e816 --- /dev/null +++ b/test/Conversion/amd/tritongpu_to_llvm_gfx950.mlir @@ -0,0 +1,24 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 --convert-builtin-func-to-llvm | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy_vectorized + tt.func public @async_copy_vectorized(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the index calculation so AxisAnalysis sees that we can vectorize the load + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + + // Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds + // CHECK: rocdl.global.load.lds + // CHECK-NOT: rocdl.global.load.lds + %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} From fe8619d5abbb6f1a2d98d42e5b1897e3a0cb57fd Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 17:46:47 +0000 Subject: [PATCH 14/29] Add const to getCtrlBitsForCacheModifierOnTarget --- third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 6 +++--- third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index d4b8d7abe01f..84c58fb824fe 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -517,9 +517,9 @@ static int32_t getDefaultCtrlBitsForCacheModifier(triton::CacheModifier cm) { // .cv: don't cache and fetch again // .wb: write-back, writes back data at all cache levels // .wt: write-through, write data directly to system memory -int32_t -getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier cm, bool isBufferLoad, - mlir::triton::AMD::TargetInfo &targetInfo) { +int32_t getCtrlBitsForCacheModifierOnTarget( + triton::CacheModifier cm, bool isBufferLoad, + const mlir::triton::AMD::TargetInfo &targetInfo) { if (targetInfo.getGPUKind() == llvm::AMDGPU::GK_GFX942) // gfx942 return getCtrlBitsForCacheModifierOnGFX942(cm, isBufferLoad); else diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index 1dabe31db2d9..bd02cce16cc9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -54,8 +54,9 @@ void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, // Get flags for a predicated Load or Store std::pair getCacheModifierFlagsForPredicatedCall(LLVM::CallOp); // Get the cachepolicy value for a cache modifier -int32_t getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier, bool, - mlir::triton::AMD::TargetInfo &); +int32_t +getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier, bool, + const mlir::triton::AMD::TargetInfo &); // Get cache modifier information for buffer atomics int32_t getCtrlBitsForBufferAtomicsOnGFX942(bool setSC0, bool setSC1, From 7941a307a8510031359fbfdbf25d44a278540b9d Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 17:53:01 +0000 Subject: [PATCH 15/29] Cleanup StreamPipeliner changes --- .../TritonAMDGPUTransforms/StreamPipeline.cpp | 52 ++++++++----------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index c01bd90e83dc..77cee19e7863 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -261,34 +261,23 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, auto sharedEncodingAttr = cast(allocTy.getEncoding()); - llvm::outs() << "Shared alloc: \n"; - alloc.print(llvm::outs()); - llvm::outs() << "\n"; + auto srcTy = dyn_cast(src.getType()); - bool emitAsyncCopy = false; + bool useAsyncCopy = false; - auto srcTy = dyn_cast(src.getType()); - // We can use AsyncCopy if we do not swizzle into smem - // TODO (alex) ensure it's 2D + // Note that we can only use AsyncCopy when have coalesced LDS writes (e.g. no + // swizzeling). if (triton::tools::getBoolEnv("AMDGCN_USE_DIRECT_TO_LDS") && sharedEncodingAttr.getPerPhase() == 1 && sharedEncodingAttr.getMaxPhase() == 1 && + sharedEncodingAttr.getOrder().size() == 2 && llvm::equal(sharedEncodingAttr.getOrder(), ttg::getOrder(srcTy.getEncoding()))) { - emitAsyncCopy = true; + useAsyncCopy = true; } - llvm::outs() << "Emit async: " << emitAsyncCopy << "\n"; SmallVector copyOffsets(allocTy.getRank(), zero); - Operation *newLoadOp{}; - if (!emitAsyncCopy) { - newLoadOp = builder.clone(*loadOp); - auto [stage, cluster] = schedule[loadOp]; - schedule.erase(loadOp); - schedule.insert(newLoadOp, stage, cluster); - } - // Extract part. SmallVector loadOffsets(allocTy.getRank(), zero); loadOffsets[0] = extractIdx; @@ -300,14 +289,20 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); + Operation *newLoadOp{}; Operation *wait{}; - if (emitAsyncCopy) { + + if (!useAsyncCopy) { + newLoadOp = builder.clone(*loadOp); + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(newLoadOp, stage, cluster); + } else { auto srcTy = dyn_cast(src.getType()); - if (!srcTy) { - llvm::outs() << "INVALID SRC!\n"; - } + assert(srcTy); + // We need to ensure we read coalesced into LDS so we adjust the blocked to - // read coalesced for now + // read coalesced auto shape = subviewTy.getShape(); auto order = sharedEncodingAttr.getOrder(); @@ -325,19 +320,14 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, auto srcEncoding = srcTy.getEncoding(); auto newLayout = ttg::BlockedEncodingAttr::get( - loadOp->getContext(), - sizePerThread, //{1, 1}, // triton::gpu::getSizePerThread(srcEncoding), - threadsPerWarp, //{2, 32}, // - // triton::gpu::getThreadsPerWarp(srcEncoding), + loadOp->getContext(), sizePerThread, threadsPerWarp, triton::gpu::getWarpsPerCTA(srcEncoding), triton::gpu::getOrder(srcEncoding), triton::gpu::getCTALayout(srcEncoding)); - llvm::outs() << "New src encoding: "; newLayout.printStripped(llvm::outs()); llvm::outs() << "\n"; RankedTensorType newArgType = RankedTensorType::get( srcTy.getShape(), srcTy.getElementType(), newLayout); - llvm::outs() << "Source encoding: "; srcTy.getEncoding().print(llvm::outs()); llvm::outs() << "\n"; auto cvtSrc = @@ -377,7 +367,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, // Prefetch load ahead of the dot stage if is used by the dot. Operation *storeOp{}; - if (emitAsyncCopy) { + if (useAsyncCopy) { // FIXME: it should be scheduled as a local_load to hide latency but that // currently breaks the scheduling as we require one more lds buffer to make // that work @@ -391,7 +381,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, // Create local load Operation *sharedLoad{}; - if (emitAsyncCopy) { + if (useAsyncCopy) { // scheduleOp(wait, SCHED_LOCAL_LOAD); sharedLoad = builder.create(loc, loadOp.getType(), viewLoad, wait->getResult(0)); @@ -409,7 +399,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, // instruction scheduling hints to correctly count the emitted `ds_write` // instructions for each GEMM tile. if (auto attr = loadOp->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())) { - if (emitAsyncCopy) { + if (useAsyncCopy) { newLoadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); } else { storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); From def9313bf515451b45b987bb69381860066a4825 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 17:55:15 +0000 Subject: [PATCH 16/29] Revert stream pipeline related changes --- include/triton/Tools/Sys/GetEnv.hpp | 1 - .../TritonAMDGPUTransforms/StreamPipeline.cpp | 125 ++---------------- 2 files changed, 13 insertions(+), 113 deletions(-) diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index cd36a473333a..b11d90be436d 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -14,7 +14,6 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { // clang-format off "AMDGCN_ENABLE_DUMP", "AMDGCN_USE_BUFFER_OPS", - "AMDGCN_USE_DIRECT_TO_LDS", "DISABLE_FAST_REDUCTION", "DISABLE_LLVM_OPT", "DISABLE_MMA_V3", diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 77cee19e7863..b53d4b55421c 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -10,7 +10,6 @@ #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/Support/Debug.h" //===----------------------------------------------------------------------===// @@ -258,25 +257,12 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, Value other = loadOp.getOther(); ttg::MemDescType allocTy = cast(alloc.getType()); - - auto sharedEncodingAttr = - cast(allocTy.getEncoding()); - auto srcTy = dyn_cast(src.getType()); - - bool useAsyncCopy = false; - - // Note that we can only use AsyncCopy when have coalesced LDS writes (e.g. no - // swizzeling). - if (triton::tools::getBoolEnv("AMDGCN_USE_DIRECT_TO_LDS") && - sharedEncodingAttr.getPerPhase() == 1 && - sharedEncodingAttr.getMaxPhase() == 1 && - sharedEncodingAttr.getOrder().size() == 2 && - llvm::equal(sharedEncodingAttr.getOrder(), - ttg::getOrder(srcTy.getEncoding()))) { - useAsyncCopy = true; - } - SmallVector copyOffsets(allocTy.getRank(), zero); + Operation *copy = builder.clone(*loadOp); + + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); // Extract part. SmallVector loadOffsets(allocTy.getRank(), zero); @@ -288,72 +274,6 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); - - Operation *newLoadOp{}; - Operation *wait{}; - - if (!useAsyncCopy) { - newLoadOp = builder.clone(*loadOp); - auto [stage, cluster] = schedule[loadOp]; - schedule.erase(loadOp); - schedule.insert(newLoadOp, stage, cluster); - } else { - auto srcTy = dyn_cast(src.getType()); - assert(srcTy); - - // We need to ensure we read coalesced into LDS so we adjust the blocked to - // read coalesced - - auto shape = subviewTy.getShape(); - auto order = sharedEncodingAttr.getOrder(); - // Aim to use wider loads - llvm::SmallVector sizePerThread{1, 1}; - sizePerThread[order[0]] = - 32 / allocTy.getElementType().getIntOrFloatBitWidth(); - llvm::SmallVector threadsPerWarp{1, 1}; - assert((shape[order[0]] % sizePerThread[0]) == 0); - unsigned warpSize = 64; - threadsPerWarp[order[0]] = - std::min(warpSize, shape[order[0]] / sizePerThread[order[0]]); - threadsPerWarp[order[1]] = - std::max(1, warpSize / threadsPerWarp[order[0]]); - - auto srcEncoding = srcTy.getEncoding(); - auto newLayout = ttg::BlockedEncodingAttr::get( - loadOp->getContext(), sizePerThread, threadsPerWarp, - triton::gpu::getWarpsPerCTA(srcEncoding), - triton::gpu::getOrder(srcEncoding), - triton::gpu::getCTALayout(srcEncoding)); - newLayout.printStripped(llvm::outs()); - llvm::outs() << "\n"; - RankedTensorType newArgType = RankedTensorType::get( - srcTy.getShape(), srcTy.getElementType(), newLayout); - srcTy.getEncoding().print(llvm::outs()); - llvm::outs() << "\n"; - auto cvtSrc = - builder.create(loadOp.getLoc(), newArgType, src); - - auto mask = loadOp.getMask(); - if (mask) { - auto maskTy = dyn_cast(mask.getType()); - RankedTensorType newMaskTy = RankedTensorType::get( - maskTy.getShape(), maskTy.getElementType(), newLayout); - auto cvtMask = builder.create( - loadOp->getLoc(), newMaskTy, loadOp.getMask()); - } - - newLoadOp = builder.create( - loadOp.getLoc(), cvtSrc.getResult(), viewLoad, mask, other, - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - - wait = builder.create(loc, newLoadOp->getResult(0), 0); - - auto [stage, cluster] = schedule[loadOp]; - schedule.erase(loadOp); - schedule.insert(cvtSrc, stage, cluster); - schedule.insert(newLoadOp, stage, cluster); - } - // Clean up old local caches. SmallVector allocsToErase; for (Operation *user : loadOp->getUsers()) { @@ -366,30 +286,15 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, alloc.erase(); // Prefetch load ahead of the dot stage if is used by the dot. - Operation *storeOp{}; - if (useAsyncCopy) { - // FIXME: it should be scheduled as a local_load to hide latency but that - // currently breaks the scheduling as we require one more lds buffer to make - // that work - scheduleOp(newLoadOp, SCHED_LOCAL_STORE); - } else { - storeOp = builder.create(loc, newLoadOp->getResult(0), - viewLoad); - scheduleOp(viewLoad, SCHED_LOCAL_STORE); - scheduleOp(storeOp, SCHED_LOCAL_STORE); - } + auto storeOp = + builder.create(loc, copy->getResult(0), viewLoad); + scheduleOp(viewLoad, SCHED_LOCAL_STORE); + scheduleOp(storeOp, SCHED_LOCAL_STORE); // Create local load - Operation *sharedLoad{}; - if (useAsyncCopy) { - // scheduleOp(wait, SCHED_LOCAL_LOAD); - sharedLoad = builder.create(loc, loadOp.getType(), - viewLoad, wait->getResult(0)); - } else { - sharedLoad = - builder.create(loc, loadOp.getType(), viewLoad); - } - Value result = sharedLoad->getResult(0); + auto sharedLoad = + builder.create(loc, loadOp.getType(), viewLoad); + Value result = sharedLoad.getResult(); if (prefetch) scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); @@ -399,11 +304,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, // instruction scheduling hints to correctly count the emitted `ds_write` // instructions for each GEMM tile. if (auto attr = loadOp->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())) { - if (useAsyncCopy) { - newLoadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); - } else { - storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); - } + storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); } loadOp->replaceAllUsesWith(ValueRange{result}); From 318caa2611525f6d2d8aed4310049b7ebdf47adb Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 18:01:44 +0000 Subject: [PATCH 17/29] Add missing CDNA1 to AsyncCopy support list --- third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 3ab5db63f76c..cd21c6b072b9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -417,6 +417,7 @@ struct AsyncCopyToGlobalOpConversion const AMD::TargetInfo &targetInfo) const { llvm::SmallSetVector supportedWidths; switch (targetInfo.getISAFamily()) { + case mlir::triton::AMD::ISAFamily::CDNA1: case mlir::triton::AMD::ISAFamily::CDNA2: case mlir::triton::AMD::ISAFamily::CDNA3: supportedWidths.insert(8); From 6600138a1f7d800bbadbbe6dbd8a5446f6b381f2 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 28 Jan 2025 18:05:57 +0000 Subject: [PATCH 18/29] Cleanup --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 58 +++++++++---------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index cd21c6b072b9..fddf024c0421 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -399,16 +399,16 @@ struct BufferLoadOpConversion } }; -struct AsyncCopyToGlobalOpConversion +struct AsyncCopyGlobalToLocalOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { using ConvertOpToLLVMPattern< triton::gpu::AsyncCopyGlobalToLocalOp>::ConvertOpToLLVMPattern; - AsyncCopyToGlobalOpConversion(LLVMTypeConverter &converter, - const AMD::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + AsyncCopyGlobalToLocalOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} @@ -442,21 +442,17 @@ struct AsyncCopyToGlobalOpConversion MLIRContext *ctx = rewriter.getContext(); auto loc = op.getLoc(); - Value mask = op.getMask(); - Value other = op.getOther(); - auto srcTy = op.getSrc().getType(); auto srcEncoding = srcTy.getEncoding(); assert((isa(srcEncoding) && "Unexpected srcEncoding in AsyncCopyGlobalToLocalOpConversion")); - - auto dstTy = op.getResult().getType(); - auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType()); - auto srcShape = srcTy.getShape(); assert(srcShape.size() <= 2 && "Async copy only supports 1d and 2d " "tensors: Unexpected rank of %src"); + auto dstTy = op.getResult().getType(); + auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + Value llSrc = adaptor.getSrc(); auto srcElems = unpackLLElements(loc, llSrc, rewriter); @@ -465,21 +461,9 @@ struct AsyncCopyToGlobalOpConversion auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct( loc, llDst, resElemTy, rewriter); - Value llMask = adaptor.getMask(); - SmallVector maskElems; - if (llMask) { - maskElems = unpackLLElements(loc, llMask, rewriter); - assert(srcElems.size() == maskElems.size()); - } - - Value llOther = adaptor.getOther(); - SmallVector otherElems; - if (llOther) { - otherElems = unpackLLElements(loc, llOther, rewriter); - assert(srcElems.size() == otherElems.size()); - } - unsigned maxVec = getContiguity(op.getSrc(), axisAnalysisPass); + + Value mask = op.getMask(); if (mask) { maxVec = std::min(maxVec, getMaskAlignment(mask)); } @@ -526,13 +510,25 @@ struct AsyncCopyToGlobalOpConversion int vecBytes = vecBits / 8; assert(llvm::isPowerOf2_32(vecBytes)); - - std::string intrinsic = "llvm.amdgcn.global.load.lds"; Value vecBytesVal = i32_val(vecBytes); Value cacheModifiers = i32_val( getCtrlBitsForCacheModifierOnTarget(op.getCache(), false, targetInfo)); + Value llMask = adaptor.getMask(); + SmallVector maskElems; + if (llMask) { + maskElems = unpackLLElements(loc, llMask, rewriter); + assert(srcElems.size() == maskElems.size()); + } + + Value other = op.getOther(); + SmallVector otherElems; + if (other) { + otherElems = unpackLLElements(loc, adaptor.getOther(), rewriter); + assert(srcElems.size() == otherElems.size()); + } + for (int i = 0; i < shmemAddrs.size(); i++) { auto srcIdx = i * maxVec; auto srcPtr = srcElems[srcIdx]; @@ -1652,7 +1648,7 @@ struct AsyncWaitConversion : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - rewriter.create(loc, 0); + rewriter.create(loc, op.getNum()); rewriter.replaceOp(op, i32_val(0)); return success(); } @@ -1671,7 +1667,7 @@ struct AsyncCommitGroupConversion LogicalResult matchAndRewrite(AsyncCommitGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // We do not have that concept so simply drop it + // Drop the result token auto loc = op->getLoc(); rewriter.replaceOp(op, i32_val(0)); return success(); @@ -1690,7 +1686,7 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, patterns.add( typeConverter, targetInfo, axisInfoAnalysis, benefit); } From ea02c3cbea2f6d8f36878f2ca7106e097312d2d4 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 29 Jan 2025 09:51:27 +0000 Subject: [PATCH 19/29] Replace macros for llvm ops with TritonLLVMOpBuilder --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index fddf024c0421..554460cf3cf2 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -439,16 +439,15 @@ struct AsyncCopyGlobalToLocalOpConversion matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MLIRContext *ctx = rewriter.getContext(); auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto srcTy = op.getSrc().getType(); auto srcEncoding = srcTy.getEncoding(); assert((isa(srcEncoding) && "Unexpected srcEncoding in AsyncCopyGlobalToLocalOpConversion")); - auto srcShape = srcTy.getShape(); - assert(srcShape.size() <= 2 && "Async copy only supports 1d and 2d " - "tensors: Unexpected rank of %src"); + assert(srcTy.getShape().size() <= 2 && "Async copy only supports 1d and 2d " + "tensors: Unexpected rank of %src"); auto dstTy = op.getResult().getType(); auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType()); @@ -479,7 +478,7 @@ struct AsyncCopyGlobalToLocalOpConversion shape, dstTy.getEncoding(), resElemTy.getIntOrFloatBitWidth()); LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout); - auto kLane = str_attr("lane"); + StringAttr kLane = rewriter.getStringAttr("lane"); for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) { auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0]; unsigned expected = maxVec * (1 << inLane); @@ -510,9 +509,9 @@ struct AsyncCopyGlobalToLocalOpConversion int vecBytes = vecBits / 8; assert(llvm::isPowerOf2_32(vecBytes)); - Value vecBytesVal = i32_val(vecBytes); + Value vecBytesVal = b.i32_val(vecBytes); - Value cacheModifiers = i32_val( + Value cacheModifiers = b.i32_val( getCtrlBitsForCacheModifierOnTarget(op.getCache(), false, targetInfo)); Value llMask = adaptor.getMask(); @@ -535,7 +534,7 @@ struct AsyncCopyGlobalToLocalOpConversion if (!mask) { rewriter.create( - loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/i32_val(0), + loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0), cacheModifiers); } else { Block *currentBlock = rewriter.getInsertionBlock(); @@ -546,8 +545,9 @@ struct AsyncCopyGlobalToLocalOpConversion rewriter.create(loc, maskElems[srcIdx], loadBlock, afterLoad); rewriter.setInsertionPointToStart(loadBlock); - rewriter.create( - loc, srcPtr, shmemAddrs[i], vecBytesVal, i32_val(0), i32_val(0)); + rewriter.create(loc, srcPtr, shmemAddrs[i], + vecBytesVal, b.i32_val(0), + cacheModifiers); rewriter.create(loc, afterLoad); rewriter.setInsertionPointToStart(afterLoad); @@ -556,7 +556,7 @@ struct AsyncCopyGlobalToLocalOpConversion packElementRangeIntoVector(rewriter, this->getTypeConverter(), loc, vecTy, otherElems, srcIdx); llStore(rewriter, loc, shmemAddrs[i], storeVal, - icmp_ne(maskElems[srcIdx], true_val()), 0, op.getCache()); + b.icmp_ne(maskElems[srcIdx], b.true_val()), 0, op.getCache()); } } } @@ -1648,8 +1648,9 @@ struct AsyncWaitConversion : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); rewriter.create(loc, op.getNum()); - rewriter.replaceOp(op, i32_val(0)); + rewriter.replaceOp(op, b.i32_val(0)); return success(); } }; @@ -1669,7 +1670,8 @@ struct AsyncCommitGroupConversion ConversionPatternRewriter &rewriter) const override { // Drop the result token auto loc = op->getLoc(); - rewriter.replaceOp(op, i32_val(0)); + auto b = TritonLLVMOpBuilder(loc, rewriter); + rewriter.replaceOp(op, b.i32_val(0)); return success(); } }; From 13419bb87cf304c3fe1403ccd241528aa67953cf Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Thu, 30 Jan 2025 10:21:22 +0000 Subject: [PATCH 20/29] Fix wrong value in supported bit width for global.to.lds --- .../amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 554460cf3cf2..bb254975c3ae 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -424,7 +424,7 @@ struct AsyncCopyGlobalToLocalOpConversion supportedWidths.insert(16); supportedWidths.insert(32); if (targetInfo.getGPUKind() == llvm::AMDGPU::GPUKind::GK_GFX950) { - supportedWidths.insert(98); + supportedWidths.insert(96); supportedWidths.insert(128); } break; @@ -545,9 +545,9 @@ struct AsyncCopyGlobalToLocalOpConversion rewriter.create(loc, maskElems[srcIdx], loadBlock, afterLoad); rewriter.setInsertionPointToStart(loadBlock); - rewriter.create(loc, srcPtr, shmemAddrs[i], - vecBytesVal, b.i32_val(0), - cacheModifiers); + rewriter.create( + loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0), + cacheModifiers); rewriter.create(loc, afterLoad); rewriter.setInsertionPointToStart(afterLoad); From ca8b441d4b5bf68dd557d286d50f3d0571664f58 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Fri, 31 Jan 2025 11:08:29 +0000 Subject: [PATCH 21/29] Addressing review comments --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 129 +++++++++--------- 1 file changed, 61 insertions(+), 68 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index bb254975c3ae..3c95ba3c5dd9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -20,12 +20,9 @@ using namespace mlir::triton::gpu; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::getSharedMemoryBase; -using ::mlir::LLVM::AMD::getContiguity; -using mlir::LLVM::AMD::getCtrlBitsForCacheModifierOnTarget; using ::mlir::LLVM::AMD::getVectorSize; using ::mlir::LLVM::AMD::llLoad; using ::mlir::LLVM::AMD::llStore; -using ::mlir::triton::AMD::ISAFamily; using ::mlir::triton::gpu::getTotalElemsPerThread; namespace { @@ -402,24 +399,21 @@ struct BufferLoadOpConversion struct AsyncCopyGlobalToLocalOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { - using ConvertOpToLLVMPattern< - triton::gpu::AsyncCopyGlobalToLocalOp>::ConvertOpToLLVMPattern; - AsyncCopyGlobalToLocalOpConversion(LLVMTypeConverter &converter, const AMD::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, - benefit), + : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} - bool isLoadWidthSupported(unsigned bits, - const AMD::TargetInfo &targetInfo) const { + bool supportsLoadWidth(unsigned bits, + const AMD::TargetInfo &targetInfo) const { llvm::SmallSetVector supportedWidths; + using mlir::triton::AMD::ISAFamily; switch (targetInfo.getISAFamily()) { - case mlir::triton::AMD::ISAFamily::CDNA1: - case mlir::triton::AMD::ISAFamily::CDNA2: - case mlir::triton::AMD::ISAFamily::CDNA3: + case ISAFamily::CDNA1: + case ISAFamily::CDNA2: + case ISAFamily::CDNA3: supportedWidths.insert(8); supportedWidths.insert(16); supportedWidths.insert(32); @@ -444,10 +438,12 @@ struct AsyncCopyGlobalToLocalOpConversion auto srcTy = op.getSrc().getType(); auto srcEncoding = srcTy.getEncoding(); - assert((isa(srcEncoding) && - "Unexpected srcEncoding in AsyncCopyGlobalToLocalOpConversion")); - assert(srcTy.getShape().size() <= 2 && "Async copy only supports 1d and 2d " - "tensors: Unexpected rank of %src"); + + if (!isa(srcEncoding)) + return rewriter.notifyMatchFailure( + op, "requires Blocked or Slice encoding for src"); + if (srcTy.getShape().size() != 2) + return rewriter.notifyMatchFailure(op, "only supports 2d tensors"); auto dstTy = op.getResult().getType(); auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType()); @@ -460,8 +456,14 @@ struct AsyncCopyGlobalToLocalOpConversion auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct( loc, llDst, resElemTy, rewriter); - unsigned maxVec = getContiguity(op.getSrc(), axisAnalysisPass); - + // We can load N elements at a time if: + // 1. Every group of N source pointers are contiguous. For example, if + // N=2, then the pointers should be [x, x+1, y, y+1, ...]. + // 2. The mask (if present) has "alignment" N, meaning that each group of N + // mask bits are the same. For example if N=2, the mask must be + // [x, x, y, y, ...]. + unsigned maxVec = + mlir::LLVM::AMD::getContiguity(op.getSrc(), axisAnalysisPass); Value mask = op.getMask(); if (mask) { maxVec = std::min(maxVec, getMaskAlignment(mask)); @@ -483,9 +485,12 @@ struct AsyncCopyGlobalToLocalOpConversion auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0]; unsigned expected = maxVec * (1 << inLane); if (basis != expected) { - return emitError(loc, "Invalid layout in AsyncCopy: ") - << "Lane: " << 1 + inLane << " is " << basis << " should be " - << expected << "\n"; + LDBG("detected uncoalesced layout from blocked to shared in async copy " + "for lane " + << 1 + inLane << "; given " << basis << " but expected " + << expected); + return rewriter.notifyMatchFailure(op, + "does not write coalesced into LDS"); } } @@ -501,18 +506,18 @@ struct AsyncCopyGlobalToLocalOpConversion assert(ok); int vecBits = vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); - if (!isLoadWidthSupported(vecBits, targetInfo)) { - return emitError(loc, "Async copy does not support the required load " - "vectorization, got ") - << vecBits << " bits"; + if (!supportsLoadWidth(vecBits, targetInfo)) { + return rewriter.notifyMatchFailure( + op, "Async copy does not support the required load vectorization"); } int vecBytes = vecBits / 8; assert(llvm::isPowerOf2_32(vecBytes)); Value vecBytesVal = b.i32_val(vecBytes); - Value cacheModifiers = b.i32_val( - getCtrlBitsForCacheModifierOnTarget(op.getCache(), false, targetInfo)); + Value cacheModifiers = + b.i32_val(mlir::LLVM::AMD::getCtrlBitsForCacheModifierOnTarget( + op.getCache(), false, targetInfo)); Value llMask = adaptor.getMask(); SmallVector maskElems; @@ -536,28 +541,28 @@ struct AsyncCopyGlobalToLocalOpConversion rewriter.create( loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0), cacheModifiers); - } else { - Block *currentBlock = rewriter.getInsertionBlock(); - Block *afterLoad = - rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); - Block *loadBlock = rewriter.createBlock(afterLoad); - rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, maskElems[srcIdx], loadBlock, - afterLoad); - rewriter.setInsertionPointToStart(loadBlock); - rewriter.create( - loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0), - cacheModifiers); + continue; + } - rewriter.create(loc, afterLoad); - rewriter.setInsertionPointToStart(afterLoad); - if (other) { - Value storeVal = - packElementRangeIntoVector(rewriter, this->getTypeConverter(), - loc, vecTy, otherElems, srcIdx); - llStore(rewriter, loc, shmemAddrs[i], storeVal, - b.icmp_ne(maskElems[srcIdx], b.true_val()), 0, op.getCache()); - } + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterLoad = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *loadBlock = rewriter.createBlock(afterLoad); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, maskElems[srcIdx], loadBlock, + afterLoad); + rewriter.setInsertionPointToStart(loadBlock); + rewriter.create( + loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0), + cacheModifiers); + + rewriter.create(loc, afterLoad); + rewriter.setInsertionPointToStart(afterLoad); + if (other) { + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, vecTy, otherElems, srcIdx); + llStore(rewriter, loc, shmemAddrs[i], storeVal, + b.icmp_ne(maskElems[srcIdx], b.true_val()), 0, op.getCache()); } } @@ -1636,13 +1641,6 @@ struct AtomicRMWOpConversion struct AsyncWaitConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - AsyncWaitConversion(LLVMTypeConverter &converter, - const AMD::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit) {} - LogicalResult matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -1659,12 +1657,6 @@ struct AsyncCommitGroupConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - AsyncCommitGroupConversion(LLVMTypeConverter &converter, - const AMD::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit) {} - LogicalResult matchAndRewrite(AsyncCommitGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -1685,11 +1677,12 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns.add( - typeConverter, targetInfo, axisInfoAnalysis, benefit); + patterns + .add( + typeConverter, targetInfo, axisInfoAnalysis, benefit); + patterns.add(typeConverter, + benefit); } } // namespace mlir::triton::AMD From 6aa3554302ee06c639b25e3b2ea6fd838721715a Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Fri, 31 Jan 2025 14:11:29 +0000 Subject: [PATCH 22/29] Unified async ops lit tests --- test/Conversion/amd/tritongpu_to_llvm.mlir | 43 ------------------- .../amd/tritongpu_to_llvm_gfx950.mlir | 24 ----------- 2 files changed, 67 deletions(-) delete mode 100644 test/Conversion/amd/tritongpu_to_llvm_gfx950.mlir diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 004b8bbfd059..144d32c71df8 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -294,46 +294,3 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr tt.return } } - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: async_copy - tt.func public @async_copy(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, - %arg1: i32 {tt.divisibility = 16 : i32}, - %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { - // We need the splat to allow the AxisAnalysis to work during lowering - %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> - // CHECK: rocdl.global.load.lds - %2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> - tt.return - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: async_copy_vectorized - tt.func public @async_copy_vectorized(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, - %arg1: i32 {tt.divisibility = 16 : i32}, - %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { - // We need the index calculation so AxisAnalysis sees that we can vectorize the load - %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> - %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> - %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> - - // Each thread needs to load 8 elements and we load 2 (sizePerThread) per global.load.lds - // CHECK-COUNT-4: rocdl.global.load.lds - // CHECK-NOT: rocdl.global.load.lds - %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> - tt.return - } -} diff --git a/test/Conversion/amd/tritongpu_to_llvm_gfx950.mlir b/test/Conversion/amd/tritongpu_to_llvm_gfx950.mlir deleted file mode 100644 index bf9698f5e816..000000000000 --- a/test/Conversion/amd/tritongpu_to_llvm_gfx950.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 --convert-builtin-func-to-llvm | FileCheck %s - -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: async_copy_vectorized - tt.func public @async_copy_vectorized(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, - %arg1: i32 {tt.divisibility = 16 : i32}, - %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { - // We need the index calculation so AxisAnalysis sees that we can vectorize the load - %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> - %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> - %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> - - // Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds - // CHECK: rocdl.global.load.lds - // CHECK-NOT: rocdl.global.load.lds - %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> - tt.return - } -} From 04fad93b5bb9ac806d8f992251b24d0ab8ec438c Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Fri, 31 Jan 2025 18:34:18 +0000 Subject: [PATCH 23/29] Emit correct wmcnt wait instead of waiting on all cnts --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 3c95ba3c5dd9..d2b74c203574 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1647,7 +1647,28 @@ struct AsyncWaitConversion : public ConvertOpToLLVMPattern { auto loc = op->getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); - rewriter.create(loc, op.getNum()); + + // global.load.lds uses vmcnt to synchronize + // The rocdl op stores all possible coutners in a single int32 value (v) + // The vmcnt (6 bits) is split into a lower 3:0 and higher part 5:4 + // The lower parts is stored in 3:0 of v and the higher part in bits 15:14 + // We have to set all other bits in v to 1 to signal we are not interested + // in those + + int vmCnt = op.getNum(); + if (vmCnt >= 64) { + return emitError(loc, "AsyncWait does not support values >= 64"); + } + + // Extract low and high bits and combine while setting all other bits to 1 + unsigned lowBits = vmCnt & 0xF; + unsigned highBits = vmCnt >> 4 << 14; + unsigned otherCnts = ~0xC00F; // C00F has bits 15:14 and 3:0 set + unsigned waitValue = lowBits | highBits | otherCnts; + + rewriter.create(loc, waitValue); + + // Drop the result AsyncToken rewriter.replaceOp(op, b.i32_val(0)); return success(); } @@ -1660,7 +1681,7 @@ struct AsyncCommitGroupConversion LogicalResult matchAndRewrite(AsyncCommitGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Drop the result token + // Drop the result AsyncToken auto loc = op->getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); rewriter.replaceOp(op, b.i32_val(0)); From f6cbe22eee868929eb7a5d682bb1167d6f4a6d10 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Fri, 31 Jan 2025 18:45:43 +0000 Subject: [PATCH 24/29] Add tests for AsyncWait/AsyncCommitGroup --- test/Conversion/amd/async_ops_to_llvm.mlir | 119 +++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 test/Conversion/amd/async_ops_to_llvm.mlir diff --git a/test/Conversion/amd/async_ops_to_llvm.mlir b/test/Conversion/amd/async_ops_to_llvm.mlir new file mode 100644 index 000000000000..d92308c73a7e --- /dev/null +++ b/test/Conversion/amd/async_ops_to_llvm.mlir @@ -0,0 +1,119 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950 +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy + tt.func public @async_copy(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the splat to allow the AxisAnalysis to work during lowering + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + // Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds + // CHECK-COUNT-8: rocdl.global.load.lds + // CHECK-NOT: rocdl.global.load.lds + %2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy_vectorized_2xf16 + tt.func public @async_copy_vectorized_2xf16(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the index calculation so AxisAnalysis sees that we can vectorize the load + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + + // Each thread needs to load 8 elements and we load 2 (sizePerThread) per global.load.lds + // CHECK-COUNT-4: rocdl.global.load.lds + // CHECK-NOT: rocdl.global.load.lds + %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // GFX950-LABEL: async_copy_vectorized_8xf16 + tt.func public @async_copy_vectorized_8xf16(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the index calculation so AxisAnalysis sees that we can vectorize the load + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + + // Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds + // GFX950: rocdl.global.load.lds + // GFX950-NOT: rocdl.global.load.lds + + // GFX942 does not support vectorization > 4bytes + // expected-error@+1 {{failed to legalize operation 'ttg.async_copy_global_to_local' that was explicitly marked illegal}} + %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_wait + tt.func public @async_wait(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the index calculation so AxisAnalysis sees that we can vectorize the load + // The value of the rocdl.waitcnt is explained in the lowering of async_wait + + // CHECK: rocdl.waitcnt -49168 + // CHECK: rocdl.barrier + ttg.async_wait {num = 0 : i32} + // CHECK: rocdl.waitcnt -49167 + // CHECK: rocdl.barrier + ttg.async_wait {num = 1 : i32} + // CHECK: rocdl.waitcnt -2 + // CHECK: rocdl.barrier + ttg.async_wait {num = 62 : i32} + // CHECK: rocdl.waitcnt -1 + // CHECK: rocdl.barrier + ttg.async_wait {num = 63 : i32} + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_commit_group + tt.func public @async_commit_group(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // CHECK: llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.return + ttg.async_commit_group + tt.return + } +} From 3d30f43b6bd5a2f91e2d4aad10b2eda2aa495b96 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Mon, 3 Feb 2025 18:54:44 +0000 Subject: [PATCH 25/29] Limit AsyncWait conversion to gfx9 --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index d2b74c203574..36f5bec36b5a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1639,17 +1639,33 @@ struct AtomicRMWOpConversion } }; -struct AsyncWaitConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern { + AsyncWaitOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {} + LogicalResult matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + using mlir::triton::AMD::ISAFamily; + + switch (targetInfo.getISAFamily()) { + case ISAFamily::CDNA1: + case ISAFamily::CDNA2: + case ISAFamily::CDNA3: + break; + default: + return rewriter.notifyMatchFailure( + op, "Only supported on target architecture"); + } + auto loc = op->getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); // global.load.lds uses vmcnt to synchronize - // The rocdl op stores all possible coutners in a single int32 value (v) + // The rocdl op stores all available counters in a single int32 value (v) // The vmcnt (6 bits) is split into a lower 3:0 and higher part 5:4 // The lower parts is stored in 3:0 of v and the higher part in bits 15:14 // We have to set all other bits in v to 1 to signal we are not interested @@ -1672,9 +1688,12 @@ struct AsyncWaitConversion : public ConvertOpToLLVMPattern { rewriter.replaceOp(op, b.i32_val(0)); return success(); } + +private: + const AMD::TargetInfo &targetInfo; }; -struct AsyncCommitGroupConversion +struct AsyncCommitGroupOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1703,7 +1722,7 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, StoreOpConversion, BufferLoadOpConversion, BufferStoreOpConversion, BufferAtomicRMWOpConversion, AsyncCopyGlobalToLocalOpConversion>( typeConverter, targetInfo, axisInfoAnalysis, benefit); - patterns.add(typeConverter, - benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, benefit); } } // namespace mlir::triton::AMD From 0c382dbbc51c7242d87011c38ad06a9871180a68 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Mon, 3 Feb 2025 19:08:14 +0000 Subject: [PATCH 26/29] Add AsyncOpy lowering lit test with masking and other values --- test/Conversion/amd/async_ops_to_llvm.mlir | 62 +++++++++++++++++++++- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/test/Conversion/amd/async_ops_to_llvm.mlir b/test/Conversion/amd/async_ops_to_llvm.mlir index d92308c73a7e..3c6d9303ce25 100644 --- a/test/Conversion/amd/async_ops_to_llvm.mlir +++ b/test/Conversion/amd/async_ops_to_llvm.mlir @@ -63,7 +63,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds // GFX950: rocdl.global.load.lds - // GFX950-NOT: rocdl.global.load.lds + // GFX950-next: llvm.return // GFX942 does not support vectorization > 4bytes // expected-error@+1 {{failed to legalize operation 'ttg.async_copy_global_to_local' that was explicitly marked illegal}} @@ -112,8 +112,66 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { // CHECK: llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.return + // CHECK-NEXT: llvm.return ttg.async_commit_group tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy_mask_other + tt.func public @async_copy_mask_other(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>, + %arg3: i32 {tt.divisibility = 16 : i32}) { + // We need the splat to allow the AxisAnalysis to work during lowering + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c31_i32 = arith.constant 31 : i32 + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %29 = arith.addi %arg3, %c31_i32 : i32 + %30 = arith.divsi %29, %c32_i32 : i32 + %31 = arith.cmpi sgt, %30, %c0_i32 : i32 + + %51 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %52 = tt.expand_dims %51 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %65 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #blocked> + %66 = arith.cmpi slt, %52, %65 : tensor<32x1xi32, #blocked> + %67 = tt.broadcast %66 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> + + %70 = tt.splat %31 : i1 -> tensor<32x32xi1, #blocked> + %71 = arith.andi %70, %67 : tensor<32x32xi1, #blocked> + + // Each thread needs to load 4 elements and we load 1 (sizePerThread) per global.load.lds + // Note that mask/other alignment is 1 so we need 4 conditionals + + // CHECK: llvm.cond_br + // CHECK: rocdl.global.load.lds + // CHECK-NEXT: llvm.br + // CHECK: _predicated_store + + // CHECK: llvm.cond_br + // CHECK: rocdl.global.load.lds + // CHECK-NEXT: llvm.br + // CHECK: _predicated_store + + // CHECK: llvm.cond_br + // CHECK: rocdl.global.load.lds + // CHECK-NEXT: llvm.br + // CHECK: _predicated_store + + // CHECK: llvm.cond_br + // CHECK: rocdl.global.load.lds + // CHECK-NEXT: llvm.br + // CHECK: _predicated_store + + %2 = ttg.async_copy_global_to_local %1, %arg2 mask %67 other %cst_0 : tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + tt.return + } +} From f560aebae085b889aa6e08407a9210c0f3dae81b Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 5 Feb 2025 07:38:17 +0000 Subject: [PATCH 27/29] Added async copy lit tests with cache modifiers --- test/Conversion/amd/async_ops_to_llvm.mlir | 44 +++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/test/Conversion/amd/async_ops_to_llvm.mlir b/test/Conversion/amd/async_ops_to_llvm.mlir index 3c6d9303ce25..83d8edb639b3 100644 --- a/test/Conversion/amd/async_ops_to_llvm.mlir +++ b/test/Conversion/amd/async_ops_to_llvm.mlir @@ -1,5 +1,5 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950 -// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s --check-prefix=GFX942 #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> @@ -175,3 +175,45 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [16, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // GFX942-LABEL: async_copy_cache_mods + tt.func public @async_copy_cache_mods(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>) { + // We need the splat to allow the AxisAnalysis to work during lowering + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + // Each thread needs to load 1 element and we load 1 (sizePerThread) per global.load.lds + + // GFX942: llvm.getelementptr + // GFX942: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32 + // GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]] + %2 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = ca: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + // GFX942: llvm.getelementptr + // GFX942: %[[aux_cg:.*]] = llvm.mlir.constant(0 : i32) : i32 + // GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cg]] + %3 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cg: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + // GFX942: llvm.getelementptr + // GFX942: %[[aux_cs:.*]] = llvm.mlir.constant(3 : i32) : i32 + // GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cs]] + %5 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cs: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + // GFX942: llvm.getelementptr + // GFX942: %[[aux_cv:.*]] = llvm.mlir.constant(9 : i32) : i32 + // GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cv]] + %6 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cv: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + // GFX942: llvm.getelementptr + // GFX942: %[[aux_wb:.*]] = llvm.mlir.constant(0 : i32) : i32 + // GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_wb]] + %7 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = wb: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + // GFX942: llvm.getelementptr + // GFX942: %[[aux_wt:.*]] = llvm.mlir.constant(8 : i32) : i32 + // GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_wt]] + %8 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = wt: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + tt.return + } +} From d90ffbecd656c612543df9c68c5c8d4b08f39abe Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 5 Feb 2025 13:27:49 +0000 Subject: [PATCH 28/29] Adjust to shared encoding changes --- test/Conversion/amd/async_ops_to_llvm.mlir | 18 ++++++++---------- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 4 ++-- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/test/Conversion/amd/async_ops_to_llvm.mlir b/test/Conversion/amd/async_ops_to_llvm.mlir index 83d8edb639b3..7b6d03de70b8 100644 --- a/test/Conversion/amd/async_ops_to_llvm.mlir +++ b/test/Conversion/amd/async_ops_to_llvm.mlir @@ -2,7 +2,7 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s --check-prefix=GFX942 #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: async_copy @@ -22,7 +22,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: async_copy_vectorized_2xf16 @@ -47,7 +47,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // GFX950-LABEL: async_copy_vectorized_8xf16 @@ -75,16 +75,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: async_wait tt.func public @async_wait(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { - // We need the index calculation so AxisAnalysis sees that we can vectorize the load - // The value of the rocdl.waitcnt is explained in the lowering of async_wait - + // The waitcnt stores all counters in one i32 bits 15:14 and 3:0 store the vmcnt we have to wait on // CHECK: rocdl.waitcnt -49168 // CHECK: rocdl.barrier ttg.async_wait {num = 0 : i32} @@ -104,7 +102,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: async_commit_group @@ -121,7 +119,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: async_copy_mask_other @@ -179,7 +177,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [16, 1], order = [1, 0]}> -#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { // GFX942-LABEL: async_copy_cache_mods diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 36f5bec36b5a..cac8148d1154 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -476,8 +476,8 @@ struct AsyncCopyGlobalToLocalOpConversion auto shape = dstTy.getShape(); LinearLayout srcLayout = triton::gpu::toLinearLayout(shape, srcTy.getEncoding()); - LinearLayout sharedLayout = triton::gpu::toLinearLayout( - shape, dstTy.getEncoding(), resElemTy.getIntOrFloatBitWidth()); + LinearLayout sharedLayout = + triton::gpu::toLinearLayout(shape, dstTy.getEncoding()); LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout); StringAttr kLane = rewriter.getStringAttr("lane"); From 5356802bfeed9c6a77a0059e585f9da8d5886012 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 5 Feb 2025 15:45:13 +0000 Subject: [PATCH 29/29] Fix a few small issues --- test/Conversion/amd/async_ops_to_llvm.mlir | 2 +- .../lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/Conversion/amd/async_ops_to_llvm.mlir b/test/Conversion/amd/async_ops_to_llvm.mlir index 7b6d03de70b8..afea228c3e67 100644 --- a/test/Conversion/amd/async_ops_to_llvm.mlir +++ b/test/Conversion/amd/async_ops_to_llvm.mlir @@ -109,7 +109,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar tt.func public @async_commit_group(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { - // CHECK: llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: llvm.return ttg.async_commit_group tt.return diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index cac8148d1154..757802cb912a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -470,9 +470,9 @@ struct AsyncCopyGlobalToLocalOpConversion } // global.load.lds does not support per lane offsets. - // We need to ensure that we write coalesced into shared memory. - // This means that the kLane dim needs to be contigeous based on the - // vectorization size + // We need to ensure that we write coalesced into shared memory. This means + // that the kLane dim needs to be contigeous based on the vectorization + // size. auto shape = dstTy.getShape(); LinearLayout srcLayout = triton::gpu::toLinearLayout(shape, srcTy.getEncoding()); @@ -1665,11 +1665,11 @@ struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern { auto b = TritonLLVMOpBuilder(loc, rewriter); // global.load.lds uses vmcnt to synchronize - // The rocdl op stores all available counters in a single int32 value (v) - // The vmcnt (6 bits) is split into a lower 3:0 and higher part 5:4 - // The lower parts is stored in 3:0 of v and the higher part in bits 15:14 - // We have to set all other bits in v to 1 to signal we are not interested - // in those + // The rocdl op stores all available counters in a single int32 value (v). + // The vmcnt (6 bits) is split into a lower 3:0 and higher 5:4 parts. + // The lower part is stored in bits 3:0 of v and the higher part in bits + // 15:14. We have to set all other bits in v to 1 to signal we are not + // interested in those. int vmCnt = op.getNum(); if (vmCnt >= 64) {