diff --git a/test/Conversion/amd/async_ops_to_llvm.mlir b/test/Conversion/amd/async_ops_to_llvm.mlir new file mode 100644 index 000000000000..afea228c3e67 --- /dev/null +++ b/test/Conversion/amd/async_ops_to_llvm.mlir @@ -0,0 +1,217 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950 +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s --check-prefix=GFX942 + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy + tt.func public @async_copy(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the splat to allow the AxisAnalysis to work during lowering + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + // Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds + // CHECK-COUNT-8: rocdl.global.load.lds + // CHECK-NOT: rocdl.global.load.lds + %2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy_vectorized_2xf16 + tt.func public @async_copy_vectorized_2xf16(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the index calculation so AxisAnalysis sees that we can vectorize the load + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + + // Each thread needs to load 8 elements and we load 2 (sizePerThread) per global.load.lds + // CHECK-COUNT-4: rocdl.global.load.lds + // CHECK-NOT: rocdl.global.load.lds + %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // GFX950-LABEL: async_copy_vectorized_8xf16 + tt.func public @async_copy_vectorized_8xf16(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the index calculation so AxisAnalysis sees that we can vectorize the load + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + + // Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds + // GFX950: rocdl.global.load.lds + // GFX950-next: llvm.return + + // GFX942 does not support vectorization > 4bytes + // expected-error@+1 {{failed to legalize operation 'ttg.async_copy_global_to_local' that was explicitly marked illegal}} + %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_wait + tt.func public @async_wait(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // The waitcnt stores all counters in one i32 bits 15:14 and 3:0 store the vmcnt we have to wait on + // CHECK: rocdl.waitcnt -49168 + // CHECK: rocdl.barrier + ttg.async_wait {num = 0 : i32} + // CHECK: rocdl.waitcnt -49167 + // CHECK: rocdl.barrier + ttg.async_wait {num = 1 : i32} + // CHECK: rocdl.waitcnt -2 + // CHECK: rocdl.barrier + ttg.async_wait {num = 62 : i32} + // CHECK: rocdl.waitcnt -1 + // CHECK: rocdl.barrier + ttg.async_wait {num = 63 : i32} + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_commit_group + tt.func public @async_commit_group(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: llvm.return + ttg.async_commit_group + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy_mask_other + tt.func public @async_copy_mask_other(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>, + %arg3: i32 {tt.divisibility = 16 : i32}) { + // We need the splat to allow the AxisAnalysis to work during lowering + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c31_i32 = arith.constant 31 : i32 + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %29 = arith.addi %arg3, %c31_i32 : i32 + %30 = arith.divsi %29, %c32_i32 : i32 + %31 = arith.cmpi sgt, %30, %c0_i32 : i32 + + %51 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %52 = tt.expand_dims %51 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %65 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #blocked> + %66 = arith.cmpi slt, %52, %65 : tensor<32x1xi32, #blocked> + %67 = tt.broadcast %66 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> + + %70 = tt.splat %31 : i1 -> tensor<32x32xi1, #blocked> + %71 = arith.andi %70, %67 : tensor<32x32xi1, #blocked> + + // Each thread needs to load 4 elements and we load 1 (sizePerThread) per global.load.lds + // Note that mask/other alignment is 1 so we need 4 conditionals + + // CHECK: llvm.cond_br + // CHECK: rocdl.global.load.lds + // CHECK-NEXT: llvm.br + // CHECK: _predicated_store + + // CHECK: llvm.cond_br + // CHECK: rocdl.global.load.lds + // CHECK-NEXT: llvm.br + // CHECK: _predicated_store + + // CHECK: llvm.cond_br + // CHECK: rocdl.global.load.lds + // CHECK-NEXT: llvm.br + // CHECK: _predicated_store + + // CHECK: llvm.cond_br + // CHECK: rocdl.global.load.lds + // CHECK-NEXT: llvm.br + // CHECK: _predicated_store + + %2 = ttg.async_copy_global_to_local %1, %arg2 mask %67 other %cst_0 : tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [16, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // GFX942-LABEL: async_copy_cache_mods + tt.func public @async_copy_cache_mods(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>) { + // We need the splat to allow the AxisAnalysis to work during lowering + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + // Each thread needs to load 1 element and we load 1 (sizePerThread) per global.load.lds + + // GFX942: llvm.getelementptr + // GFX942: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32 + // GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]] + %2 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = ca: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + // GFX942: llvm.getelementptr + // GFX942: %[[aux_cg:.*]] = llvm.mlir.constant(0 : i32) : i32 + // GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cg]] + %3 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cg: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + // GFX942: llvm.getelementptr + // GFX942: %[[aux_cs:.*]] = llvm.mlir.constant(3 : i32) : i32 + // GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cs]] + %5 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cs: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + // GFX942: llvm.getelementptr + // GFX942: %[[aux_cv:.*]] = llvm.mlir.constant(9 : i32) : i32 + // GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cv]] + %6 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cv: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + // GFX942: llvm.getelementptr + // GFX942: %[[aux_wb:.*]] = llvm.mlir.constant(0 : i32) : i32 + // GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_wb]] + %7 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = wb: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + // GFX942: llvm.getelementptr + // GFX942: %[[aux_wt:.*]] = llvm.mlir.constant(8 : i32) : i32 + // GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_wt]] + %8 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = wt: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 18f5cfc68abe..757802cb912a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -396,6 +396,185 @@ struct BufferLoadOpConversion } }; +struct AsyncCopyGlobalToLocalOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + AsyncCopyGlobalToLocalOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + bool supportsLoadWidth(unsigned bits, + const AMD::TargetInfo &targetInfo) const { + llvm::SmallSetVector supportedWidths; + using mlir::triton::AMD::ISAFamily; + switch (targetInfo.getISAFamily()) { + case ISAFamily::CDNA1: + case ISAFamily::CDNA2: + case ISAFamily::CDNA3: + supportedWidths.insert(8); + supportedWidths.insert(16); + supportedWidths.insert(32); + if (targetInfo.getGPUKind() == llvm::AMDGPU::GPUKind::GK_GFX950) { + supportedWidths.insert(96); + supportedWidths.insert(128); + } + break; + default: + return false; + } + + return supportedWidths.contains(bits); + } + + LogicalResult + matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto srcTy = op.getSrc().getType(); + auto srcEncoding = srcTy.getEncoding(); + + if (!isa(srcEncoding)) + return rewriter.notifyMatchFailure( + op, "requires Blocked or Slice encoding for src"); + if (srcTy.getShape().size() != 2) + return rewriter.notifyMatchFailure(op, "only supports 2d tensors"); + + auto dstTy = op.getResult().getType(); + auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + + Value llSrc = adaptor.getSrc(); + + auto srcElems = unpackLLElements(loc, llSrc, rewriter); + + Value llDst = adaptor.getResult(); + auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct( + loc, llDst, resElemTy, rewriter); + + // We can load N elements at a time if: + // 1. Every group of N source pointers are contiguous. For example, if + // N=2, then the pointers should be [x, x+1, y, y+1, ...]. + // 2. The mask (if present) has "alignment" N, meaning that each group of N + // mask bits are the same. For example if N=2, the mask must be + // [x, x, y, y, ...]. + unsigned maxVec = + mlir::LLVM::AMD::getContiguity(op.getSrc(), axisAnalysisPass); + Value mask = op.getMask(); + if (mask) { + maxVec = std::min(maxVec, getMaskAlignment(mask)); + } + + // global.load.lds does not support per lane offsets. + // We need to ensure that we write coalesced into shared memory. This means + // that the kLane dim needs to be contigeous based on the vectorization + // size. + auto shape = dstTy.getShape(); + LinearLayout srcLayout = + triton::gpu::toLinearLayout(shape, srcTy.getEncoding()); + LinearLayout sharedLayout = + triton::gpu::toLinearLayout(shape, dstTy.getEncoding()); + LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout); + + StringAttr kLane = rewriter.getStringAttr("lane"); + for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) { + auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0]; + unsigned expected = maxVec * (1 << inLane); + if (basis != expected) { + LDBG("detected uncoalesced layout from blocked to shared in async copy " + "for lane " + << 1 + inLane << "; given " << basis << " but expected " + << expected); + return rewriter.notifyMatchFailure(op, + "does not write coalesced into LDS"); + } + } + + // Addresses to store into, one per `vecTy`. + VectorType vecTy; + SmallVector shmemAddrs; + bool ok = emitTransferBetweenRegistersAndShared( + srcTy, dstTy, resElemTy, {}, smemObj, loc, rewriter, targetInfo, + [&](VectorType vecTy_, Value shmemAddr) { + vecTy = vecTy_; + shmemAddrs.push_back(shmemAddr); + }); + assert(ok); + + int vecBits = vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); + if (!supportsLoadWidth(vecBits, targetInfo)) { + return rewriter.notifyMatchFailure( + op, "Async copy does not support the required load vectorization"); + } + + int vecBytes = vecBits / 8; + assert(llvm::isPowerOf2_32(vecBytes)); + Value vecBytesVal = b.i32_val(vecBytes); + + Value cacheModifiers = + b.i32_val(mlir::LLVM::AMD::getCtrlBitsForCacheModifierOnTarget( + op.getCache(), false, targetInfo)); + + Value llMask = adaptor.getMask(); + SmallVector maskElems; + if (llMask) { + maskElems = unpackLLElements(loc, llMask, rewriter); + assert(srcElems.size() == maskElems.size()); + } + + Value other = op.getOther(); + SmallVector otherElems; + if (other) { + otherElems = unpackLLElements(loc, adaptor.getOther(), rewriter); + assert(srcElems.size() == otherElems.size()); + } + + for (int i = 0; i < shmemAddrs.size(); i++) { + auto srcIdx = i * maxVec; + auto srcPtr = srcElems[srcIdx]; + + if (!mask) { + rewriter.create( + loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0), + cacheModifiers); + continue; + } + + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterLoad = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *loadBlock = rewriter.createBlock(afterLoad); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, maskElems[srcIdx], loadBlock, + afterLoad); + rewriter.setInsertionPointToStart(loadBlock); + rewriter.create( + loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0), + cacheModifiers); + + rewriter.create(loc, afterLoad); + rewriter.setInsertionPointToStart(afterLoad); + if (other) { + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, vecTy, otherElems, srcIdx); + llStore(rewriter, loc, shmemAddrs[i], storeVal, + b.icmp_ne(maskElems[srcIdx], b.true_val()), 0, op.getCache()); + } + } + + // Drop the result token. + Value zero = rewriter.create( + op.getLoc(), IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); + rewriter.replaceOp(op, zero); + return success(); + } +}; + struct StoreOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1459,6 +1638,76 @@ struct AtomicRMWOpConversion return endBlock->getArgument(0); } }; + +struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern { + AsyncWaitOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + using mlir::triton::AMD::ISAFamily; + + switch (targetInfo.getISAFamily()) { + case ISAFamily::CDNA1: + case ISAFamily::CDNA2: + case ISAFamily::CDNA3: + break; + default: + return rewriter.notifyMatchFailure( + op, "Only supported on target architecture"); + } + + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // global.load.lds uses vmcnt to synchronize + // The rocdl op stores all available counters in a single int32 value (v). + // The vmcnt (6 bits) is split into a lower 3:0 and higher 5:4 parts. + // The lower part is stored in bits 3:0 of v and the higher part in bits + // 15:14. We have to set all other bits in v to 1 to signal we are not + // interested in those. + + int vmCnt = op.getNum(); + if (vmCnt >= 64) { + return emitError(loc, "AsyncWait does not support values >= 64"); + } + + // Extract low and high bits and combine while setting all other bits to 1 + unsigned lowBits = vmCnt & 0xF; + unsigned highBits = vmCnt >> 4 << 14; + unsigned otherCnts = ~0xC00F; // C00F has bits 15:14 and 3:0 set + unsigned waitValue = lowBits | highBits | otherCnts; + + rewriter.create(loc, waitValue); + + // Drop the result AsyncToken + rewriter.replaceOp(op, b.i32_val(0)); + return success(); + } + +private: + const AMD::TargetInfo &targetInfo; +}; + +struct AsyncCommitGroupOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(AsyncCommitGroupOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Drop the result AsyncToken + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + rewriter.replaceOp(op, b.i32_val(0)); + return success(); + } +}; + } // namespace namespace mlir::triton::AMD { @@ -1468,9 +1717,12 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns.add( - typeConverter, targetInfo, axisInfoAnalysis, benefit); + patterns + .add( + typeConverter, targetInfo, axisInfoAnalysis, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index d4b8d7abe01f..84c58fb824fe 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -517,9 +517,9 @@ static int32_t getDefaultCtrlBitsForCacheModifier(triton::CacheModifier cm) { // .cv: don't cache and fetch again // .wb: write-back, writes back data at all cache levels // .wt: write-through, write data directly to system memory -int32_t -getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier cm, bool isBufferLoad, - mlir::triton::AMD::TargetInfo &targetInfo) { +int32_t getCtrlBitsForCacheModifierOnTarget( + triton::CacheModifier cm, bool isBufferLoad, + const mlir::triton::AMD::TargetInfo &targetInfo) { if (targetInfo.getGPUKind() == llvm::AMDGPU::GK_GFX942) // gfx942 return getCtrlBitsForCacheModifierOnGFX942(cm, isBufferLoad); else diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index 1dabe31db2d9..bd02cce16cc9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -54,8 +54,9 @@ void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, // Get flags for a predicated Load or Store std::pair getCacheModifierFlagsForPredicatedCall(LLVM::CallOp); // Get the cachepolicy value for a cache modifier -int32_t getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier, bool, - mlir::triton::AMD::TargetInfo &); +int32_t +getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier, bool, + const mlir::triton::AMD::TargetInfo &); // Get cache modifier information for buffer atomics int32_t getCtrlBitsForBufferAtomicsOnGFX942(bool setSC0, bool setSC1,