diff --git a/include/triton/Analysis/BufferRegion.h b/include/triton/Analysis/BufferRegion.h index ca66cb4d8a36..ca2e48403ed5 100644 --- a/include/triton/Analysis/BufferRegion.h +++ b/include/triton/Analysis/BufferRegion.h @@ -162,6 +162,8 @@ class BufferRegionAnalysis : public dataflow::SparseForwardDataFlowAnalysis< private: // Global registry of all regions std::set usedBufferRegions[NUM_REGION_TYPES]; + + static void verifyOpIsSupported(Operation *op); }; } // namespace mlir::triton diff --git a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h index 6f13b9c8955a..732544751607 100644 --- a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h +++ b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h @@ -117,25 +117,27 @@ class FunctionBuilder { // from the visibility bitmask. We know this is safe because there cannot be // outstanding writes to this buffer at this point. void createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf, - uint64_t threadMask, Value pred, - MemType memType, Operation *insertPoint); + uint32_t length, uint64_t threadMask, + Value pred, MemType memType, + Operation *insertPoint); // setReadVisibility: add the threads set in threadMask to the buffer's read // visibility bitmask. void createSetReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, - uint64_t threadMask, Value pred, - MemType memType, Operation *insertPoint); + uint32_t length, uint64_t threadMask, + Value pred, MemType memType, + Operation *insertPoint); // clearWriteTracking: clear all the information about threads writing to a // buffer. void createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value buf, - Value pred, MemType memType, - Operation *insertPoint); + uint32_t length, Value pred, + MemType memType, Operation *insertPoint); // clearReadVisibility: clear the read visibility for a buffer. void createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, - Value pred, MemType memType, - Operation *insertPoint); + uint32_t length, Value pred, + MemType memType, Operation *insertPoint); // clearReadTracking: clear the read tracking for a buffer. void createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value buf, - Value pred, MemType memType, + uint32_t length, Value pred, MemType memType, Operation *insertPoint); // trackVisibleWrites: snapshot buffers currently visible to the thread into // the tracking table for a barrier. @@ -160,15 +162,15 @@ class FunctionBuilder { // verifyWriteVisibility: ensure the thread either sees the latest write or no // other thread is writing the buffer. void createVerifyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf, - int thread, StringRef operandName, - Value pred, MemType memType, - Operation *insertPoint); + uint32_t length, int thread, + StringRef operandName, Value pred, + MemType memType, Operation *insertPoint); // verifyReadVisibility: ensure all reads from the buffer are visible to the // thread. void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, - int thread, StringRef operandName, - Value pred, MemType memType, - Operation *insertPoint); + uint32_t length, int thread, + StringRef operandName, Value pred, + MemType memType, Operation *insertPoint); // copyWriteVisibility: replicate the write visibility bit of sourceThread to // every destination thread in destMask. void createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread, @@ -182,7 +184,8 @@ class FunctionBuilder { // stageAccessForCommit: mark the buffer as staged (value -1) in the // outstanding commit table for this thread. void createStageAccessForCommitCall(ImplicitLocOpBuilder &b, Value buf, - int thread, Value pred, MemType memType, + uint32_t length, int thread, Value pred, + MemType memType, CommitKind::Kind commitKind, Operation *insertPoint); // commitAccesses: convert staged entries to 1 and increment outstanding @@ -207,7 +210,7 @@ class FunctionBuilder { // checkOutstandingCommits: assert that the outstanding commit row for the // buffer is zero before the access described by pendingAccessType. void createCheckOutstandingCommitsCall(ImplicitLocOpBuilder &b, Value buf, - int thread, + uint32_t length, int thread, StringRef pendingAccessType, Value pred, MemType memType, CommitKind::Kind commitKind, diff --git a/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td b/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td index ab97ddb890fa..b74c45c33eb1 100644 --- a/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td +++ b/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td @@ -34,30 +34,33 @@ def TTI_ExperimentalAssertInThreadOp : TTI_Op<"experimental_assert_in_thread", [ } -def TTI_ExperimentalBufferPointersOp : TTI_Op<"experimental_buffer_pointers", [Pure]> { - let summary = "definte an array of pointers to shared memory buffers"; +def TTI_ExperimentalBufferDescriptorsOp + : TTI_Op<"experimental_buffer_descriptors", [Pure]> { + let summary = "define an array of buffer descriptors"; let description = [{ - Create a tensor of pointers to shared memory buffers. + Create a tensor of buffer descriptors packing 32-bit pointer offsets and + 32-bit lengths into 64-bit elements. }]; - let arguments = (ins DenseI32ArrayAttr:$offsets, TT_MemTypeAttr:$memType); + let arguments = (ins DenseI32ArrayAttr:$offsets, DenseI32ArrayAttr:$lengths, + TT_MemTypeAttr:$memType); let results = (outs TT_Tensor:$result); let assemblyFormat = [{ - $offsets `,` $memType attr-dict `:` type($result) + $offsets `,` $lengths `,` $memType attr-dict `:` type($result) }]; } -def TTI_ExperimentalMemDescToI64Op : TTI_Op<"experimental_memdesc_to_i64", [Pure]> { - let summary = "Convert a memdesc into its base pointer as i64"; +def TTI_ExperimentalMemDescToI32Op : TTI_Op<"experimental_memdesc_to_i32", [Pure]> { + let summary = "Convert a memdesc into its base pointer as i32"; let description = [{ - Extract the base pointer from the given memdesc and return it as a 64-bit + Extract the base pointer from the given memdesc and return it as a 32-bit integer. This can be used to compare the memdesc against tensors of barrier pointers maintained by the concurrency sanitizer. }]; let arguments = (ins TTG_MemDescType:$memdesc); - let results = (outs I64:$result); + let results = (outs I32:$result); let builders = [ OpBuilder<(ins "Value":$memdesc), [{ - build($_builder, $_state, $_builder.getI64Type(), memdesc); + build($_builder, $_state, $_builder.getI32Type(), memdesc); }]> ]; let assemblyFormat = "$memdesc attr-dict `:` type($memdesc)"; diff --git a/include/triton/Dialect/TritonInstrument/IR/Utility.h b/include/triton/Dialect/TritonInstrument/IR/Utility.h index a7f35c9ba162..4204b65f76f7 100644 --- a/include/triton/Dialect/TritonInstrument/IR/Utility.h +++ b/include/triton/Dialect/TritonInstrument/IR/Utility.h @@ -1,6 +1,7 @@ #ifndef TRITONINSTRUMENT_UTILITY_H #define TRITONINSTRUMENT_UTILITY_H +#include "triton/Analysis/BufferRegion.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonInstrument/IR/Dialect.h" @@ -74,15 +75,17 @@ struct AuxDataMap { RegionToValueMap readVisibility[numMemTypes]; RegionToValueMap readTracking[numMemTypes]; RegionToValueMap commits[CommitKind::NumCommitKinds]; + RegionToValueMap aliasMatrices[numMemTypes]; RegionToValueMap lock; RegionToValueMap waiting; void populateAndPassToWarpSpecialize(ModuleOp module); private: - void getBuffersAndBarriers(ModuleOp module, - SmallVector, 2> &bufValues, - SmallVector &barrierValues); + void getBuffersAndBarriers( + ModuleOp module, + SmallVector, 2> &bufRegions, + SmallVector &barrierRegions); void passToWarpSpecialize(triton::FuncOp func, ValueType value, RegionToValueMap &map); void createInWarpSpecialize( diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td index 6f94d52a833a..9003fa303b66 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td @@ -15,6 +15,9 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> { InterfaceMethod<"Return the A operand.", "::mlir::TypedValue<::mlir::triton::gpu::MemDescType>", "getA">, + InterfaceMethod<"Return the B operand.", + "::mlir::TypedValue<::mlir::triton::gpu::MemDescType>", + "getB">, InterfaceMethod<"Return the accumulator init flag.", "::mlir::Value", "useAccumulator">, @@ -22,6 +25,12 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> { "void", "setUseAccumulator", (ins "::mlir::Value":$flag)>, + InterfaceMethod<"Return the completion barriers of this MMAv5 op.", + "::mlir::ValueRange", + "getCompletionBarriers">, + InterfaceMethod<"Return the completion barrier predicates of this MMAv5 op.", + "::mlir::ValueRange", + "getCompletionBarrierPreds">, InterfaceMethod<"Associate a new completion barrier to this MMAv5 op.", "void", "addCompletionBarrier", diff --git a/lib/Analysis/BufferRegion.cpp b/lib/Analysis/BufferRegion.cpp index 33f93f47222c..86c4d8b32e5b 100644 --- a/lib/Analysis/BufferRegion.cpp +++ b/lib/Analysis/BufferRegion.cpp @@ -2,7 +2,9 @@ #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" namespace ttg = mlir::triton::gpu; namespace ttng = mlir::triton::nvidia_gpu; @@ -33,14 +35,22 @@ uint64_t getAllocationOffset(ttng::TMEMAllocOp op) { return colOffset | (rowOffset << 16); } -unsigned getAllocSize(ttg::LocalAllocOp op) { - ttg::MemDescType ty = op.getType(); +unsigned getMemDescSize(ttg::MemDescType ty) { + if (isa(ty.getMemorySpace())) { + return ttng::getTmemAllocSizes(ty).numCols; + } + assert(isa(ty.getMemorySpace()) && + "Unsupported memory space"); unsigned elSize = ty.getElementType().getIntOrFloatBitWidth() / 8; return product(ty.getShape()) * elSize; } +unsigned getAllocSize(ttg::LocalAllocOp op) { + return getMemDescSize(op.getType()); +} + unsigned getAllocSize(ttng::TMEMAllocOp op) { - return ttng::getTmemAllocSizes(op.getType()).numCols; + return getMemDescSize(op.getType()); } unsigned getNumBuffers(ttg::MemDescIndexOp memdescIndexOp) { @@ -49,22 +59,32 @@ unsigned getNumBuffers(ttg::MemDescIndexOp memdescIndexOp) { return ty.getShape()[0]; } -Value getBarrierOperand(Operation *op) { +llvm::DenseSet getBarrierOperands(Operation *op) { if (auto initBarrierOp = dyn_cast(op)) { - return initBarrierOp.getOperand(); + return {initBarrierOp.getOperand()}; + } + if (auto barrierExpectOp = dyn_cast(op)) { + return {barrierExpectOp.getAlloc()}; + } + if (auto invalBarrierOp = dyn_cast(op)) { + return {invalBarrierOp.getAlloc()}; } if (auto asyncOp = dyn_cast(op)) { - return asyncOp.getBarrier(); + return {asyncOp.getBarrier()}; } if (auto gatherOp = dyn_cast(op)) { - return gatherOp.getBarrier(); + return {gatherOp.getBarrier()}; } - return nullptr; + if (auto mmaV5Op = dyn_cast(op)) { + return llvm::DenseSet(mmaV5Op.getCompletionBarriers().begin(), + mmaV5Op.getCompletionBarriers().end()); + } + return llvm::DenseSet{}; } bool isUsedAsBarrier(Value v) { for (auto user : v.getUsers()) { - if (v == getBarrierOperand(user)) { + if (getBarrierOperands(user).contains(v)) { return true; } } @@ -83,6 +103,67 @@ bool isUsedAsTensorMemory(Value v) { isa_and_nonnull(type.getMemorySpace()); } +uint32_t getMemDescSubsliceByteOffset(ttg::MemDescSubsliceOp op) { + auto srcTy = op.getSrc().getType(); + auto offsets = op.getOffsets(); + if (offsets.empty()) + return 0; + + Attribute encoding = srcTy.getEncoding(); + mlir::triton::LinearLayout layout; + if (auto padded = dyn_cast(encoding)) { + layout = padded.getLinearComponent(); + } else { + layout = ttg::toLinearLayout(srcTy); + } + + MLIRContext *ctx = op->getContext(); + SmallVector dimNames = + mlir::triton::standardOutDimNames(ctx, srcTy.getRank()); + SmallVector> logicalOffsets; + logicalOffsets.reserve(offsets.size()); + for (auto &&[dimName, offset] : llvm::zip_equal(dimNames, offsets)) { + logicalOffsets.push_back({dimName, static_cast(offset)}); + } + + StringAttr offsetDim = StringAttr::get(ctx, "offset"); + layout = layout.sublayout({offsetDim}, dimNames); + mlir::triton::LinearLayout inverse = layout.invert(); + auto mapped = inverse.apply(logicalOffsets); + assert(mapped.size() == 1 && mapped[0].first == offsetDim && + "expected single offset dimension after inversion"); + uint64_t elementOffset = static_cast(mapped[0].second); + + uint64_t elementSizeBytes = + srcTy.getElementType().getIntOrFloatBitWidth() / 8; + assert(elementSizeBytes > 0 && "element size must be non-zero"); + uint64_t byteOffset = elementOffset * elementSizeBytes; + + if (auto padded = dyn_cast(encoding)) { + uint64_t padBytes = 0; + for (auto &&[interval, padding] : + llvm::zip_equal(padded.getIntervals(), padded.getPaddings())) { + if (interval == 0 || padding == 0) + continue; + uint64_t intervalScaled = + static_cast(interval) * elementSizeBytes; + uint64_t paddingScaled = + static_cast(padding) * elementSizeBytes; + assert(llvm::isPowerOf2_64(intervalScaled) && + llvm::isPowerOf2_64(paddingScaled) && + "interval and padding must be powers of two in bytes"); + unsigned intervalLog2 = llvm::Log2_64(intervalScaled); + unsigned paddingLog2 = llvm::Log2_64(paddingScaled); + padBytes += (byteOffset >> intervalLog2) << paddingLog2; + } + byteOffset += padBytes; + } + + assert(byteOffset <= std::numeric_limits::max() && + "memdesc_subslice offset exceeds 32-bit range"); + return static_cast(byteOffset); +} + std::optional getRegionType(Value v) { if (isUsedAsBarrier(v)) { return triton::BufferRegionAnalysis::RegionType::BARRIER; @@ -134,6 +215,7 @@ LogicalResult BufferRegionAnalysis::visitOperation( getOrCreate(getProgramPointBefore(&entry)); propagateIfChanged(exec, exec->setToLive()); } + return success(); } if (auto localAllocOp = dyn_cast(op)) { uint32_t offset = getAllocationOffset(localAllocOp); @@ -143,6 +225,7 @@ LogicalResult BufferRegionAnalysis::visitOperation( for (auto *r : results) { propagateIfChanged(r, r->join(regionInfo)); } + return success(); } if (auto tmemAllocOp = dyn_cast(op)) { uint32_t offset = getAllocationOffset(tmemAllocOp); @@ -152,13 +235,14 @@ LogicalResult BufferRegionAnalysis::visitOperation( for (auto *r : results) { propagateIfChanged(r, r->join(regionInfo)); } + return success(); } if (auto memdescIndexOp = dyn_cast(op)) { RegionInfo in = operands[0]->getValue(); int numSubBuffers = getNumBuffers(memdescIndexOp); for (auto ®ion : in.regions) { for (int i = 0; i < numSubBuffers; i++) { - uint32_t subBufferSize = region.length / numSubBuffers; + uint32_t subBufferSize = getMemDescSize(memdescIndexOp.getType()); regionInfo.regions.insert( {region.baseOffset + i * subBufferSize, subBufferSize}); } @@ -167,7 +251,48 @@ LogicalResult BufferRegionAnalysis::visitOperation( for (auto *r : results) { propagateIfChanged(r, r->join(regionInfo)); } + return success(); } + if (auto memdescSubsliceOp = dyn_cast(op)) { + RegionInfo in = operands[0]->getValue(); + uint32_t subBufferSize = getMemDescSize(memdescSubsliceOp.getType()); + uint32_t relativeOffset = getMemDescSubsliceByteOffset(memdescSubsliceOp); + for (auto ®ion : in.regions) { + regionInfo.regions.insert( + {region.baseOffset + relativeOffset, subBufferSize}); + } + for (auto *r : results) { + propagateIfChanged(r, r->join(regionInfo)); + } + return success(); + } + if (auto tmemSubsliceOp = dyn_cast(op)) { + RegionInfo in = operands[0]->getValue(); + uint32_t subBufferSize = getMemDescSize(tmemSubsliceOp.getType()); + uint32_t relativeOffset = tmemSubsliceOp.getN(); + for (auto ®ion : in.regions) { + regionInfo.regions.insert( + {region.baseOffset + relativeOffset, subBufferSize}); + } + for (auto *r : results) { + propagateIfChanged(r, r->join(regionInfo)); + } + return success(); + } + // "Passthrough" ops that don't modify the buffer regions. + if (isa(op)) { + // Just propagate the regions from the operand. + RegionInfo in = operands[0]->getValue(); + for (auto ®ion : in.regions) { + regionInfo.regions.insert(region); + } + for (auto *r : results) { + propagateIfChanged(r, r->join(regionInfo)); + } + return success(); + } + verifyOpIsSupported(op); return success(); } @@ -225,7 +350,8 @@ bool BufferRegionAnalysis::isMemoryAccessOperation(Operation *op) { if (isa( + ttng::AsyncTMAGatherOp, ttng::AsyncTMAScatterOp, ttng::InitBarrierOp, + ttng::BarrierExpectOp, ttng::InvalBarrierOp, ttng::WaitBarrierOp>( op)) { return true; } @@ -240,4 +366,20 @@ bool BufferRegionAnalysis::isMemoryAccessOperation(Operation *op) { return false; } +void BufferRegionAnalysis::verifyOpIsSupported(Operation *op) { + bool hasMemoryOperands = llvm::any_of(op->getOperands(), [](Value v) { + return isUsedAsSharedMemory(v) || isUsedAsTensorMemory(v); + }); + if (!hasMemoryOperands) { + return; + } + if (isMemoryAccessOperation(op)) { + return; + } + op->emitError( + "Operation accessing memory unaccounted for in buffer region analysis"); + llvm::report_fatal_error( + "Operation accessing memory unaccounted for in buffer region analysis"); +} + } // namespace mlir::triton diff --git a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp index 75ef836873e6..63540a042e25 100644 --- a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp +++ b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp @@ -23,12 +23,13 @@ namespace ttng = mlir::triton::nvidia_gpu; // Utility functions //////////////////////////////////////////// -Value createMemDescToI64(RewriterBase &rewriter, Location loc, +Value createMemDescToI32(RewriterBase &rewriter, Location loc, const LLVMTypeConverter *typeConverter, ttg::MemDescType memDescTy, Value sharedMemStruct) { TritonLLVMOpBuilder b(loc, rewriter); + auto i32Ty = rewriter.getIntegerType(32); if (isa(memDescTy.getMemorySpace())) { - return b.ptrtoint(rewriter.getIntegerType(64), sharedMemStruct); + return b.ptrtoint(i32Ty, sharedMemStruct); } assert(isa(memDescTy.getEncoding()) && "Unsupported memory encoding"); @@ -38,9 +39,7 @@ Value createMemDescToI64(RewriterBase &rewriter, Location loc, auto offset = smemObj.getShmemOffset(loc, rewriter, memDescTy); auto elemSize = srcElemTy.getIntOrFloatBitWidth() / 8; offset = b.mul(offset, b.i32_val(elemSize)); - auto i64Ty = rewriter.getIntegerType(64); - offset = b.zext(i64Ty, offset); - return b.add(offset, b.ptrtoint(i64Ty, smemObj.getBase())); + return b.add(offset, b.ptrtoint(i32Ty, smemObj.getBase())); } std::tuple @@ -163,48 +162,77 @@ struct AssertInThreadOpConversion const TargetInfoBase &targetInfo; }; -struct BufferPointersOpConversion - : public ConvertOpToLLVMPattern { +struct BufferDescriptorsOpConversion + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(tti::ExperimentalBufferPointersOp op, OpAdaptor adaptor, + matchAndRewrite(tti::ExperimentalBufferDescriptorsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - auto module = op->getParentOfType(); - auto values = adaptor.getOffsets(); auto encoding = cast(op.getResult().getType().getEncoding()); - auto bufPointers = - createInitializedIntArrayTensor(rewriter, loc, encoding, values); - Value base = nullptr; + auto offsets = adaptor.getOffsets(); + auto lengths = adaptor.getLengths(); + assert(offsets.size() == lengths.size() && "Mismatched descriptor arrays"); + + auto tensorType = cast(op.getResult().getType()); + + SmallVector offsetVals; + offsetVals.reserve(offsets.size()); + for (int32_t offset : offsets) + offsetVals.push_back(static_cast(offset)); + Value pointerTensor = + createInitializedIntArrayTensor(rewriter, loc, encoding, offsetVals); + + TritonLLVMOpBuilder b(loc, rewriter); + auto i64Ty = rewriter.getIntegerType(64); + Value baseTensor = nullptr; if (op.getMemType() == tti::MemType::SHARED_MEM) { - base = getSharedMemoryBase(rewriter, - op->getParentOfType()); + auto func = op->getParentOfType(); + Value base = getSharedMemoryBase(rewriter, func); + baseTensor = triton::SplatOp::create(rewriter, loc, tensorType, base); } else { assert(op.getMemType() == tti::MemType::TENSOR_MEM && "Unsupported memory type"); - TritonLLVMOpBuilder b(loc, rewriter); - base = nvgpu::TensorMemoryBaseAddress::create(rewriter, loc); - base = b.ptrtoint(i32_ty, base); + Value basePtr = nvgpu::TensorMemoryBaseAddress::create(rewriter, loc); + Value base = b.ptrtoint(i64Ty, basePtr); + baseTensor = triton::SplatOp::create(rewriter, loc, tensorType, base); } - bufPointers = arith::AddIOp::create( - rewriter, loc, bufPointers, - triton::SplatOp::create(rewriter, loc, bufPointers.getType(), base)); - rewriter.replaceOp(op, bufPointers); + + pointerTensor = arith::AddIOp::create( + rewriter, loc, pointerTensor.getType(), pointerTensor, baseTensor); + + SmallVector maskVals(offsets.size(), 0xffffffffu); + Value maskTensor = + createInitializedIntArrayTensor(rewriter, loc, encoding, maskVals); + Value trimmedPointers = arith::AndIOp::create( + rewriter, loc, pointerTensor.getType(), pointerTensor, maskTensor); + + SmallVector lengthVals; + lengthVals.reserve(lengths.size()); + for (int32_t length : lengths) + lengthVals.push_back(static_cast(static_cast(length)) + << 32); + Value lengthTensor = + createInitializedIntArrayTensor(rewriter, loc, encoding, lengthVals); + + auto bufDescriptors = + arith::OrIOp::create(rewriter, loc, trimmedPointers.getType(), + trimmedPointers, lengthTensor); + rewriter.replaceOp(op, bufDescriptors); return success(); } Value createInitializedIntArrayTensor(OpBuilder &builder, Location loc, BlockedEncodingAttr encoding, - ArrayRef values) const { + ArrayRef values) const { int64_t size = values.size(); assert(llvm::isPowerOf2_64(size) && "Expected power of 2"); auto tensorType = RankedTensorType::get({size}, builder.getIntegerType(64), encoding); SmallVector apInts = llvm::to_vector( - llvm::map_range(values, [](int32_t v) { return APInt(64, v); })); + llvm::map_range(values, [](uint64_t v) { return APInt(64, v); })); auto denseAttr = DenseElementsAttr::get(tensorType, apInts); return arith::ConstantOp::create(builder, loc, tensorType, denseAttr); } @@ -212,12 +240,10 @@ struct BufferPointersOpConversion Value getSharedMemoryBase(ConversionPatternRewriter &rewriter, FunctionOpInterface func) const { Location loc = func.getLoc(); - Value base = LLVM::getStackPointer(rewriter, func); - // Bitcast to i64 + Value basePtr = LLVM::getStackPointer(rewriter, func); auto i64Ty = rewriter.getIntegerType(64); TritonLLVMOpBuilder b(loc, rewriter); - base = b.ptrtoint(i64Ty, base); - return base; + return b.ptrtoint(i64Ty, basePtr); } }; @@ -307,18 +333,18 @@ struct LockReleaseOpConversion } }; -struct MemDescToI64OpConversion - : public ConvertOpToLLVMPattern { +struct MemDescToI32OpConversion + : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern< - tti::ExperimentalMemDescToI64Op>::ConvertOpToLLVMPattern; + tti::ExperimentalMemDescToI32Op>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(tti::ExperimentalMemDescToI64Op op, OpAdaptor adaptor, + matchAndRewrite(tti::ExperimentalMemDescToI32Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value converted = - createMemDescToI64(rewriter, loc, getTypeConverter(), + createMemDescToI32(rewriter, loc, getTypeConverter(), op.getMemdesc().getType(), adaptor.getMemdesc()); rewriter.replaceOp(op, converted); return success(); @@ -331,8 +357,8 @@ void mlir::triton::populateInstrumentationToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, targetInfo, benefit); - patterns.add(typeConverter); + patterns.add(typeConverter); patterns.add(typeConverter); patterns.add(typeConverter); - patterns.add(typeConverter); + patterns.add(typeConverter); } diff --git a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp index c54e238653ca..8cbb1f0edc7d 100644 --- a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp +++ b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp @@ -7,10 +7,12 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonInstrument/IR/Dialect.h" #include "triton/Dialect/TritonInstrument/IR/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" namespace mlir::triton::instrument { namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; namespace tti = mlir::triton::instrument; namespace { @@ -141,6 +143,28 @@ void createCallToCachedFunction( } } +Value createBufferDescriptor(ImplicitLocOpBuilder &b, Value offsetI32, + Value lengthI32) { + auto i64Type = b.getI64Type(); + Value offsetI64 = arith::ExtUIOp::create(b, i64Type, offsetI32); + Value lengthI64 = arith::ExtUIOp::create(b, i64Type, lengthI32); + Value shiftAmount = arith::ConstantIntOp::create(b, 32, 64); + Value lengthShifted = arith::ShLIOp::create(b, lengthI64, shiftAmount); + return arith::OrIOp::create(b, lengthShifted, offsetI64); +} + +uint32_t getMemDescLength(Value buf) { + auto memDescType = cast(buf.getType()); + if (isa(memDescType.getEncoding())) { + unsigned elSize = memDescType.getElementType().getIntOrFloatBitWidth() / 8; + return static_cast(product(memDescType.getShape()) * elSize); + } + if (isa(memDescType.getMemorySpace())) { + return ttng::getTmemAllocSizes(memDescType).numCols; + } + llvm_unreachable("Unsupported memory space for memdesc"); +} + std::tuple createIfBlock(ImplicitLocOpBuilder &b, Value cnd) { // #prevBlock @@ -188,6 +212,18 @@ Value createConvertLayout(ImplicitLocOpBuilder &b, Value tensor, return ttg::ConvertLayoutOp::create(b, dstType, tensor); } +Value expandAliases(ImplicitLocOpBuilder &b, Value bufferMask, + Value aliasMatrix, RankedTensorType aliasMatrixType) { + assert(aliasMatrixType.getRank() == 2 && + "Alias matrix expected to be rank-2"); + auto bufferMaskType = cast(bufferMask.getType()); + Value bufMaskMatrix = + convertAndBroadcast(b, bufferMask, /*dim=*/1, aliasMatrixType); + Value aliasingMask = arith::AndIOp::create(b, aliasMatrix, bufMaskMatrix); + Value aliasVector = createBitwiseOrReduce(b, aliasingMask, /*axis=*/0); + return createConvertLayout(b, aliasVector, bufferMaskType.getEncoding()); +} + Value createOneHot(ImplicitLocOpBuilder &b, int size, int index, Attribute encoding) { auto loc = b.getLoc(); @@ -298,27 +334,32 @@ void FunctionBuilder::createSetWaitingCall(ImplicitLocOpBuilder &b, Value mbar, Value waitingVal = auxData.waiting.at(insertPoint).value; auto waitingType = cast(auxData.waiting.at(insertPoint).type); - Value mbarI64 = tti::ExperimentalMemDescToI64Op::create(b, mbar); - SmallVector args = {mbarI64, threadVal, phase, - pred, barriersVal, waitingVal}; + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, threadVal, phase, + pred, barriersVal, waitingVal}; createCallToCachedFunction( b, "set_waiting", args, /*assertInfo=*/std::nullopt, {barriersType, waitingType}, [barriersType, waitingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value mbarI64 = entryBlock->getArgument(0); - Value baseThread = entryBlock->getArgument(1); - Value phase = entryBlock->getArgument(2); - Value pred = entryBlock->getArgument(3); + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value baseThread = entryBlock->getArgument(2); + Value phase = entryBlock->getArgument(3); + Value pred = entryBlock->getArgument(4); - Value barriers = entryBlock->getArgument(4); - Value waitingPtr = entryBlock->getArgument(5); + Value barriers = entryBlock->getArgument(5); + Value waitingPtr = entryBlock->getArgument(6); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); Value waiting = tti::createLoadScratchMemory(fb, fb.getLoc(), waitingPtr, waitingType); - Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, mbarI64); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value barriersEqBar = + createCmpIntTensorScalar(fb, barriers, descriptor); Value bitsPerThread = arith::ConstantIntOp::create(fb, WaitingBits::bitsPerThread, 32); @@ -392,25 +433,31 @@ void FunctionBuilder::createClearWaitingCall(ImplicitLocOpBuilder &b, auto waitingType = cast(auxData.waiting.at(insertPoint).type); - Value mbarI64 = tti::ExperimentalMemDescToI64Op::create(b, mbar); - SmallVector args = {mbarI64, threadVal, pred, barriersVal, waitingVal}; + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, threadVal, + pred, barriersVal, waitingVal}; createCallToCachedFunction( b, "clear_waiting", args, /*assertInfo=*/std::nullopt, {barriersType, waitingType}, [barriersType, waitingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value mbarI64 = entryBlock->getArgument(0); - Value baseThread = entryBlock->getArgument(1); - Value pred = entryBlock->getArgument(2); + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value baseThread = entryBlock->getArgument(2); + Value pred = entryBlock->getArgument(3); - Value barriers = entryBlock->getArgument(3); - Value waitingPtr = entryBlock->getArgument(4); + Value barriers = entryBlock->getArgument(4); + Value waitingPtr = entryBlock->getArgument(5); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); Value waiting = tti::createLoadScratchMemory(fb, fb.getLoc(), waitingPtr, waitingType); - Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, mbarI64); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value barriersEqBar = + createCmpIntTensorScalar(fb, barriers, descriptor); Value bitsPerThread = arith::ConstantIntOp::create(fb, WaitingBits::bitsPerThread, 32); @@ -545,22 +592,27 @@ void FunctionBuilder::createInitBarrierStateCall(ImplicitLocOpBuilder &b, Value barrierStatesVal = auxData.barrierStates.at(insertPoint).value; auto barrierStatesType = cast(auxData.barrierStates.at(insertPoint).type); - Value mbarI64 = tti::ExperimentalMemDescToI64Op::create(b, mbar); - SmallVector args = {mbarI64, countVal, barriersVal, barrierStatesVal}; + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, countVal, barriersVal, + barrierStatesVal}; createCallToCachedFunction( b, "init_barrier_state", args, /*assertInfo=*/std::nullopt, {barriersType, barrierStatesType}, [barriersType, barrierStatesType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value mbarI64 = entryBlock->getArgument(0); - Value count = entryBlock->getArgument(1); + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value count = entryBlock->getArgument(2); - Value barriers = entryBlock->getArgument(2); - Value statesPtr = entryBlock->getArgument(3); + Value barriers = entryBlock->getArgument(3); + Value statesPtr = entryBlock->getArgument(4); Value states = tti::createLoadScratchMemory(fb, fb.getLoc(), statesPtr, barrierStatesType); - Value mask = createCmpIntTensorScalar(fb, barriers, mbarI64); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); Value countMask = arith::ConstantIntOp::create(fb, BarrierBits::countMask, 32); @@ -604,9 +656,11 @@ void FunctionBuilder::createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, Value barrierStatesVal = auxData.barrierStates.at(insertPoint).value; auto barrierStatesType = cast(auxData.barrierStates.at(insertPoint).type); - Value mbarI64 = tti::ExperimentalMemDescToI64Op::create(b, mbar); - SmallVector args = {mbarI64, countVal, pred, barriersVal, - barrierStatesVal}; + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, countVal, + pred, barriersVal, barrierStatesVal}; AssertInfo assertInfo{ "Barrier arrive underflow: current count would become negative", barrierStatesType.cloneWith(std::nullopt, b.getI1Type())}; @@ -615,16 +669,18 @@ void FunctionBuilder::createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, {barriersType, barrierStatesType}, [barriersType, barrierStatesType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value mbarI64 = entryBlock->getArgument(0); - Value count = entryBlock->getArgument(1); - Value pred = entryBlock->getArgument(2); + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value count = entryBlock->getArgument(2); + Value pred = entryBlock->getArgument(3); - Value barriers = entryBlock->getArgument(3); - Value statesPtr = entryBlock->getArgument(4); + Value barriers = entryBlock->getArgument(4); + Value statesPtr = entryBlock->getArgument(5); Value states = tti::createLoadScratchMemory(fb, fb.getLoc(), statesPtr, barrierStatesType); - Value mask = createCmpIntTensorScalar(fb, barriers, mbarI64); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); Value zero32 = tti::createConstIntTensor(fb, fb.getLoc(), 0, barrierStatesType); @@ -675,27 +731,31 @@ void FunctionBuilder::createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, Value barrierStatesVal = auxData.barrierStates.at(insertPoint).value; auto barrierStatesType = cast(auxData.barrierStates.at(insertPoint).type); - Value mbarI64 = tti::ExperimentalMemDescToI64Op::create(b, mbar); - SmallVector args = {mbarI64, countVal, pred, barriersVal, - barrierStatesVal}; + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, countVal, + pred, barriersVal, barrierStatesVal}; createCallToCachedFunction( b, "update_barrier_state", args, /*assertInfo=*/std::nullopt, {barriersType, barrierStatesType}, [barriersType, barrierStatesType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value mbarI64 = entryBlock->getArgument(0); - Value count = entryBlock->getArgument(1); - Value pred = entryBlock->getArgument(2); + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value count = entryBlock->getArgument(2); + Value pred = entryBlock->getArgument(3); - Value barriers = entryBlock->getArgument(3); - Value statesPtr = entryBlock->getArgument(4); + Value barriers = entryBlock->getArgument(4); + Value statesPtr = entryBlock->getArgument(5); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); Value states = tti::createLoadScratchMemory(fb, fb.getLoc(), statesPtr, barrierStatesType); - Value mask = createCmpIntTensorScalar(fb, barriers, mbarI64); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); Value zero32 = tti::createConstIntTensor(fb, fb.getLoc(), 0, barrierStatesType); @@ -750,7 +810,7 @@ void FunctionBuilder::createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, } void FunctionBuilder::createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, - Value buf, + Value buf, uint32_t length, uint64_t threadMask, Value pred, MemType memType, Operation *insertPoint) { @@ -769,27 +829,30 @@ void FunctionBuilder::createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, auxData.writeVisibility[(int)memType].at(insertPoint).value; auto writeVisibilityType = cast( auxData.writeVisibility[(int)memType].at(insertPoint).type); - Value bufI64 = tti::ExperimentalMemDescToI64Op::create(b, buf); - SmallVector args = {bufI64, pred, threadMaskVal, buffersVal, - writeVisibilityVal}; + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, pred, + threadMaskVal, buffersVal, writeVisibilityVal}; createCallToCachedFunction( b, "set_write_visibility", args, /*assertInfo=*/std::nullopt, {buffersType, writeVisibilityType, (int)memType}, [buffersType, writeVisibilityType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value bufI64 = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value threadMaskVal = entryBlock->getArgument(2); - Value buffers = entryBlock->getArgument(3); - Value writeVisibilityPtr = entryBlock->getArgument(4); + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadMaskVal = entryBlock->getArgument(3); + Value buffers = entryBlock->getArgument(4); + Value writeVisibilityPtr = entryBlock->getArgument(5); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); Value writeVisibility = tti::createLoadScratchMemory( fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityType); - Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, bufI64); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); auto elemType = cast(writeVisibilityType.getElementType()); Value threadMaskElem = adjustIntegerWidth(fb, threadMaskVal, elemType); Value threadMaskTensor = @@ -805,13 +868,14 @@ void FunctionBuilder::createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, } void FunctionBuilder::createSetReadVisibilityCall(ImplicitLocOpBuilder &b, - Value buf, + Value buf, uint32_t length, uint64_t threadMask, Value pred, MemType memType, Operation *insertPoint) { if (auxData.buffers[(int)memType].empty() || - auxData.readVisibility[(int)memType].empty()) { + auxData.readVisibility[(int)memType].empty() || + auxData.aliasMatrices[(int)memType].empty()) { return; } if (!pred) @@ -824,27 +888,30 @@ void FunctionBuilder::createSetReadVisibilityCall(ImplicitLocOpBuilder &b, auxData.readVisibility[(int)memType].at(insertPoint).value; auto readVisibilityType = cast( auxData.readVisibility[(int)memType].at(insertPoint).type); - Value bufI64 = tti::ExperimentalMemDescToI64Op::create(b, buf); - SmallVector args = {bufI64, pred, threadMaskVal, buffersVal, - readVisibilityVal}; + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, pred, + threadMaskVal, buffersVal, readVisibilityVal}; createCallToCachedFunction( b, "set_read_visibility", args, /*assertInfo=*/std::nullopt, {buffersType, readVisibilityType, (int)memType}, [buffersType, readVisibilityType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value bufI64 = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value threadMaskVal = entryBlock->getArgument(2); - Value buffers = entryBlock->getArgument(3); - Value readVisibilityPtr = entryBlock->getArgument(4); + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadMaskVal = entryBlock->getArgument(3); + Value buffers = entryBlock->getArgument(4); + Value readVisibilityPtr = entryBlock->getArgument(5); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); Value readVisibility = tti::createLoadScratchMemory( fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); - Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, bufI64); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); buffersEqBuf = convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, readVisibilityType); auto elemType = cast(readVisibilityType.getElementType()); @@ -868,8 +935,8 @@ void FunctionBuilder::createSetReadVisibilityCall(ImplicitLocOpBuilder &b, } void FunctionBuilder::createClearWriteTrackingCall(ImplicitLocOpBuilder &b, - Value buf, Value pred, - MemType memType, + Value buf, uint32_t length, + Value pred, MemType memType, Operation *insertPoint) { if (auxData.buffers[(int)memType].empty() || auxData.writeTracking[(int)memType].empty()) { @@ -884,25 +951,29 @@ void FunctionBuilder::createClearWriteTrackingCall(ImplicitLocOpBuilder &b, auxData.writeTracking[(int)memType].at(insertPoint).value; auto writeTrackingType = cast( auxData.writeTracking[(int)memType].at(insertPoint).type); - Value bufI64 = tti::ExperimentalMemDescToI64Op::create(b, buf); - SmallVector args = {bufI64, pred, buffersVal, writeTrackingVal}; + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, pred, buffersVal, + writeTrackingVal}; createCallToCachedFunction( b, "clear_write_tracking", args, /*assertInfo=*/std::nullopt, {buffersType, writeTrackingType, (int)memType}, [buffersType, writeTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value bufI64 = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value buffers = entryBlock->getArgument(2); - Value writeTrackingPtr = entryBlock->getArgument(3); + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value buffers = entryBlock->getArgument(3); + Value writeTrackingPtr = entryBlock->getArgument(4); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); Value writeTracking = tti::createLoadScratchMemory( fb, fb.getLoc(), writeTrackingPtr, writeTrackingType); - Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, bufI64); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); buffersEqBuf = convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, writeTrackingType); Value zero = @@ -918,8 +989,8 @@ void FunctionBuilder::createClearWriteTrackingCall(ImplicitLocOpBuilder &b, } void FunctionBuilder::createClearReadVisibilityCall(ImplicitLocOpBuilder &b, - Value buf, Value pred, - MemType memType, + Value buf, uint32_t length, + Value pred, MemType memType, Operation *insertPoint) { if (auxData.buffers[(int)memType].empty() || auxData.readVisibility[(int)memType].empty()) { @@ -934,25 +1005,29 @@ void FunctionBuilder::createClearReadVisibilityCall(ImplicitLocOpBuilder &b, auxData.readVisibility[(int)memType].at(insertPoint).value; auto readVisibilityType = cast( auxData.readVisibility[(int)memType].at(insertPoint).type); - Value bufI64 = tti::ExperimentalMemDescToI64Op::create(b, buf); - SmallVector args = {bufI64, pred, buffersVal, readVisibilityVal}; + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, pred, buffersVal, + readVisibilityVal}; createCallToCachedFunction( b, "clear_read_visibility", args, /*assertInfo=*/std::nullopt, {buffersType, readVisibilityType, (int)memType}, [buffersType, readVisibilityType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value bufI64 = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value buffers = entryBlock->getArgument(2); - Value readVisibilityPtr = entryBlock->getArgument(3); + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value buffers = entryBlock->getArgument(3); + Value readVisibilityPtr = entryBlock->getArgument(4); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); Value readVisibility = tti::createLoadScratchMemory( fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); - Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, bufI64); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); buffersEqBuf = convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, readVisibilityType); Value zero = @@ -968,8 +1043,8 @@ void FunctionBuilder::createClearReadVisibilityCall(ImplicitLocOpBuilder &b, } void FunctionBuilder::createClearReadTrackingCall(ImplicitLocOpBuilder &b, - Value buf, Value pred, - MemType memType, + Value buf, uint32_t length, + Value pred, MemType memType, Operation *insertPoint) { if (auxData.buffers[(int)memType].empty() || @@ -985,25 +1060,29 @@ void FunctionBuilder::createClearReadTrackingCall(ImplicitLocOpBuilder &b, auxData.readTracking[(int)memType].at(insertPoint).value; auto readTrackingType = cast( auxData.readTracking[(int)memType].at(insertPoint).type); - Value bufI64 = tti::ExperimentalMemDescToI64Op::create(b, buf); - SmallVector args = {bufI64, pred, buffersVal, readTrackingVal}; + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, pred, buffersVal, + readTrackingVal}; createCallToCachedFunction( b, "clear_read_tracking", args, /*assertInfo=*/std::nullopt, {buffersType, readTrackingType, (int)memType}, [buffersType, readTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value bufI64 = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value buffers = entryBlock->getArgument(2); - Value readTrackingPtr = entryBlock->getArgument(3); + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value buffers = entryBlock->getArgument(3); + Value readTrackingPtr = entryBlock->getArgument(4); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); Value readTracking = tti::createLoadScratchMemory( fb, fb.getLoc(), readTrackingPtr, readTrackingType); - Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, bufI64); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); buffersEqBuf = convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, readTrackingType); Value zero = @@ -1041,22 +1120,25 @@ void FunctionBuilder::createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, auxData.writeTracking[(int)memType].at(insertPoint).value; auto writeTrackingType = cast( auxData.writeTracking[(int)memType].at(insertPoint).type); - Value mbarI64 = tti::ExperimentalMemDescToI64Op::create(b, mbar); - SmallVector args = { - mbarI64, pred, threadVal, barriersVal, writeVisibilityVal, - writeTrackingVal}; + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, pred, + threadVal, barriersVal, writeVisibilityVal, + writeTrackingVal}; createCallToCachedFunction( b, "track_visible_writes", args, /*assertInfo=*/std::nullopt, {barriersType, writeVisibilityType, writeTrackingType, (int)memType}, [barriersType, writeVisibilityType, writeTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value bar = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value threadVal = entryBlock->getArgument(2); - Value barriers = entryBlock->getArgument(3); - Value writeVisibilityPtr = entryBlock->getArgument(4); - Value writeTrackingPtr = entryBlock->getArgument(5); + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadVal = entryBlock->getArgument(3); + Value barriers = entryBlock->getArgument(4); + Value writeVisibilityPtr = entryBlock->getArgument(5); + Value writeTrackingPtr = entryBlock->getArgument(6); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -1065,7 +1147,9 @@ void FunctionBuilder::createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityType); Value writeTracking = tti::createLoadScratchMemory( fb, fb.getLoc(), writeTrackingPtr, writeTrackingType); - Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, bar); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value barriersEqBar = + createCmpIntTensorScalar(fb, barriers, descriptor); barriersEqBar = convertAndBroadcast(fb, barriersEqBar, /*dim=*/0, writeTrackingType); Value threadI64 = @@ -1118,22 +1202,25 @@ void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, auxData.readTracking[(int)memType].at(insertPoint).value; auto readTrackingType = cast( auxData.readTracking[(int)memType].at(insertPoint).type); - Value mbarI64 = tti::ExperimentalMemDescToI64Op::create(b, mbar); - SmallVector args = {mbarI64, pred, - threadVal, barriersVal, - readVisibilityVal, readTrackingVal}; + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, pred, + threadVal, barriersVal, readVisibilityVal, + readTrackingVal}; createCallToCachedFunction( b, "track_visible_reads", args, /*assertInfo=*/std::nullopt, {barriersType, readVisibilityType, readTrackingType, (int)memType}, [barriersType, readVisibilityType, readTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value bar = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value threadVal = entryBlock->getArgument(2); - Value barriers = entryBlock->getArgument(3); - Value readVisibilityPtr = entryBlock->getArgument(4); - Value readTrackingPtr = entryBlock->getArgument(5); + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadVal = entryBlock->getArgument(3); + Value barriers = entryBlock->getArgument(4); + Value readVisibilityPtr = entryBlock->getArgument(5); + Value readTrackingPtr = entryBlock->getArgument(6); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -1142,7 +1229,9 @@ void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); Value readTracking = tti::createLoadScratchMemory( fb, fb.getLoc(), readTrackingPtr, readTrackingType); - Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, bar); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value barriersEqBar = + createCmpIntTensorScalar(fb, barriers, descriptor); barriersEqBar = convertAndBroadcast(fb, barriersEqBar, /*dim=*/0, readTrackingType); Value threadColumnMask = @@ -1189,22 +1278,25 @@ void FunctionBuilder::createTransferVisibleWritesCall( auxData.writeTracking[(int)memType].at(insertPoint).value; auto writeTrackingType = cast( auxData.writeTracking[(int)memType].at(insertPoint).type); - Value mbarI64 = tti::ExperimentalMemDescToI64Op::create(b, mbar); - SmallVector args = { - mbarI64, pred, threadMaskVal, barriersVal, writeVisibilityVal, - writeTrackingVal}; + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, pred, + threadMaskVal, barriersVal, writeVisibilityVal, + writeTrackingVal}; createCallToCachedFunction( b, "transfer_visible_writes", args, /*assertInfo=*/std::nullopt, {barriersType, writeVisibilityType, writeTrackingType, (int)memType}, [barriersType, writeVisibilityType, writeTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value bar = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value threadMaskVal = entryBlock->getArgument(2); - Value barriers = entryBlock->getArgument(3); - Value writeVisibilityPtr = entryBlock->getArgument(4); - Value writeTrackingPtr = entryBlock->getArgument(5); + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadMaskVal = entryBlock->getArgument(3); + Value barriers = entryBlock->getArgument(4); + Value writeVisibilityPtr = entryBlock->getArgument(5); + Value writeTrackingPtr = entryBlock->getArgument(6); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -1213,7 +1305,9 @@ void FunctionBuilder::createTransferVisibleWritesCall( fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityType); Value writeTracking = tti::createLoadScratchMemory( fb, fb.getLoc(), writeTrackingPtr, writeTrackingType); - Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, bar); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value barriersEqBar = + createCmpIntTensorScalar(fb, barriers, descriptor); barriersEqBar = convertAndBroadcast(fb, barriersEqBar, /*dim=*/0, writeTrackingType); Value zeroTracking = @@ -1271,22 +1365,25 @@ void FunctionBuilder::createTransferVisibleReadsCall( auxData.readTracking[(int)memType].at(insertPoint).value; auto readTrackingType = cast( auxData.readTracking[(int)memType].at(insertPoint).type); - Value mbarI64 = tti::ExperimentalMemDescToI64Op::create(b, mbar); - SmallVector args = {mbarI64, pred, - threadMaskVal, barriersVal, - readVisibilityVal, readTrackingVal}; + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, pred, + threadMaskVal, barriersVal, readVisibilityVal, + readTrackingVal}; createCallToCachedFunction( b, "transfer_visible_reads", args, /*assertInfo=*/std::nullopt, {barriersType, readVisibilityType, readTrackingType, (int)memType}, [barriersType, readVisibilityType, readTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value bar = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value threadMaskVal = entryBlock->getArgument(2); - Value barriers = entryBlock->getArgument(3); - Value readVisibilityPtr = entryBlock->getArgument(4); - Value readTrackingPtr = entryBlock->getArgument(5); + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadMaskVal = entryBlock->getArgument(3); + Value barriers = entryBlock->getArgument(4); + Value readVisibilityPtr = entryBlock->getArgument(5); + Value readTrackingPtr = entryBlock->getArgument(6); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -1295,7 +1392,9 @@ void FunctionBuilder::createTransferVisibleReadsCall( fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); Value readTracking = tti::createLoadScratchMemory( fb, fb.getLoc(), readTrackingPtr, readTrackingType); - Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, bar); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value barriersEqBar = + createCmpIntTensorScalar(fb, barriers, descriptor); barriersEqBar = convertAndBroadcast(fb, barriersEqBar, /*dim=*/0, readTrackingType); Value readTrackingZero = @@ -1320,10 +1419,12 @@ void FunctionBuilder::createTransferVisibleReadsCall( } void FunctionBuilder::createVerifyWriteVisibilityCall( - ImplicitLocOpBuilder &b, Value buf, int thread, StringRef operandName, - Value pred, MemType memType, Operation *insertPoint) { + ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, + StringRef operandName, Value pred, MemType memType, + Operation *insertPoint) { if (auxData.buffers[(int)memType].empty() || - auxData.writeVisibility[(int)memType].empty()) { + auxData.writeVisibility[(int)memType].empty() || + auxData.aliasMatrices[(int)memType].empty()) { return; } if (!pred) @@ -1336,9 +1437,15 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( auxData.writeVisibility[(int)memType].at(insertPoint).value; auto writeVisibilityType = cast( auxData.writeVisibility[(int)memType].at(insertPoint).type); - Value bufI64 = tti::ExperimentalMemDescToI64Op::create(b, buf); - SmallVector args = {bufI64, pred, threadVal, buffersVal, - writeVisibilityVal}; + Value aliasMatrixVal = + auxData.aliasMatrices[(int)memType].at(insertPoint).value; + auto aliasMatrixType = cast( + auxData.aliasMatrices[(int)memType].at(insertPoint).type); + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, pred, + threadVal, buffersVal, writeVisibilityVal, + aliasMatrixVal}; std::string message = "Buffer being accessed has outstanding writes."; if (!operandName.empty()) message += " Operand: " + operandName.str(); @@ -1346,18 +1453,23 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( buffersType.cloneWith(std::nullopt, b.getI1Type())}; createCallToCachedFunction( b, "verify_write_visibility", args, assertInfo, - {buffersType, writeVisibilityType, (int)memType}, - [buffersType, writeVisibilityType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { - Value bufI64 = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value threadVal = entryBlock->getArgument(2); - Value buffers = entryBlock->getArgument(3); - Value writeVisibilityPtr = entryBlock->getArgument(4); + {buffersType, writeVisibilityType, aliasMatrixType, (int)memType}, + [buffersType, writeVisibilityType, + aliasMatrixType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadVal = entryBlock->getArgument(3); + Value buffers = entryBlock->getArgument(4); + Value writeVisibilityPtr = entryBlock->getArgument(5); + Value aliasMatrix = entryBlock->getArgument(6); Value writeVisibility = tti::createLoadScratchMemory( fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityType); - Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, bufI64); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); + buffersEqBuf = + expandAliases(fb, buffersEqBuf, aliasMatrix, aliasMatrixType); Value writeVisibilityZero = tti::createConstIntTensor(fb, fb.getLoc(), 0, writeVisibilityType); Value bufVisibility = arith::SelectOp::create( @@ -1388,8 +1500,9 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( } void FunctionBuilder::createVerifyReadVisibilityCall( - ImplicitLocOpBuilder &b, Value buf, int thread, StringRef operandName, - Value pred, MemType memType, Operation *insertPoint) { + ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, + StringRef operandName, Value pred, MemType memType, + Operation *insertPoint) { if (auxData.buffers[(int)memType].empty() || auxData.readVisibility[(int)memType].empty()) { return; @@ -1404,9 +1517,15 @@ void FunctionBuilder::createVerifyReadVisibilityCall( auxData.readVisibility[(int)memType].at(insertPoint).value; auto readVisibilityType = cast( auxData.readVisibility[(int)memType].at(insertPoint).type); - Value bufI64 = tti::ExperimentalMemDescToI64Op::create(b, buf); - SmallVector args = {bufI64, pred, threadVal, buffersVal, - readVisibilityVal}; + Value aliasMatrixVal = + auxData.aliasMatrices[(int)memType].at(insertPoint).value; + auto aliasMatrixType = cast( + auxData.aliasMatrices[(int)memType].at(insertPoint).type); + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, pred, + threadVal, buffersVal, readVisibilityVal, + aliasMatrixVal}; std::string message = "Buffer being accessed has outstanding reads"; if (!operandName.empty()) message += ". Operand: " + operandName.str(); @@ -1414,18 +1533,23 @@ void FunctionBuilder::createVerifyReadVisibilityCall( buffersType.cloneWith(std::nullopt, b.getI1Type())}; createCallToCachedFunction( b, "verify_read_visibility", args, assertInfo, - {buffersType, readVisibilityType, (int)memType}, - [buffersType, readVisibilityType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { - Value bufI64 = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value threadVal = entryBlock->getArgument(2); - Value buffers = entryBlock->getArgument(3); - Value readVisibilityPtr = entryBlock->getArgument(4); + {buffersType, readVisibilityType, aliasMatrixType, (int)memType}, + [buffersType, readVisibilityType, + aliasMatrixType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadVal = entryBlock->getArgument(3); + Value buffers = entryBlock->getArgument(4); + Value readVisibilityPtr = entryBlock->getArgument(5); + Value aliasMatrix = entryBlock->getArgument(6); Value readVisibility = tti::createLoadScratchMemory( fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); - Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, bufI64); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); + buffersEqBuf = + expandAliases(fb, buffersEqBuf, aliasMatrix, aliasMatrixType); buffersEqBuf = convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, readVisibilityType); Value readVisibilityZero = @@ -1590,8 +1714,8 @@ void FunctionBuilder::createCopyReadVisibilityCall(ImplicitLocOpBuilder &b, } void FunctionBuilder::createStageAccessForCommitCall( - ImplicitLocOpBuilder &b, Value buf, int thread, Value pred, MemType memType, - CommitKind::Kind commitKind, Operation *insertPoint) { + ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, Value pred, + MemType memType, CommitKind::Kind commitKind, Operation *insertPoint) { if (auxData.buffers[(int)memType].empty() || auxData.commits[commitKind].empty()) { return; @@ -1603,18 +1727,21 @@ void FunctionBuilder::createStageAccessForCommitCall( auto buffersType = cast(buffers.type); auto commitsType = cast(outstandingCommits.type); Value threadVal = arith::ConstantIntOp::create(b, thread, 32); - Value bufI64 = tti::ExperimentalMemDescToI64Op::create(b, buf); - SmallVector args = {bufI64, pred, threadVal, buffers.value, - outstandingCommits.value}; + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, + pred, threadVal, + buffers.value, outstandingCommits.value}; createCallToCachedFunction( b, "stage_access_for_commit", args, /*assertInfo=*/std::nullopt, {buffersType, commitsType}, [buffersType, commitsType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value bufI64 = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value threadVal = entryBlock->getArgument(2); - Value buffers = entryBlock->getArgument(3); - Value outstandingCommitsPtr = entryBlock->getArgument(4); + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadVal = entryBlock->getArgument(3); + Value buffers = entryBlock->getArgument(4); + Value outstandingCommitsPtr = entryBlock->getArgument(5); (void)threadVal; @@ -1623,7 +1750,8 @@ void FunctionBuilder::createStageAccessForCommitCall( Value commits = tti::createLoadScratchMemory( fb, fb.getLoc(), outstandingCommitsPtr, commitsType); - Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, bufI64); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); buffersEqBuf = convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, commitsType); Value threadColumnMask = createColumnMask(fb, threadVal, commitsType); @@ -1863,25 +1991,31 @@ void FunctionBuilder::createClearOutstandingCommitsTransferReadsCall( } void FunctionBuilder::createCheckOutstandingCommitsCall( - ImplicitLocOpBuilder &b, Value buf, int thread, StringRef pendingAccessType, - Value pred, MemType memType, CommitKind::Kind commitKind, - Operation *insertPoint) { + ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, + StringRef pendingAccessType, Value pred, MemType memType, + CommitKind::Kind commitKind, Operation *insertPoint) { if (auxData.buffers[(int)memType].empty() || - auxData.commits[commitKind].empty()) { + auxData.commits[commitKind].empty() || + auxData.aliasMatrices[(int)memType].empty()) { return; } ValueType buffers = auxData.buffers[(int)memType].at(insertPoint); ValueType outstandingCommits = auxData.commits[commitKind].at(insertPoint); + ValueType aliasMatrix = auxData.aliasMatrices[(int)memType].at(insertPoint); assert(thread < NUM_THREADS && "Commit-count tracking must operate on base threads"); - Value bufI64 = tti::ExperimentalMemDescToI64Op::create(b, buf); + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); if (!pred) pred = arith::ConstantIntOp::create(b, 1, 1); auto buffersType = cast(buffers.type); auto commitsType = cast(outstandingCommits.type); Value threadVal = arith::ConstantIntOp::create(b, thread, 32); - SmallVector args = {bufI64, pred, threadVal, buffers.value, - outstandingCommits.value}; + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + auto aliasMatrixType = cast(aliasMatrix.type); + SmallVector args = { + bufOffset, lengthVal, pred, + threadVal, buffers.value, outstandingCommits.value, + aliasMatrix.value}; std::string message = "Accessing buffer with pending access. Pending access type: " + pendingAccessType.str(); @@ -1889,17 +2023,23 @@ void FunctionBuilder::createCheckOutstandingCommitsCall( commitsType.cloneWith(std::nullopt, b.getI1Type())}; createCallToCachedFunction( b, "check_outstanding_commits", args, assertInfo, - {buffersType, commitsType, (int)thread}, - [buffersType, commitsType](ImplicitLocOpBuilder &fb, Block *entryBlock) { - Value bufI64 = entryBlock->getArgument(0); - Value pred = entryBlock->getArgument(1); - Value threadVal = entryBlock->getArgument(2); - Value buffers = entryBlock->getArgument(3); - Value outstandingCommitsPtr = entryBlock->getArgument(4); + {buffersType, commitsType, aliasMatrixType, (int)thread}, + [buffersType, commitsType, aliasMatrixType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadVal = entryBlock->getArgument(3); + Value buffers = entryBlock->getArgument(4); + Value outstandingCommitsPtr = entryBlock->getArgument(5); + Value aliasMatrix = entryBlock->getArgument(6); Value outstandingCommits = tti::createLoadScratchMemory( fb, fb.getLoc(), outstandingCommitsPtr, commitsType); - Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, bufI64); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); + buffersEqBuf = + expandAliases(fb, buffersEqBuf, aliasMatrix, aliasMatrixType); buffersEqBuf = convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, commitsType); Value zeroTensor = diff --git a/lib/Dialect/TritonInstrument/IR/Utility.cpp b/lib/Dialect/TritonInstrument/IR/Utility.cpp index 1b104d606360..dbc4848135a0 100644 --- a/lib/Dialect/TritonInstrument/IR/Utility.cpp +++ b/lib/Dialect/TritonInstrument/IR/Utility.cpp @@ -14,6 +14,7 @@ using namespace mlir::triton; using namespace mlir::triton::gpu; using namespace mlir::triton::nvidia_gpu; using namespace mlir::triton::instrument; +using mlir::triton::BufferRegion; namespace { @@ -59,19 +60,54 @@ RankedTensorType getIntTensorType(Region *region, ArrayRef shape, } std::pair -createBufferPointersTensor(ImplicitLocOpBuilder &builder, MemType memType, - SmallVector values) { - int64_t size = values.size(); +createBufferDescriptorsTensor(ImplicitLocOpBuilder &builder, MemType memType, + ArrayRef regions) { + int64_t size = regions.size(); assert(llvm::isPowerOf2_64(size) && "Expected power of 2"); auto tensorType = getIntTensorType(builder.getInsertionBlock()->getParent(), {size}, 64); - auto valuesI32 = llvm::to_vector(llvm::map_range( - values, [](uint32_t v) { return static_cast(v); })); - return {ExperimentalBufferPointersOp::create(builder, tensorType, valuesI32, - memType), + SmallVector offsets; + SmallVector lengths; + offsets.reserve(size); + lengths.reserve(size); + for (const auto ®ion : regions) { + offsets.push_back(static_cast(region.baseOffset)); + lengths.push_back(static_cast(region.length)); + } + return {ExperimentalBufferDescriptorsOp::create(builder, tensorType, offsets, + lengths, memType), tensorType}; } +SmallVector> +createAliasingMatrix(ArrayRef regions) { + SmallVector> matrix; + size_t numRegions = regions.size(); + matrix.resize(numRegions); + for (size_t i = 0; i < numRegions; ++i) + matrix[i].assign(numRegions, /*Value=*/0); + + for (size_t i = 0; i < numRegions; ++i) { + uint64_t startI = regions[i].baseOffset; + uint64_t endI = startI + regions[i].length; + if (regions[i].length == 0) + continue; + // Include self-aliasing + for (size_t j = i; j < numRegions; ++j) { + uint64_t startJ = regions[j].baseOffset; + uint64_t endJ = startJ + regions[j].length; + if (regions[j].length == 0) + continue; + bool alias = (startI < endJ) && (startJ < endI); + if (alias) { + matrix[i][j] = 1; + matrix[j][i] = 1; + } + } + } + return matrix; +} + Value createInitializedScratchMemory(ImplicitLocOpBuilder &b, TypedValue tensor) { Type elType = tensor.getType().getElementType(); @@ -97,6 +133,30 @@ Value createZeroInitStateTensor(ImplicitLocOpBuilder &b, int m, int n, return createInitializedScratchMemory(b, tensor); } +TypedValue +createAliasMatrixTensor(ImplicitLocOpBuilder &b, + ArrayRef> matrix, Region *region) { + size_t rows = matrix.size(); + if (rows == 0) + return {}; + size_t cols = matrix.front().size(); + for (const auto &row : matrix) + assert(row.size() == cols && "Expected square alias matrix"); + + auto type = getIntTensorType( + region, {static_cast(rows), static_cast(cols)}, + /*bitWidth=*/1); + SmallVector values; + values.reserve(rows * cols); + for (const auto &row : matrix) + for (uint8_t v : row) + values.emplace_back(/*numBits=*/1, v); + + auto denseAttr = DenseElementsAttr::get(type, values); + Value constValue = arith::ConstantOp::create(b, b.getLoc(), type, denseAttr); + return cast>(constValue); +} + bool hasCpAsync(ModuleOp module) { bool hasCpAsync = false; module.walk([&](Operation *op) { @@ -266,8 +326,8 @@ Region *AuxDataMap::RegionToValueMap::getEnclosingParitionOrFunctionRegion( if (isa(region->getParentOp())) { ModuleOp module = op->getParentOfType(); assert(getEntryPoint(module) == region->getParentOp() && - "For now we support" - " only one function in the module"); + "Concurrency sanitizer supports only one instrumented " + "function in the module"); return region; } region = region->getParentRegion(); @@ -277,9 +337,9 @@ Region *AuxDataMap::RegionToValueMap::getEnclosingParitionOrFunctionRegion( } void AuxDataMap::populateAndPassToWarpSpecialize(ModuleOp module) { - SmallVector, 2> bufValues(numMemTypes); - SmallVector barrierValues; - getBuffersAndBarriers(module, bufValues, barrierValues); + SmallVector, numMemTypes> bufRegions(numMemTypes); + SmallVector barrierRegions; + getBuffersAndBarriers(module, bufRegions, barrierRegions); FuncOp entryPoint = getEntryPoint(module); assert(entryPoint); @@ -290,21 +350,37 @@ void AuxDataMap::populateAndPassToWarpSpecialize(ModuleOp module) { for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { int iMemType = (int)memType; - if (bufValues[iMemType].empty()) { + if (bufRegions[iMemType].empty()) { continue; } buffers[iMemType].insert( entryRegion, - {createBufferPointersTensor(b, memType, bufValues[iMemType])}); - // Buffer pointers are rematerialized in the warp specialize region, + {createBufferDescriptorsTensor(b, memType, bufRegions[iMemType])}); + // Buffer descriptors are rematerialized in the warp specialize region, // not passed as an argument. createInWarpSpecialize( entryPoint, buffers[iMemType], [&](ImplicitLocOpBuilder &b) { return ValueType{ - createBufferPointersTensor(b, memType, bufValues[iMemType])}; + createBufferDescriptorsTensor(b, memType, bufRegions[iMemType])}; }); - int numBufs = bufValues[iMemType].size(); + int numBufs = bufRegions[iMemType].size(); + + auto aliasMatrixData = createAliasingMatrix(bufRegions[iMemType]); + if (!aliasMatrixData.empty()) { + auto aliasTensor = + createAliasMatrixTensor(b, aliasMatrixData, entryRegion); + aliasMatrices[iMemType].insert(entryRegion, + {aliasTensor, aliasTensor.getType()}); + createInWarpSpecialize( + entryPoint, aliasMatrices[iMemType], + [aliasMatrixData](ImplicitLocOpBuilder &nestedBuilder) { + Region *region = nestedBuilder.getInsertionBlock()->getParent(); + auto tensor = + createAliasMatrixTensor(nestedBuilder, aliasMatrixData, region); + return ValueType{tensor, tensor.getType()}; + }); + } writeVisibility[iMemType].insert( entryRegion, {createZeroInitStateTensor(b, numBufs, 0, 64), @@ -319,18 +395,18 @@ void AuxDataMap::populateAndPassToWarpSpecialize(ModuleOp module) { readVisibility[iMemType]); } - if (!barrierValues.empty()) { + if (!barrierRegions.empty()) { // Barriers allocations are in shared memory - barriers.insert(entryRegion, {createBufferPointersTensor( - b, MemType::SHARED_MEM, barrierValues)}); + barriers.insert(entryRegion, {createBufferDescriptorsTensor( + b, MemType::SHARED_MEM, barrierRegions)}); // Barriers allocations are rematerialized in the warp specialize region, // not passed as an argument. createInWarpSpecialize(entryPoint, barriers, [&](ImplicitLocOpBuilder &b) { - return ValueType{ - createBufferPointersTensor(b, MemType::SHARED_MEM, barrierValues)}; + return ValueType{createBufferDescriptorsTensor(b, MemType::SHARED_MEM, + barrierRegions)}; }); - int numBarriers = barrierValues.size(); + int numBarriers = barrierRegions.size(); barrierStates.insert(entryRegion, {createZeroInitStateTensor(b, numBarriers, 0, 32), getIntTensorType(entryRegion, {numBarriers}, 32)}); @@ -347,8 +423,8 @@ void AuxDataMap::populateAndPassToWarpSpecialize(ModuleOp module) { for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { int iMemType = (int)memType; // Create state tensors: - int numBufs = bufValues[iMemType].size(); - int numBarriers = barrierValues.size(); + int numBufs = bufRegions[iMemType].size(); + int numBarriers = barrierRegions.size(); if (numBufs > 0) { writeTracking[iMemType].insert( entryRegion, @@ -373,7 +449,7 @@ void AuxDataMap::populateAndPassToWarpSpecialize(ModuleOp module) { passToWarpSpecialize(entryPoint, lock.at(entryRegion), lock); auto createCommitTensor = [&](CommitKind::Kind commitKind) { - int numBufs = bufValues[(int)MemType::SHARED_MEM].size(); + int numBufs = bufRegions[(int)MemType::SHARED_MEM].size(); if (numBufs == 0) return; // NUM_THREADS instead of THREADS_BITMASK_SIZE as commit-count tracking @@ -402,8 +478,8 @@ void AuxDataMap::populateAndPassToWarpSpecialize(ModuleOp module) { } void AuxDataMap::getBuffersAndBarriers( - ModuleOp module, SmallVector, 2> &bufValues, - SmallVector &barrierValues) { + ModuleOp module, SmallVector, 2> &bufRegions, + SmallVector &barrierRegions) { // Collect shared memory buffers allocated in the module std::unique_ptr solver = createDataFlowSolver(); triton::BufferRegionAnalysis *analysis = @@ -412,30 +488,26 @@ void AuxDataMap::getBuffersAndBarriers( return; analysis->calculateUsedBufferRegions(module); - bufValues[(int)MemType::SHARED_MEM] = llvm::to_vector(llvm::map_range( - analysis->getAllUsedBufferRegions( - BufferRegionAnalysis::RegionType::SHARED_MEMORY), - [](const BufferRegion ®ion) { return region.baseOffset; })); - bufValues[(int)MemType::TENSOR_MEM] = llvm::to_vector(llvm::map_range( - analysis->getAllUsedBufferRegions( - BufferRegionAnalysis::RegionType::TENSOR_MEMORY), - [](const BufferRegion ®ion) { return region.baseOffset; })); - barrierValues = llvm::to_vector(llvm::map_range( - analysis->getAllUsedBufferRegions( - BufferRegionAnalysis::RegionType::BARRIER), - [](const BufferRegion ®ion) { return region.baseOffset; })); - - if (!barrierValues.empty()) { - barrierValues.resize(llvm::NextPowerOf2(barrierValues.size() - 1), 0); + bufRegions[(int)MemType::SHARED_MEM] = analysis->getAllUsedBufferRegions( + BufferRegionAnalysis::RegionType::SHARED_MEMORY); + bufRegions[(int)MemType::TENSOR_MEM] = analysis->getAllUsedBufferRegions( + BufferRegionAnalysis::RegionType::TENSOR_MEMORY); + barrierRegions = analysis->getAllUsedBufferRegions( + BufferRegionAnalysis::RegionType::BARRIER); + + if (!barrierRegions.empty()) { + barrierRegions.resize(llvm::NextPowerOf2(barrierRegions.size() - 1), + BufferRegion{0, 0}); } for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { int iMemType = (int)memType; - if (bufValues[iMemType].empty()) { + if (bufRegions[iMemType].empty()) { continue; } - bufValues[iMemType].resize( - llvm::NextPowerOf2(bufValues[iMemType].size() - 1), 0); + bufRegions[iMemType].resize( + llvm::NextPowerOf2(bufRegions[iMemType].size() - 1), + BufferRegion{0, 0}); } } diff --git a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp index 14c9aff29909..ab4e2ea575ee 100644 --- a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp @@ -137,6 +137,18 @@ int getActiveMask(Operation *op) { return activeMask; } +uint32_t getMemDescLength(Value buf) { + auto memDescType = cast(buf.getType()); + if (isa(memDescType.getEncoding())) { + unsigned elSize = memDescType.getElementType().getIntOrFloatBitWidth() / 8; + return static_cast(product(memDescType.getShape()) * elSize); + } + if (isa(memDescType.getMemorySpace())) { + return ttng::getTmemAllocSizes(memDescType).numCols; + } + llvm_unreachable("Unsupported memory space for memdesc"); +} + } // namespace class ConcurrencySanitizerPass @@ -264,6 +276,11 @@ class ConcurrencySanitizerPass enum RW { Read, Write } rw; Value buf; std::string operandName = ""; + uint32_t length = 0; + + Effects(RW rw, Value buf, std::string operandName = "") + : rw(rw), buf(buf), operandName(operandName), + length(getMemDescLength(buf)) {} }; struct BarrierInfo { Value barrier; @@ -309,38 +326,45 @@ class ConcurrencySanitizerPass if (effect.rw == MemEffectsOpInfo::Effects::Read) { // For op that is reading, we only need to check if anything else // is writing to the same buffer. - addWriteChecks(b, funcBuilder, op, buf, pred, memType, thread, - effect.operandName); + addWriteChecks(b, funcBuilder, op, buf, effect.length, pred, memType, + thread, effect.operandName); if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::Barrier) { - funcBuilder.createSetReadVisibilityCall( - b, buf, getThreadPeersMask(thread), pred, memType, op); + funcBuilder.createSetReadVisibilityCall(b, buf, effect.length, + getThreadPeersMask(thread), + pred, memType, op); } if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::CommitCount) { assert(memType == MemType::SHARED_MEM); - funcBuilder.createStageAccessForCommitCall( - b, buf, baseThread, pred, memType, opInfo->commitKind, op); + funcBuilder.createStageAccessForCommitCall(b, buf, effect.length, + baseThread, pred, memType, + opInfo->commitKind, op); } } if (effect.rw == MemEffectsOpInfo::Effects::Write) { // Op is writing to the buffer, we need to check if anything else // is reading or writing to the same buffer. - addWriteChecks(b, funcBuilder, op, buf, pred, memType, thread, - effect.operandName); - addReadChecks(b, funcBuilder, op, buf, pred, memType, thread, - effect.operandName); + addWriteChecks(b, funcBuilder, op, buf, effect.length, pred, memType, + thread, effect.operandName); + addReadChecks(b, funcBuilder, op, buf, effect.length, pred, memType, + thread, effect.operandName); if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::Barrier) { - funcBuilder.createSetWriteVisibilityCall( - b, buf, getThreadPeersMask(thread), pred, memType, op); - funcBuilder.createClearWriteTrackingCall(b, buf, pred, memType, op); - funcBuilder.createClearReadVisibilityCall(b, buf, pred, memType, op); - funcBuilder.createClearReadTrackingCall(b, buf, pred, memType, op); + funcBuilder.createSetWriteVisibilityCall(b, buf, effect.length, + getThreadPeersMask(thread), + pred, memType, op); + funcBuilder.createClearWriteTrackingCall(b, buf, effect.length, pred, + memType, op); + funcBuilder.createClearReadVisibilityCall(b, buf, effect.length, pred, + memType, op); + funcBuilder.createClearReadTrackingCall(b, buf, effect.length, pred, + memType, op); } if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::CommitCount) { assert(memType == MemType::SHARED_MEM); - funcBuilder.createStageAccessForCommitCall( - b, buf, baseThread, pred, memType, opInfo->commitKind, op); + funcBuilder.createStageAccessForCommitCall(b, buf, effect.length, + baseThread, pred, memType, + opInfo->commitKind, op); } } } @@ -372,31 +396,32 @@ class ConcurrencySanitizerPass void addWriteChecks(ImplicitLocOpBuilder &b, tti::FunctionBuilder &funcBuilder, Operation *op, - Value buf, Value pred, MemType memType, int thread, - const std::string &operandName) { - funcBuilder.createVerifyWriteVisibilityCall(b, buf, thread, operandName, - pred, memType, op); + Value buf, uint32_t length, Value pred, MemType memType, + int thread, const std::string &operandName) { + funcBuilder.createVerifyWriteVisibilityCall(b, buf, length, thread, + operandName, pred, memType, op); // commit-num-based synchronization is only supported for shared memory if (memType == MemType::SHARED_MEM) { funcBuilder.createCheckOutstandingCommitsCall( - b, buf, getBaseThread(thread), "async_copy_global_to_shared", pred, - memType, CommitKind::AsyncCp, op); + b, buf, length, getBaseThread(thread), "async_copy_global_to_shared", + pred, memType, CommitKind::AsyncCp, op); } } void addReadChecks(ImplicitLocOpBuilder &b, tti::FunctionBuilder &funcBuilder, - Operation *op, Value buf, Value pred, MemType memType, - int thread, const std::string &operandName) { - funcBuilder.createVerifyReadVisibilityCall(b, buf, thread, operandName, - pred, memType, op); + Operation *op, Value buf, uint32_t length, Value pred, + MemType memType, int thread, + const std::string &operandName) { + funcBuilder.createVerifyReadVisibilityCall(b, buf, length, thread, + operandName, pred, memType, op); // commit-num-based synchronization is only supported for shared memory if (memType == MemType::SHARED_MEM) { funcBuilder.createCheckOutstandingCommitsCall( - b, buf, getBaseThread(thread), "warpgroup_mma operand read", pred, - memType, CommitKind::Wgmma, op); + b, buf, length, getBaseThread(thread), "warpgroup_mma operand read", + pred, memType, CommitKind::Wgmma, op); funcBuilder.createCheckOutstandingCommitsCall( - b, buf, getBaseThread(thread), "async_copy_shared_to_global", pred, - memType, CommitKind::TmaStore, op); + b, buf, length, getBaseThread(thread), "async_copy_shared_to_global", + pred, memType, CommitKind::TmaStore, op); } } @@ -407,97 +432,93 @@ class ConcurrencySanitizerPass info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; info->pred = copyOp.getPred(); info->barriers.push_back({copyOp.getBarrier(), nullptr, 1}); - info->operandEffects.push_back({/*.rw =*/MemEffectsOpInfo::Effects::Write, - /*.buf =*/copyOp.getResult()}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + copyOp.getResult()); } if (auto storeOp = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::CommitCount; info->commitKind = CommitKind::TmaStore; info->implicitCommit = true; - info->operandEffects.push_back({/*.rw =*/MemEffectsOpInfo::Effects::Read, - /*.buf =*/storeOp.getSrc()}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + storeOp.getSrc()); } if (auto gatherOp = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; info->pred = gatherOp.getPred(); info->barriers.push_back({gatherOp.getBarrier(), nullptr, 1}); - info->operandEffects.push_back({/*.rw =*/MemEffectsOpInfo::Effects::Write, - /*.buf =*/gatherOp.getResult()}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + gatherOp.getResult()); } if (auto scatterOp = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::None; - info->operandEffects.push_back({/*.rw =*/MemEffectsOpInfo::Effects::Read, - /*.buf =*/scatterOp.getSrc()}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + scatterOp.getSrc()); } if (auto copyOp = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::CommitCount; info->commitKind = CommitKind::AsyncCp; - info->operandEffects.push_back({/*.rw =*/MemEffectsOpInfo::Effects::Write, - /*.buf =*/copyOp.getResult()}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + copyOp.getResult()); } if (auto loadOp = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; - info->operandEffects.push_back({/*.rw =*/MemEffectsOpInfo::Effects::Read, - /*.buf =*/loadOp.getSrc()}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + loadOp.getSrc()); } if (auto storeOp = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; - info->operandEffects.push_back({/*.rw =*/MemEffectsOpInfo::Effects::Write, - /*.buf =*/storeOp.getDst()}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + storeOp.getDst()); } if (auto allocOp = dyn_cast(op)) { if (allocOp.getSrc()) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; - info->operandEffects.push_back( - {/*.rw =*/MemEffectsOpInfo::Effects::Write, - /*.buf =*/allocOp.getResult()}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + allocOp.getResult()); } } if (auto loadOp = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; - info->operandEffects.push_back({/*.rw =*/MemEffectsOpInfo::Effects::Read, - /*.buf =*/loadOp.getSrc()}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + loadOp.getSrc()); } if (auto storeOp = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; - info->operandEffects.push_back({/*.rw =*/MemEffectsOpInfo::Effects::Write, - /*.buf =*/storeOp.getDst()}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + storeOp.getDst()); } if (auto allocOp = dyn_cast(op)) { if (allocOp.getSrc()) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; - info->operandEffects.push_back( - {/*.rw =*/MemEffectsOpInfo::Effects::Write, - /*.buf =*/allocOp.getResult()}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + allocOp.getResult()); } } - if (auto mmav5Op = dyn_cast(op)) { + if (auto mmav5Op = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; - info->pred = mmav5Op.getPred(); + info->pred = mmav5Op.getPredicate(); for (auto [barrier, barrierPred] : - llvm::zip(mmav5Op.getBarriers(), mmav5Op.getBarrierPreds())) { + llvm::zip(mmav5Op.getCompletionBarriers(), + mmav5Op.getCompletionBarrierPreds())) { info->barriers.push_back({barrier, barrierPred, 1}); } - info->operandEffects.push_back({/*.rw =*/MemEffectsOpInfo::Effects::Read, - /*.buf =*/mmav5Op.getA(), - /*.operandName =*/"A"}); - info->operandEffects.push_back({/*.rw =*/MemEffectsOpInfo::Effects::Read, - /*.buf =*/mmav5Op.getB(), - /*.operandName =*/"B"}); - info->operandEffects.push_back({/*.rw =*/MemEffectsOpInfo::Effects::Write, - /*.buf =*/mmav5Op.getAccumulator(), - /*.operandName =*/"Acc"}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + mmav5Op.getA(), "A"); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + mmav5Op.getB(), "B"); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + mmav5Op.getAccumulator(), "Acc"); } if (auto commitOp = dyn_cast(op)) { info.emplace(); @@ -521,17 +542,13 @@ class ConcurrencySanitizerPass info->barriers = {}; if (isa( wgmmaOp.getA().getType().getEncoding())) { - info->operandEffects.emplace_back(MemEffectsOpInfo::Effects{ - /*.rw =*/MemEffectsOpInfo::Effects::Read, - /*.buf =*/wgmmaOp.getA(), - /*.operandName =*/"A"}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + wgmmaOp.getA(), "A"); } if (isa( wgmmaOp.getB().getType().getEncoding())) { - info->operandEffects.emplace_back(MemEffectsOpInfo::Effects{ - /*.rw =*/MemEffectsOpInfo::Effects::Read, - /*.buf =*/wgmmaOp.getB(), - /*.operandName =*/"B"}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + wgmmaOp.getB(), "B"); } } } diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index e104967cd753..de8f833d9c6b 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -529,6 +529,11 @@ void TCGen5MMAOp::setUseAccumulator(Value flag) { getUseDMutable().assign(flag); } +ValueRange TCGen5MMAOp::getCompletionBarriers() { return getBarriers(); } +ValueRange TCGen5MMAOp::getCompletionBarrierPreds() { + return getBarrierPreds(); +} + void TCGen5MMAOp::addCompletionBarrier(Value barrier, Value pred) { getBarrierPredsMutable().append(pred); getBarriersMutable().append(barrier); @@ -654,6 +659,11 @@ void TCGen5MMAScaledOp::setUseAccumulator(Value flag) { getUseDMutable().assign(flag); } +ValueRange TCGen5MMAScaledOp::getCompletionBarriers() { return getBarriers(); } +ValueRange TCGen5MMAScaledOp::getCompletionBarrierPreds() { + return getBarrierPreds(); +} + void TCGen5MMAScaledOp::addCompletionBarrier(Value barrier, Value pred) { getBarrierPredsMutable().append(pred); getBarriersMutable().append(barrier); diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 7d91ba1679bc..cdd6d567e6fa 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -1711,3 +1711,174 @@ def kernel(): ], [4], [32]) kernel[(1, )](num_warps=4) + + +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper") +@pytest.mark.parametrize("MISSING_BAR", [True, False]) +@pytest.mark.parametrize("OVERLAP", [True, False]) +def test_aliasing_shared_visibility_outstanding_write(MISSING_BAR, OVERLAP, device, run_wrapper, monkeypatch): + if run_wrapper: + result = run_in_process(test_aliasing_shared_visibility_outstanding_write, + (MISSING_BAR, OVERLAP, device, False, monkeypatch)) + if MISSING_BAR and OVERLAP: + assert "device-side assert" in str(result.exc) + assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output + else: + assert result.exc is None + assert result.driver_stderr_output == "" + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + @gluon.jit + def writer(alias0: ttgl.constexpr, bar: ttgl.constexpr, OVERLAP: ttgl.constexpr, blocked_layout: ttgl.constexpr): + SIZE_N: ttgl.constexpr = XBLOCK * 2 if OVERLAP else XBLOCK + vals = ttgl.full([XBLOCK, SIZE_N], 42.0, ttgl.float16, blocked_layout) + alias0.store(vals) + mbarrier.arrive(bar.index(0), count=1) + + @gluon.jit + def reader(alias1: ttgl.constexpr, dummy: ttgl.constexpr, bar: ttgl.constexpr, MISSING_BAR: ttgl.constexpr, + blocked_layout: ttgl.constexpr): + if not MISSING_BAR: + mbarrier.wait(bar.index(0), phase=0) + val = alias1.load(blocked_layout) + dummy.store(val) # keep the load alive + + @gluon.jit + def kernel(MISSING_BAR: ttgl.constexpr, OVERLAP: ttgl.constexpr): + smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[0, 1]) + blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1]) + smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK * 2], smem_layout) + smem2 = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1, 1], mbarrier.MBarrierLayout()) + mbarrier.init(bar.index(0), count=1) + alias0 = smem if OVERLAP else smem.slice(0, XBLOCK, dim=1) + alias1 = smem.slice(XBLOCK, XBLOCK, dim=1) + + ttgl.warp_specialize([(writer, (alias0, bar, OVERLAP, blocked_layout)), + (reader, (alias1, smem2, bar, MISSING_BAR, blocked_layout))], [4], [32]) + + kernel[(1, )](MISSING_BAR=MISSING_BAR, OVERLAP=OVERLAP, num_warps=4) + + +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer") +@pytest.mark.parametrize("FAILURE", [True, False]) +def test_aliasing_tensor_visibility_outstanding_read(FAILURE, device, run_wrapper, monkeypatch): + if run_wrapper: + result = run_in_process(test_aliasing_tensor_visibility_outstanding_read, (FAILURE, device, False, monkeypatch)) + if FAILURE: + assert "device-side assert" in str(result.exc) + # outstanding reads or writes depends on the timing of the operations. + assert "Buffer being accessed has outstanding" in result.driver_stderr_output + else: + assert result.exc is None + assert result.driver_stderr_output == "" + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + @gluon.jit + def reader(alias0: ttgl.constexpr, smem: ttgl.constexpr, bar: ttgl.constexpr, blocked_layout: ttgl.constexpr): + val = alias0.load(blocked_layout) + smem.store(val) # keep the load alive + mbarrier.arrive(bar.index(0), count=1) + + @gluon.jit + def writer(alias1: ttgl.constexpr, bar: ttgl.constexpr, FAILURE: ttgl.constexpr, blocked_layout: ttgl.constexpr): + if not FAILURE: + mbarrier.wait(bar.index(0), phase=0) + alias1.store(ttgl.zeros([XBLOCK, XBLOCK], ttgl.float32, blocked_layout)) + + @gluon.jit + def kernel(FAILURE: ttgl.constexpr): + smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[0, 1]) + blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1]) + smem = ttgl.allocate_shared_memory(ttgl.float32, [XBLOCK, XBLOCK], smem_layout) + tmem_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([XBLOCK, XBLOCK * 2], col_stride=1) + tmem = blackwell.allocate_tensor_memory(ttgl.float32, [XBLOCK, XBLOCK * 2], tmem_layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1, 1], mbarrier.MBarrierLayout()) + mbarrier.init(bar.index(0), count=1) + alias0 = tmem.slice(0, XBLOCK) + alias1 = tmem.slice(XBLOCK // 2, XBLOCK) + + ttgl.warp_specialize([(reader, (alias0, smem, bar, blocked_layout)), + (writer, (alias1, bar, FAILURE, blocked_layout))], [4], [32]) + + kernel[(1, )](FAILURE=FAILURE, num_warps=4) + + +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper") +@pytest.mark.parametrize("MISSING_WAIT", [True, False]) +@pytest.mark.parametrize("OVERLAP", [True, False]) +def test_aliasing_commit_tracking(MISSING_WAIT, OVERLAP, device, run_wrapper, monkeypatch): + if run_wrapper: + result = run_in_process(test_aliasing_commit_tracking, (MISSING_WAIT, OVERLAP, device, False, monkeypatch)) + if MISSING_WAIT and OVERLAP: + assert "device-side assert" in str(result.exc) + assert "Accessing buffer with pending access. Pending access type: async_copy_global_to_shared" in result.driver_stderr_output + else: + assert result.exc is None + assert result.driver_stderr_output == "" + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + @gluon.jit + def producer(input, alias0, bar, MISSING_WAIT: ttgl.constexpr, OVERLAP: ttgl.constexpr, + blocked_layout: ttgl.constexpr): + SIZE_N: ttgl.constexpr = XBLOCK * 2 if OVERLAP else XBLOCK + offs_m = ttgl.arange(0, XBLOCK, layout=ttgl.SliceLayout(dim=1, parent=blocked_layout))[:, None] + offs_n = ttgl.arange(0, SIZE_N, layout=ttgl.SliceLayout(dim=0, parent=blocked_layout))[None, :] + offs = offs_m * XBLOCK + offs_n + ampere.async_copy.async_copy_global_to_shared(alias0, input + offs) + ampere.async_copy.commit_group() + if not MISSING_WAIT: + ampere.async_copy.wait_group(0) + mbarrier.arrive(bar.index(0), count=1) + + @gluon.jit + def consumer(alias1, bar, blocked_layout: ttgl.constexpr): + mbarrier.wait(bar.index(0), phase=0) + alias1.store(ttgl.zeros([XBLOCK, XBLOCK], ttgl.float32, blocked_layout)) + + @gluon.jit + def kernel(input, MISSING_WAIT: ttgl.constexpr, OVERLAP: ttgl.constexpr): + smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[0, 1]) + blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1]) + smem = ttgl.allocate_shared_memory(ttgl.float32, [XBLOCK, XBLOCK * 2], smem_layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1, 1], mbarrier.MBarrierLayout()) + mbarrier.init(bar.index(0), count=1) + + alias0 = smem if OVERLAP else smem.slice(0, XBLOCK, dim=1) + alias1 = smem.slice(XBLOCK, XBLOCK, dim=1) + + ttgl.warp_specialize([(producer, (input, alias0, bar, MISSING_WAIT, OVERLAP, blocked_layout)), + (consumer, (alias1, bar, blocked_layout))], [4], [32]) + + input = torch.randn((XBLOCK, ), device=device, dtype=torch.float32) + kernel[(1, )](input, MISSING_WAIT=MISSING_WAIT, OVERLAP=OVERLAP, num_warps=4) diff --git a/test/Conversion/tritoninstrument_to_llvm.mlir b/test/Conversion/tritoninstrument_to_llvm.mlir index f128d6859ad9..c4c98d743027 100644 --- a/test/Conversion/tritoninstrument_to_llvm.mlir +++ b/test/Conversion/tritoninstrument_to_llvm.mlir @@ -3,10 +3,12 @@ #blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:90"} { -// CHECK-LABEL: @experimental_buffer_pointers_tmem -// CHECK:nvg.tensor_memory_base -tt.func private @experimental_buffer_pointers_tmem() { - tti.experimental_buffer_pointers [0, 42], tensor_mem : tensor<2xi64, #blocked> +// CHECK-LABEL: @experimental_buffer_descriptors_tmem +// CHECK: llvm.mlir.constant(4294967295 : i64) : i64 +// CHECK: llvm.mlir.constant(34359738368 : i64) : i64 +// CHECK: llvm.mlir.constant(68719476736 : i64) : i64 +tt.func private @experimental_buffer_descriptors_tmem() { + tti.experimental_buffer_descriptors [0, 42], [8, 16], tensor_mem : tensor<2xi64, #blocked> tt.return } } @@ -16,10 +18,12 @@ tt.func private @experimental_buffer_pointers_tmem() { #blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:90"} { -// CHECK-LABEL: @experimental_buffer_pointers_shared -// CHECK: llvm.ptrtoint %arg0 -tt.func private @experimental_buffer_pointers_shared() { - tti.experimental_buffer_pointers [0, 42], shared_mem : tensor<2xi64, #blocked> +// CHECK-LABEL: @experimental_buffer_descriptors_shared +// CHECK: llvm.mlir.constant(4294967295 : i64) : i64 +// CHECK: llvm.mlir.constant(17179869184 : i64) : i64 +// CHECK: llvm.mlir.constant(51539607552 : i64) : i64 +tt.func private @experimental_buffer_descriptors_shared() { + tti.experimental_buffer_descriptors [0, 42], [4, 12], shared_mem : tensor<2xi64, #blocked> tt.return } } @@ -107,3 +111,19 @@ tt.func private @experimental_lock_release( tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:90"} { +// CHECK-LABEL: @experimental_memdesc_to_i32 +// CHECK: llvm.ptrtoint %1 : !llvm.ptr<3> to i32 +tt.func private @experimental_memdesc_to_i32( + %memdesc: !ttg.memdesc<32x32xf32, #shared, #smem, mutable> +) { + tti.experimental_memdesc_to_i32 %memdesc : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + tt.return +} +} diff --git a/test/TritonGPU/consan.mlir b/test/TritonGPU/consan.mlir index 80a7d1cafaa1..2ba1a9ddff7d 100644 --- a/test/TritonGPU/consan.mlir +++ b/test/TritonGPU/consan.mlir @@ -10,7 +10,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: #[[BUFS_BARS_L:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 1], order = [0, 1]}> // CHECK: @single_local_alloc tt.func public @single_local_alloc() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_pointers [0], shared_mem : tensor<1xi64, #[[BUFS_L]]> + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64, #[[BUFS_L]]> // CHECK: %[[WRITE_VISIBILITY:.*]] = arith.constant dense<0> : tensor<1xi64, #[[BUFS_L]]> // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 8 : i32} : !tt.ptr // CHECK: %[[READ_VISIBILITY:.*]] = arith.constant dense<0> : tensor<1x64xi64, #[[BUFS_THREADS_L]]> @@ -36,7 +36,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @two_local_alloc tt.func public @two_local_alloc() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_pointers [0, 4096], shared_mem : tensor<2xi64, + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096], [{{.*}}], shared_mem : tensor<2xi64, // CHECK: %[[WRITE_VISIBILITY:.*]] = arith.constant dense<0> : tensor<2xi64, // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 16 : i32} : !tt.ptr // CHECK: %[[READ_VISIBILITY:.*]] = arith.constant dense<0> : tensor<2x64xi64, @@ -64,7 +64,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @three_local_alloc tt.func public @three_local_alloc() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_pointers [0, 4096, 8192, 0], shared_mem : tensor<4xi64, + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096, 8192, 0], [{{.*}}], shared_mem : tensor<4xi64, // CHECK: %[[WRITE_VISIBILITY:.*]] = arith.constant dense<0> : tensor<4xi64, // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 32 : i32} : !tt.ptr // CHECK: %[[READ_VISIBILITY:.*]] = arith.constant dense<0> : tensor<4x64xi64, @@ -94,7 +94,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @three_sub_bufs tt.func public @three_sub_bufs() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_pointers [0, 4096, 8192, 0], shared_mem : tensor<4xi64, + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096, 8192, 0], [{{.*}}], shared_mem : tensor<4xi64, // CHECK: %[[WRITE_VISIBILITY:.*]] = arith.constant dense<0> : tensor<4xi64, // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 32 : i32} : !tt.ptr // CHECK: %[[READ_VISIBILITY:.*]] = arith.constant dense<0> : tensor<4x64xi64, @@ -159,8 +159,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: #[[BUFS_L:.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> // CHECK: @tmem_alloc tt.func public @tmem_alloc() { - // CHECK-DAG: %[[TMEM_BUFS:.*]] = tti.experimental_buffer_pointers [0], tensor_mem : tensor<1xi64, #[[BUFS_L]]> - // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_pointers [4096], shared_mem : tensor<1xi64, #[[BUFS_L]]> + // CHECK-DAG: %[[TMEM_BUFS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], tensor_mem : tensor<1xi64, #[[BUFS_L]]> + // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [4096], [{{.*}}], shared_mem : tensor<1xi64, #[[BUFS_L]]> %0 = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -177,12 +177,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_tma_copy_global_to_local tt.func public @async_tma_copy_global_to_local(%arg0: !tt.tensordesc>) { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_pointers [0], shared_mem : tensor<1xi64 + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 // CHECK-DAG: %[[WRITE_VISIBILITY:.*]] = arith.constant dense<0> : tensor<1xi64, // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 8 : i32} : !tt.ptr // CHECK-DAG: %[[READ_VISIBILITY:.*]] = arith.constant dense<0> : tensor<1x64xi64, // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 512 : i32} : !tt.ptr - // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_pointers [65536], shared_mem : tensor<1xi64 + // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64 // CHECK-DAG: %[[WRITE_TRACKING:.*]] = arith.constant dense<0> : tensor<1x1xi8, // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 1 : i32, nbytes = 1 : i32} : !tt.ptr // CHECK-DAG: %[[READ_TRACKING:.*]] = arith.constant dense<0> : tensor<1x1xi64, @@ -328,12 +328,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @wait_barrier tt.func public @wait_barrier(%arg0: !tt.tensordesc>) { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_pointers [0], shared_mem : tensor<1xi64, #blocked> + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64, #blocked> // CHECK-DAG: %[[WRITE_VISIBILITY:.*]] = arith.constant dense<0> : tensor<1xi64, // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 8 : i32} : !tt.ptr // CHECK-DAG: %[[READ_VISIBILITY:.*]] = arith.constant dense<0> : tensor<1x64xi64, // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 512 : i32} : !tt.ptr - // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_pointers [65536], shared_mem : tensor<1xi64, #blocked> + // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64, #blocked> // CHECK-DAG: %[[WRITE_TRACKING:.*]] = arith.constant dense<0> : tensor<1x1xi8, // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 1 : i32, nbytes = 1 : i32} : !tt.ptr // CHECK-DAG: %[[READ_TRACKING:.*]] = arith.constant dense<0> : tensor<1x1xi64, @@ -395,13 +395,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @tcgen5_mma tt.func public @tcgen5_mma(%arg0: !tt.tensordesc>) { - // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_pointers [0, 32768], shared_mem : tensor<2xi64 + // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 32768], [{{.*}}], shared_mem : tensor<2xi64 // CHECK-DAG: %[[SM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 16 : i32} : !tt.ptr // CHECK-DAG: %[[SM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 1024 : i32} : !tt.ptr - // CHECK-DAG: %[[TM_BUFS:.*]] = tti.experimental_buffer_pointers [0], tensor_mem : tensor<1xi64 + // CHECK-DAG: %[[TM_BUFS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], tensor_mem : tensor<1xi64 // CHECK-DAG: %[[TM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 8 : i32} : !tt.ptr // CHECK-DAG: %[[TM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 512 : i32} : !tt.ptr - // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_pointers [65536], shared_mem : tensor<1xi64 + // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64 // CHECK-DAG: %[[SM_WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 1 : i32, nbytes = 2 : i32} : !tt.ptr // CHECK-DAG: %[[SM_READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 16 : i32} : !tt.ptr @@ -409,43 +409,43 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK-DAG: %[[TM_READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 8 : i32} : !tt.ptr // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i64 %[[A:.*]] : - // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[A_I64]], %true, %[[TC_BIT]], %[[SM_BUFS]], %[[SM_WRITE_VISIBILITY_GLOB]] + // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A:.*]] : + // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[A_I64]], {{[^,]+}}, %true, %[[TC_BIT]], %[[SM_BUFS]], %[[SM_WRITE_VISIBILITY_GLOB]] // CHECK: %[[TC_MASK:.*]] = arith.constant 4294967296 : i64 - // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i64 %[[A]] : - // CHECK: tt.call @__triton_consan_set_read_visibility{{.*}}%[[A_I64]], %true, %[[TC_MASK]], %[[SM_BUFS]], %[[SM_READ_VISIBILITY_GLOB]] + // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A]] : + // CHECK: tt.call @__triton_consan_set_read_visibility{{.*}}%[[A_I64]], {{[^,]+}}, %true, %[[TC_MASK]], %[[SM_BUFS]], %[[SM_READ_VISIBILITY_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[B_I64:.*]] = tti.experimental_memdesc_to_i64 %[[B:.*]] : - // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[B_I64]], %true, %[[TC_BIT]], %[[SM_BUFS]], %[[SM_WRITE_VISIBILITY_GLOB]] + // CHECK: %[[B_I64:.*]] = tti.experimental_memdesc_to_i32 %[[B:.*]] : + // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[B_I64]], {{[^,]+}}, %true, %[[TC_BIT]], %[[SM_BUFS]], %[[SM_WRITE_VISIBILITY_GLOB]] // CHECK: %[[TC_MASK:.*]] = arith.constant 4294967296 : i64 - // CHECK: %[[B_I64:.*]] = tti.experimental_memdesc_to_i64 %[[B]] : - // CHECK: tt.call @__triton_consan_set_read_visibility{{.*}}%[[B_I64]], %true, %[[TC_MASK]], %[[SM_BUFS]], %[[SM_READ_VISIBILITY_GLOB]] + // CHECK: %[[B_I64:.*]] = tti.experimental_memdesc_to_i32 %[[B]] : + // CHECK: tt.call @__triton_consan_set_read_visibility{{.*}}%[[B_I64]], {{[^,]+}}, %true, %[[TC_MASK]], %[[SM_BUFS]], %[[SM_READ_VISIBILITY_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i64 %[[ACC:.*]] : - // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[ACC_I64]], %true, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] + // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC:.*]] : + // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %true, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i64 %[[ACC]] : - // CHECK: tt.call @__triton_consan_verify_read_visibility{{.*}}%[[ACC_I64]], %true, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] + // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : + // CHECK: tt.call @__triton_consan_verify_read_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %true, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] // CHECK: %[[TC_MASK:.*]] = arith.constant 4294967296 : i64 - // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i64 %[[ACC]] : - // CHECK: tt.call @__triton_consan_set_write_visibility{{.*}}%[[ACC_I64]], %true, %[[TC_MASK]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] - // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i64 %[[ACC]] : - // CHECK: tt.call @__triton_consan_clear_write_tracking{{.*}}%[[ACC_I64]], %true, %[[TM_BUFS]], %[[TM_WRITE_TRACKING_GLOB]] - // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i64 %[[ACC]] : - // CHECK: tt.call @__triton_consan_clear_read_visibility{{.*}}%[[ACC_I64]], %true, %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] - // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i64 %[[ACC]] : - // CHECK: tt.call @__triton_consan_clear_read_tracking{{.*}}%[[ACC_I64]], %true, %[[TM_BUFS]], %[[TM_READ_TRACKING_GLOB]] + // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : + // CHECK: tt.call @__triton_consan_set_write_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %true, %[[TC_MASK]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] + // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : + // CHECK: tt.call @__triton_consan_clear_write_tracking{{.*}}%[[ACC_I64]], {{[^,]+}}, %true, %[[TM_BUFS]], %[[TM_WRITE_TRACKING_GLOB]] + // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : + // CHECK: tt.call @__triton_consan_clear_read_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %true, %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] + // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : + // CHECK: tt.call @__triton_consan_clear_read_tracking{{.*}}%[[ACC_I64]], {{[^,]+}}, %true, %[[TM_BUFS]], %[[TM_READ_TRACKING_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i64 %[[BAR:.*]] : + // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR:.*]] : // CHECK: tt.call @__triton_consan_track_visible_writes{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_WRITE_VISIBILITY_GLOB]], %[[SM_WRITE_TRACKING_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i64 %[[BAR]] : + // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_READ_VISIBILITY_GLOB]], %[[SM_READ_TRACKING_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i64 %[[BAR]] : + // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_writes{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_WRITE_VISIBILITY_GLOB]], %[[TM_WRITE_TRACKING_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i64 %[[BAR]] : + // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_READ_VISIBILITY_GLOB]], %[[TM_READ_TRACKING_GLOB]] // CHECK: ttng.tc_gen5_mma %[[A]], %[[B]], %[[ACC]][], {{.*}}, {{.*}}, %[[BAR]] %c0_i32 = arith.constant 0 : i32 @@ -470,13 +470,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @tcgen5_mma_lhs_in_tmem tt.func public @tcgen5_mma_lhs_in_tmem(%arg0: !tt.tensordesc>) { - // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_pointers [32768], shared_mem : tensor<1xi64 + // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [32768], [{{.*}}], shared_mem : tensor<1xi64 // CHECK-DAG: %[[SM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 8 : i32} : !tt.ptr // CHECK-DAG: %[[SM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 512 : i32} : !tt.ptr - // CHECK-DAG: %[[TM_BUFS:.*]] = tti.experimental_buffer_pointers [0, 128], tensor_mem : tensor<2xi64 + // CHECK-DAG: %[[TM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 128], [{{.*}}], tensor_mem : tensor<2xi64 // CHECK-DAG: %[[TM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 16 : i32} : !tt.ptr // CHECK-DAG: %[[TM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 1024 : i32} : !tt.ptr - // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_pointers [65536], shared_mem : tensor<1xi64 + // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64 // CHECK-DAG: %[[SM_WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 1 : i32, nbytes = 1 : i32} : !tt.ptr // CHECK-DAG: %[[SM_READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 8 : i32} : !tt.ptr @@ -484,43 +484,43 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK-DAG: %[[TM_READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 16 : i32} : !tt.ptr // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i64 %[[A:.*]] : - // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[A_I64]], %true, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] + // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A:.*]] : + // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[A_I64]], {{[^,]+}}, %true, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] // CHECK: %[[TC_MASK:.*]] = arith.constant 4294967296 : i64 - // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i64 %[[A]] : - // CHECK: tt.call @__triton_consan_set_read_visibility{{.*}}%[[A_I64]], %true, %[[TC_MASK]], %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] + // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A]] : + // CHECK: tt.call @__triton_consan_set_read_visibility{{.*}}%[[A_I64]], {{[^,]+}}, %true, %[[TC_MASK]], %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[B_I64:.*]] = tti.experimental_memdesc_to_i64 %[[B:.*]] : - // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[B_I64]], %true, %[[TC_BIT]], %[[SM_BUFS]], %[[SM_WRITE_VISIBILITY_GLOB]] + // CHECK: %[[B_I64:.*]] = tti.experimental_memdesc_to_i32 %[[B:.*]] : + // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[B_I64]], {{[^,]+}}, %true, %[[TC_BIT]], %[[SM_BUFS]], %[[SM_WRITE_VISIBILITY_GLOB]] // CHECK: %[[TC_MASK:.*]] = arith.constant 4294967296 : i64 - // CHECK: %[[B_I64:.*]] = tti.experimental_memdesc_to_i64 %[[B]] : - // CHECK: tt.call @__triton_consan_set_read_visibility{{.*}}%[[B_I64]], %true, %[[TC_MASK]], %[[SM_BUFS]], %[[SM_READ_VISIBILITY_GLOB]] + // CHECK: %[[B_I64:.*]] = tti.experimental_memdesc_to_i32 %[[B]] : + // CHECK: tt.call @__triton_consan_set_read_visibility{{.*}}%[[B_I64]], {{[^,]+}}, %true, %[[TC_MASK]], %[[SM_BUFS]], %[[SM_READ_VISIBILITY_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i64 %[[ACC:.*]] : - // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[ACC_I64]], %true, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] + // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC:.*]] : + // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %true, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i64 %[[ACC]] : - // CHECK: tt.call @__triton_consan_verify_read_visibility{{.*}}%[[ACC_I64]], %true, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] + // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : + // CHECK: tt.call @__triton_consan_verify_read_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %true, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] // CHECK: %[[TC_MASK:.*]] = arith.constant 4294967296 : i64 - // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i64 %[[ACC]] : - // CHECK: tt.call @__triton_consan_set_write_visibility{{.*}}%[[ACC_I64]], %true, %[[TC_MASK]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] - // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i64 %[[ACC]] : - // CHECK: tt.call @__triton_consan_clear_write_tracking{{.*}}%[[ACC_I64]], %true, %[[TM_BUFS]], %[[TM_WRITE_TRACKING_GLOB]] - // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i64 %[[ACC]] : - // CHECK: tt.call @__triton_consan_clear_read_visibility{{.*}}%[[ACC_I64]], %true, %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] - // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i64 %[[ACC]] : - // CHECK: tt.call @__triton_consan_clear_read_tracking{{.*}}%[[ACC_I64]], %true, %[[TM_BUFS]], %[[TM_READ_TRACKING_GLOB]] + // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : + // CHECK: tt.call @__triton_consan_set_write_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %true, %[[TC_MASK]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] + // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : + // CHECK: tt.call @__triton_consan_clear_write_tracking{{.*}}%[[ACC_I64]], {{[^,]+}}, %true, %[[TM_BUFS]], %[[TM_WRITE_TRACKING_GLOB]] + // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : + // CHECK: tt.call @__triton_consan_clear_read_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %true, %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] + // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : + // CHECK: tt.call @__triton_consan_clear_read_tracking{{.*}}%[[ACC_I64]], {{[^,]+}}, %true, %[[TM_BUFS]], %[[TM_READ_TRACKING_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i64 %[[BAR:.*]] : + // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR:.*]] : // CHECK: tt.call @__triton_consan_track_visible_writes{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_WRITE_VISIBILITY_GLOB]], %[[SM_WRITE_TRACKING_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i64 %[[BAR]] : + // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_READ_VISIBILITY_GLOB]], %[[SM_READ_TRACKING_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i64 %[[BAR]] : + // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_writes{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_WRITE_VISIBILITY_GLOB]], %[[TM_WRITE_TRACKING_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 - // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i64 %[[BAR]] : + // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_READ_VISIBILITY_GLOB]], %[[TM_READ_TRACKING_GLOB]] // CHECK: tt.call @__triton_consan_verify_barrier_arrive // CHECK: tt.call @__triton_consan_update_barrier_state @@ -577,18 +577,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_copy_global_to_local tt.func public @async_copy_global_to_local(%ptr: tensor<128x128x!tt.ptr, #blocked>) { - // CHECK: %[[BUFFERS:.*]] = tti.experimental_buffer_pointers [0], shared_mem : tensor<1xi64 + // CHECK: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 // CHECK: %[[WRITE_COMMITS:.*]] = arith.constant dense<0> : tensor<1x16xi8 // CHECK: %[[WRT_COMMITS_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 1 : i32, nbytes = 16 : i32} : !tt.ptr - // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i64 %[[A:.*]] : + // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A:.*]] : // CHECK: tt.call @__triton_consan_verify_write_visibility_nw1{{.*}}(%[[A_I64]] - // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i64 %[[A]] : + // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A]] : // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 // CHECK: tt.call @__triton_consan_check_outstanding_commits{{.*}}(%[[A_I64]], {{.*}}, %[[THREAD_BIT]], %[[BUFFERS]], %[[WRT_COMMITS_GLOB]] // CHECK: tt.call @__triton_consan_verify_read_visibility_nw1 // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i64 %[[A]] : + // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A]] : // CHECK: tt.call @__triton_consan_stage_access_for_commit_nw1{{.*}}(%[[A_I64]], {{.*}}, %[[THREAD_BIT]], %[[BUFFERS]], %[[WRT_COMMITS_GLOB]] // CHECK: ttg.async_copy_global_to_local %{{.*}}, %[[A]] @@ -608,7 +608,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_copy_global_to_local_with_barriers tt.func public @async_copy_global_to_local_with_barriers(%ptr: tensor<128x128x!tt.ptr, #blocked>) { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_pointers [0], shared_mem : tensor<1xi64 + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 8 : i32} : !tt.ptr // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 512 : i32} : !tt.ptr // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 1 : i32, nbytes = 1 : i32} : !tt.ptr @@ -618,15 +618,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: tt.call @__triton_consan_init_barrier_state - // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i64 %[[A:.*]] : + // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A:.*]] : // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}(%[[A_I64]] - // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i64 %[[A]] : + // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A]] : // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 // CHECK: tt.call @__triton_consan_check_outstanding_commits{{.*}}(%[[A_I64]], {{.*}}, %[[THREAD_BIT]], %[[BUFFERS]], %[[WRT_COMMITS_GLOB]] - // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i64 %[[A]] : + // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A]] : // CHECK: tt.call @__triton_consan_verify_read_visibility{{.*}}(%[[A_I64]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i64 %[[A]] : + // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A]] : // CHECK: tt.call @__triton_consan_stage_access_for_commit{{.*}}(%[[A_I64]], {{.*}}, %[[THREAD_BIT]], %[[BUFFERS]], %[[WRT_COMMITS_GLOB]] // CHECK: ttg.async_copy_global_to_local %{{.*}}, %[[A]] %bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -700,7 +700,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @warp_group_dot tt.func public @warp_group_dot(%acc: tensor<128x128xf16, #mma>) { - // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_pointers [0, 32768], shared_mem : tensor<2xi64 + // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 32768], [{{.*}}], shared_mem : tensor<2xi64 // CHECK-DAG: %[[SM_WGMMA_READS:.*]] = arith.constant dense<0> : tensor<2x16xi8 // CHECK-DAG: %[[SM_WGMMA_WRITES_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 1 : i32, nbytes = 32 : i32} : !tt.ptr // CHECK: tt.call @__triton_consan_verify_write_visibility @@ -729,7 +729,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @warp_group_dot_sync tt.func public @warp_group_dot_sync(%acc: tensor<128x128xf16, #mma>) { - // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_pointers [0, 32768], shared_mem : tensor<2xi64 + // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 32768], [{{.*}}], shared_mem : tensor<2xi64 // CHECK-DAG: %[[SM_WGMMA_READS:.*]] = arith.constant dense<0> : tensor<2x16xi8 // CHECK-DAG: %[[SM_WGMMA_WRITES_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 1 : i32, nbytes = 32 : i32} : !tt.ptr @@ -775,9 +775,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK-LABEL: @local_alloc_with_src tt.func public @local_alloc_with_src(%acc: tensor<128x128xf16, #mma>) { // CHECK: %[[BUF:.*]] = ttg.local_alloc - // CHECK: %[[BUF_I64:.*]] = tti.experimental_memdesc_to_i64 %[[BUF:.*]] : + // CHECK: %[[BUF_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BUF:.*]] : // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}(%[[BUF_I64]] - // CHECK: %[[BUF_I64:.*]] = tti.experimental_memdesc_to_i64 %[[BUF:.*]] : + // CHECK: %[[BUF_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BUF:.*]] : // CHECK: tt.call @__triton_consan_verify_read_visibility{{.*}}(%[[BUF_I64]] %buf = ttg.local_alloc %acc {allocation.offset = 0 : i32} : (tensor<128x128xf16, #mma>) -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -798,9 +798,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK-LABEL: @tmem_alloc_with_src tt.func public @tmem_alloc_with_src(%acc: tensor<128x128xf16, #blocked>) { // CHECK: %[[BUF:.*]] = ttng.tmem_alloc - // CHECK: %[[BUF_I64:.*]] = tti.experimental_memdesc_to_i64 %[[BUF:.*]] : + // CHECK: %[[BUF_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BUF:.*]] : // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}(%[[BUF_I64]] - // CHECK: %[[BUF_I64:.*]] = tti.experimental_memdesc_to_i64 %[[BUF:.*]] : + // CHECK: %[[BUF_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BUF:.*]] : // CHECK: tt.call @__triton_consan_verify_read_visibility{{.*}}(%[[BUF_I64]] %buf = ttng.tmem_alloc %acc { tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32 } : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -818,7 +818,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @local_load_barriers tt.func public @local_load_barriers() { - // CHECK: tti.experimental_buffer_pointers + // CHECK: tti.experimental_buffer_descriptors %buf = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -840,7 +840,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @local_load_barriers tt.func public @local_load_barriers_cp_async(%ptr: tensor<128x128x!tt.ptr, #blocked>) { - // CHECK: tti.experimental_buffer_pointers + // CHECK: tti.experimental_buffer_descriptors %buf = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> %shmem = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 8192 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -870,7 +870,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @local_store_barriers_cp_async_wgmma tt.func public @local_store_barriers_cp_async_wgmma(%ptr: tensor<128x128x!tt.ptr, #blocked>, %acc: tensor<128x128xf16, #mma>) { - // CHECK: tti.experimental_buffer_pointers + // CHECK: tti.experimental_buffer_descriptors %buf = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> %shmem = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 8192 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -904,8 +904,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @ws_allocation tt.func public @ws_allocation(%arg0: !tt.tensordesc>) { - // CHECK-DAG: tti.experimental_buffer_pointers [65536], shared_mem : tensor<1xi64, - // CHECK-DAG: tti.experimental_buffer_pointers [0], shared_mem : tensor<1xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64, + // CHECK-DAG: tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 %smem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -927,8 +927,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar } partition0(%arg1: !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<1xi64, #shared1, #smem, mutable>) num_warps(4) { // CHECK: partition0 - // CHECK-DAG: tti.experimental_buffer_pointers [65536], shared_mem : tensor<1xi64, - // CHECK-DAG: tti.experimental_buffer_pointers [0], shared_mem : tensor<1xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64, + // CHECK-DAG: tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 // CHECK: tti.experimental_lock_acquire // CHECK: tt.call @__triton_consan_verify_write_visibility // CHECK: tt.call @__triton_consan_set_read_visibility @@ -949,8 +949,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @ws_buf_ptrs_default tt.func public @ws_buf_ptrs_default(%arg0: !tt.tensordesc>) { - // CHECK-DAG: tti.experimental_buffer_pointers [0, 32768, 65536, 0], shared_mem - // CHECK-DAG: tti.experimental_buffer_pointers [65536], shared_mem + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 32768, 65536, 0], [{{.*}}], shared_mem + // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem %smem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<3x128x128xf16, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -985,8 +985,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @ws_buf_ptrs_partition0 tt.func public @ws_buf_ptrs_partition0(%arg0: !tt.tensordesc>) { - // CHECK-DAG: tti.experimental_buffer_pointers [0, 32768, 65536, 0], shared_mem - // CHECK-DAG: tti.experimental_buffer_pointers [65536], shared_mem + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 32768, 65536, 0], [{{.*}}], shared_mem + // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem %smem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<3x128x128xf16, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -1051,3 +1051,149 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar tt.return } } + +// ----- + + +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +#tmem = #ttng.tensor_memory_encoding + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { + // CHECK-LABEL: @alias_matrix_shared + tt.func public @alias_matrix_shared() { + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 16], [128, 128], shared_mem : tensor<2xi64 + // CHECK-DAG: arith.constant dense : tensor<2x2xi1 + %buf0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32xf32, #shared, #smem, mutable> + %buf1 = ttg.local_alloc {allocation.offset = 16 : i32} : () -> !ttg.memdesc<32xf32, #shared, #smem, mutable> + %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> + ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared, #smem, mutable> + ttg.local_load %buf0 : !ttg.memdesc<32xf32, #shared, #smem, mutable> -> tensor<32xf32> + ttg.local_load %buf1 : !ttg.memdesc<32xf32, #shared, #smem, mutable> -> tensor<32xf32> + tt.return + } +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +#tmem = #ttng.tensor_memory_encoding + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { + // CHECK-LABEL: @alias_matrix_shared_indexed + tt.func public @alias_matrix_shared_indexed() { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 128], [128, 128], shared_mem : tensor<2xi64 + // CHECK-DAG: arith.constant dense<{{\[\[true, false\], \[false, true\]\]}}> : tensor<2x2xi1 + %smem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<2x32xf32, #shared, #smem, mutable> + %buf0 = ttg.memdesc_index %smem[%c0_i32] : !ttg.memdesc<2x32xf32, #shared, #smem, mutable> -> !ttg.memdesc<32xf32, #shared, #smem, mutable> + %buf1 = ttg.memdesc_index %smem[%c1_i32] : !ttg.memdesc<2x32xf32, #shared, #smem, mutable> -> !ttg.memdesc<32xf32, #shared, #smem, mutable> + %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> + ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared, #smem, mutable> + ttg.local_load %buf0 : !ttg.memdesc<32xf32, #shared, #smem, mutable> -> tensor<32xf32> + ttg.local_load %buf1 : !ttg.memdesc<32xf32, #shared, #smem, mutable> -> tensor<32xf32> + tt.return + } +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +#tmem = #ttng.tensor_memory_encoding + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { + // CHECK-LABEL: @alias_matrix_shared_subslice + tt.func public @alias_matrix_shared_subslice() { + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 128], [256, 128], shared_mem : tensor<2xi64 + // CHECK-DAG: arith.constant dense : tensor<2x2xi1 + %buf0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64xf32, #shared, #smem, mutable> + %buf1 = ttg.memdesc_subslice %buf0 [32] : !ttg.memdesc<64xf32, #shared, #smem, mutable> -> !ttg.memdesc<32xf32, #shared, #smem, mutable> + %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> + ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared, #smem, mutable> + ttg.local_load %buf0 : !ttg.memdesc<64xf32, #shared, #smem, mutable> -> tensor<64xf32> + ttg.local_load %buf1 : !ttg.memdesc<32xf32, #shared, #smem, mutable> -> tensor<32xf32> + tt.return + } +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +#tmem = #ttng.tensor_memory_encoding +#tmem2 = #ttng.tensor_memory_encoding + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { + // CHECK-LABEL: @alias_matrix_tensor + tt.func public @alias_matrix_tensor() { + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 32, 64, 0], [64, 32, 64, 0], tensor_mem : tensor<4xi64 + // CHECK-DAG: arith.constant dense<{{\[\[true, true, false, false\], \[true, true, false, false\], \[false, false, true, false\], \[false, false, false, false\]\]}}> : tensor<4x4xi1 + %buf0 = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable> + %buf1 = ttng.tmem_alloc {tensor_memory_col_offset = 64 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable> + %buf3 = ttng.tmem_subslice %buf0 {N = 32 : i32} : !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<64x32xf32, #tmem2, #ttng.tensor_memory, mutable> + ttng.tmem_load %buf0 : !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<64x64xf32> + ttng.tmem_load %buf1 : !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<64x64xf32> + ttng.tmem_load %buf3 : !ttg.memdesc<64x32xf32, #tmem2, #ttng.tensor_memory, mutable> -> tensor<64x32xf32> + tt.return + } +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +#tmem = #ttng.tensor_memory_encoding + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { + // CHECK-LABEL: @alias_matrix_mixed + tt.func public @alias_matrix_mixed() { + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 16], [128, 128], shared_mem : tensor<2xi64 + // CHECK-DAG: arith.constant dense : tensor<2x2xi1 + // CHECK-DAG: tti.experimental_buffer_descriptors [0], [64], tensor_mem : tensor<1xi64 + // CHECK-DAG: arith.constant dense : tensor<1x1xi1 + %smem0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32xf32, #shared, #smem, mutable> + %smem1 = ttg.local_alloc {allocation.offset = 16 : i32} : () -> !ttg.memdesc<32xf32, #shared, #smem, mutable> + %tmem0 = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_load %tmem0 : !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<64x64xf32> + ttg.local_load %smem0 : !ttg.memdesc<32xf32, #shared, #smem, mutable> -> tensor<32xf32> + ttg.local_load %smem1 : !ttg.memdesc<32xf32, #shared, #smem, mutable> -> tensor<32xf32> + tt.return + } +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 4 : i32} { + // CHECK-LABEL: @ws_alias_matrix + tt.func public @ws_alias_matrix() { + // We expect the alias matrix constant to appear once for the default region + // and once for partition0 when we lower warp_specialize. + // CHECK-DAG: arith.constant dense : tensor<2x2xi1 + %smem0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32xf32, #shared, #smem, mutable> + %smem1 = ttg.local_alloc {allocation.offset = 16 : i32} : () -> !ttg.memdesc<32xf32, #shared, #smem, mutable> + %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> + ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared, #smem, mutable> + + ttg.warp_specialize(%smem0, %smem1, %bar) attributes {actualRegisters = array, allocation.offset = 0 : i32, requestedRegisters = array, warpGroupStartIds = array} + default { + %c0 = arith.constant 0 : i32 + ttg.local_load %smem0 : !ttg.memdesc<32xf32, #shared, #smem, mutable> -> tensor<32xf32> + ttg.local_load %smem1 : !ttg.memdesc<32xf32, #shared, #smem, mutable> -> tensor<32xf32> + ttg.warp_yield + } + partition0(%arg0: !ttg.memdesc<32xf32, #shared, #smem, mutable>, %arg1: !ttg.memdesc<32xf32, #shared, #smem, mutable>, %arg2: !ttg.memdesc<1xi64, #shared, #smem, mutable>) num_warps(1) { + // CHECK: arith.constant dense : tensor<2x2xi1 + %c0 = arith.constant 0 : i32 + ttg.local_load %arg0 : !ttg.memdesc<32xf32, #shared, #smem, mutable> -> tensor<32xf32> + ttg.local_load %arg1 : !ttg.memdesc<32xf32, #shared, #smem, mutable> -> tensor<32xf32> + ttg.warp_return + } : (!ttg.memdesc<32xf32, #shared, #smem, mutable>, !ttg.memdesc<32xf32, #shared, #smem, mutable>, !ttg.memdesc<1xi64, #shared, #smem, mutable>) -> () + tt.return + } +}