Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ba17a9e
Basic lowering AsyncCommitGroup and AsyncWait
AlexAUT Jan 17, 2025
2327587
WIP lowering of AsyncCopy
AlexAUT Jan 17, 2025
1d4edf6
Added layout checks for asynccopy lowering
AlexAUT Jan 24, 2025
ead4915
Support direct to lds
AlexAUT Jan 27, 2025
3141ba4
Enable non working masking
AlexAUT Jan 27, 2025
644aa1e
Add support to enable disable direct to lds with env var AMDGCN_USE_D…
AlexAUT Jan 28, 2025
7c9bab1
Fix masking and others for direct to lds
AlexAUT Jan 28, 2025
cb823d0
Fix when AsycCopy is lowered without a mask
AlexAUT Jan 28, 2025
c097616
Use ROCDL instead of intrinsics
AlexAUT Jan 28, 2025
1a9f1e0
Cleanup and simplify AsyncCopy lowering
AlexAUT Jan 28, 2025
a20b686
CacheModifiers for AsyncCopy
AlexAUT Jan 28, 2025
97d677d
Add lit test for AsyncCopy
AlexAUT Jan 28, 2025
30352ad
Split AsyncCopy Lit for gfx950
AlexAUT Jan 28, 2025
fe8619d
Add const to getCtrlBitsForCacheModifierOnTarget
AlexAUT Jan 28, 2025
7941a30
Cleanup StreamPipeliner changes
AlexAUT Jan 28, 2025
def9313
Revert stream pipeline related changes
AlexAUT Jan 28, 2025
318caa2
Add missing CDNA1 to AsyncCopy support list
AlexAUT Jan 28, 2025
6600138
Cleanup
AlexAUT Jan 28, 2025
ea02c3c
Replace macros for llvm ops with TritonLLVMOpBuilder
AlexAUT Jan 29, 2025
13419bb
Fix wrong value in supported bit width for global.to.lds
AlexAUT Jan 30, 2025
ca8b441
Addressing review comments
AlexAUT Jan 31, 2025
6aa3554
Unified async ops lit tests
AlexAUT Jan 31, 2025
04fad93
Emit correct wmcnt wait instead of waiting on all cnts
AlexAUT Jan 31, 2025
f6cbe22
Add tests for AsyncWait/AsyncCommitGroup
AlexAUT Jan 31, 2025
3d30f43
Limit AsyncWait conversion to gfx9
AlexAUT Feb 3, 2025
0c382db
Add AsyncOpy lowering lit test with masking and other values
AlexAUT Feb 3, 2025
f560aeb
Added async copy lit tests with cache modifiers
AlexAUT Feb 5, 2025
d6b0d02
Merge branch 'main' into global_to_lds_lowering
AlexAUT Feb 5, 2025
d90ffbe
Adjust to shared encoding changes
AlexAUT Feb 5, 2025
5356802
Fix a few small issues
antiagainst Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions test/Conversion/amd/async_ops_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: async_copy
tt.func public @async_copy(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// We need the splat to allow the AxisAnalysis to work during lowering
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
// Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds
// CHECK-COUNT-8: rocdl.global.load.lds
// CHECK-NOT: rocdl.global.load.lds
%2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: async_copy_vectorized_2xf16
tt.func public @async_copy_vectorized_2xf16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>

// Each thread needs to load 8 elements and we load 2 (sizePerThread) per global.load.lds
// CHECK-COUNT-4: rocdl.global.load.lds
// CHECK-NOT: rocdl.global.load.lds
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
// GFX950-LABEL: async_copy_vectorized_8xf16
tt.func public @async_copy_vectorized_8xf16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>

// Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds
// GFX950: rocdl.global.load.lds
// GFX950-NOT: rocdl.global.load.lds

// GFX942 does not support vectorization > 4bytes
// expected-error@+1 {{failed to legalize operation 'ttg.async_copy_global_to_local' that was explicitly marked illegal}}
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: async_wait
tt.func public @async_wait(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
// The value of the rocdl.waitcnt is explained in the lowering of async_wait

// CHECK: rocdl.waitcnt -49168
// CHECK: rocdl.barrier
ttg.async_wait {num = 0 : i32}
// CHECK: rocdl.waitcnt -49167
// CHECK: rocdl.barrier
ttg.async_wait {num = 1 : i32}
// CHECK: rocdl.waitcnt -2
// CHECK: rocdl.barrier
ttg.async_wait {num = 62 : i32}
// CHECK: rocdl.waitcnt -1
// CHECK: rocdl.barrier
ttg.async_wait {num = 63 : i32}
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: async_commit_group
tt.func public @async_commit_group(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.return
ttg.async_commit_group
tt.return
}
}
241 changes: 237 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,185 @@ struct BufferLoadOpConversion
}
};

struct AsyncCopyGlobalToLocalOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::AsyncCopyGlobalToLocalOp>,
public LoadStoreConversionBase {
AsyncCopyGlobalToLocalOpConversion(LLVMTypeConverter &converter,
const AMD::TargetInfo &targetInfo,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertOpToLLVMPattern(converter, benefit),
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}

bool supportsLoadWidth(unsigned bits,
const AMD::TargetInfo &targetInfo) const {
llvm::SmallSetVector<unsigned, 10> supportedWidths;
using mlir::triton::AMD::ISAFamily;
switch (targetInfo.getISAFamily()) {
case ISAFamily::CDNA1:
case ISAFamily::CDNA2:
case ISAFamily::CDNA3:
supportedWidths.insert(8);
supportedWidths.insert(16);
supportedWidths.insert(32);
if (targetInfo.getGPUKind() == llvm::AMDGPU::GPUKind::GK_GFX950) {
supportedWidths.insert(96);
supportedWidths.insert(128);
}
break;
default:
return false;
}

return supportedWidths.contains(bits);
}

LogicalResult
matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);

auto srcTy = op.getSrc().getType();
auto srcEncoding = srcTy.getEncoding();

if (!isa<BlockedEncodingAttr, SliceEncodingAttr>(srcEncoding))
return rewriter.notifyMatchFailure(
op, "requires Blocked or Slice encoding for src");
if (srcTy.getShape().size() != 2)
return rewriter.notifyMatchFailure(op, "only supports 2d tensors");

auto dstTy = op.getResult().getType();
auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType());

Value llSrc = adaptor.getSrc();

auto srcElems = unpackLLElements(loc, llSrc, rewriter);

Value llDst = adaptor.getResult();
auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct(
loc, llDst, resElemTy, rewriter);

// We can load N elements at a time if:
// 1. Every group of N source pointers are contiguous. For example, if
// N=2, then the pointers should be [x, x+1, y, y+1, ...].
// 2. The mask (if present) has "alignment" N, meaning that each group of N
// mask bits are the same. For example if N=2, the mask must be
// [x, x, y, y, ...].
unsigned maxVec =
mlir::LLVM::AMD::getContiguity(op.getSrc(), axisAnalysisPass);
Value mask = op.getMask();
if (mask) {
maxVec = std::min(maxVec, getMaskAlignment(mask));
}

// global.load.lds does not support per lane offsets.
// We need to ensure that we write coalesced into shared memory.
// This means that the kLane dim needs to be contigeous based on the
// vectorization size
auto shape = dstTy.getShape();
LinearLayout srcLayout =
triton::gpu::toLinearLayout(shape, srcTy.getEncoding());
LinearLayout sharedLayout = triton::gpu::toLinearLayout(
shape, dstTy.getEncoding(), resElemTy.getIntOrFloatBitWidth());
LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout);

StringAttr kLane = rewriter.getStringAttr("lane");
for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) {
auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0];
unsigned expected = maxVec * (1 << inLane);
if (basis != expected) {
LDBG("detected uncoalesced layout from blocked to shared in async copy "
"for lane "
<< 1 + inLane << "; given " << basis << " but expected "
<< expected);
return rewriter.notifyMatchFailure(op,
"does not write coalesced into LDS");
}
}

// Addresses to store into, one per `vecTy`.
VectorType vecTy;
SmallVector<Value> shmemAddrs;
bool ok = emitTransferBetweenRegistersAndShared(
srcTy, dstTy, resElemTy, {}, smemObj, loc, rewriter, targetInfo,
[&](VectorType vecTy_, Value shmemAddr) {
vecTy = vecTy_;
shmemAddrs.push_back(shmemAddr);
});
assert(ok);

int vecBits = vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
if (!supportsLoadWidth(vecBits, targetInfo)) {
return rewriter.notifyMatchFailure(
op, "Async copy does not support the required load vectorization");
}

int vecBytes = vecBits / 8;
assert(llvm::isPowerOf2_32(vecBytes));
Value vecBytesVal = b.i32_val(vecBytes);

Value cacheModifiers =
b.i32_val(mlir::LLVM::AMD::getCtrlBitsForCacheModifierOnTarget(
op.getCache(), false, targetInfo));

Value llMask = adaptor.getMask();
SmallVector<Value> maskElems;
if (llMask) {
maskElems = unpackLLElements(loc, llMask, rewriter);
assert(srcElems.size() == maskElems.size());
}

Value other = op.getOther();
SmallVector<Value> otherElems;
if (other) {
otherElems = unpackLLElements(loc, adaptor.getOther(), rewriter);
assert(srcElems.size() == otherElems.size());
}

for (int i = 0; i < shmemAddrs.size(); i++) {
auto srcIdx = i * maxVec;
auto srcPtr = srcElems[srcIdx];

if (!mask) {
rewriter.create<ROCDL::GlobalLoadLDSOp>(
loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0),
cacheModifiers);
continue;
}

Block *currentBlock = rewriter.getInsertionBlock();
Block *afterLoad =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
Block *loadBlock = rewriter.createBlock(afterLoad);
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, maskElems[srcIdx], loadBlock,
afterLoad);
rewriter.setInsertionPointToStart(loadBlock);
rewriter.create<ROCDL::GlobalLoadLDSOp>(
loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0),
cacheModifiers);

rewriter.create<LLVM::BrOp>(loc, afterLoad);
rewriter.setInsertionPointToStart(afterLoad);
if (other) {
Value storeVal = packElementRangeIntoVector(
rewriter, this->getTypeConverter(), loc, vecTy, otherElems, srcIdx);
llStore(rewriter, loc, shmemAddrs[i], storeVal,
b.icmp_ne(maskElems[srcIdx], b.true_val()), 0, op.getCache());
}
}

// Drop the result token.
Value zero = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), IntegerType::get(op.getContext(), 32),
rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
};

struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
public LoadStoreConversionBase {
using ConvertOpToLLVMPattern<triton::StoreOp>::ConvertOpToLLVMPattern;
Expand Down Expand Up @@ -1459,6 +1638,57 @@ struct AtomicRMWOpConversion
return endBlock->getArgument(0);
}
};

struct AsyncWaitConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
using ConvertOpToLLVMPattern<AsyncWaitOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto loc = op->getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);

// global.load.lds uses vmcnt to synchronize
// The rocdl op stores all possible coutners in a single int32 value (v)
// The vmcnt (6 bits) is split into a lower 3:0 and higher part 5:4
// The lower parts is stored in 3:0 of v and the higher part in bits 15:14
// We have to set all other bits in v to 1 to signal we are not interested
// in those

int vmCnt = op.getNum();
if (vmCnt >= 64) {
return emitError(loc, "AsyncWait does not support values >= 64");
}

// Extract low and high bits and combine while setting all other bits to 1
unsigned lowBits = vmCnt & 0xF;
unsigned highBits = vmCnt >> 4 << 14;
unsigned otherCnts = ~0xC00F; // C00F has bits 15:14 and 3:0 set
unsigned waitValue = lowBits | highBits | otherCnts;

rewriter.create<ROCDL::WaitcntOp>(loc, waitValue);

// Drop the result AsyncToken
rewriter.replaceOp(op, b.i32_val(0));
return success();
}
};

struct AsyncCommitGroupConversion
: public ConvertOpToLLVMPattern<AsyncCommitGroupOp> {
using ConvertOpToLLVMPattern<AsyncCommitGroupOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(AsyncCommitGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Drop the result AsyncToken
auto loc = op->getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
rewriter.replaceOp(op, b.i32_val(0));
return success();
}
};

} // namespace

namespace mlir::triton::AMD {
Expand All @@ -1468,9 +1698,12 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
PatternBenefit benefit) {
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
StoreOpConversion, BufferLoadOpConversion,
BufferStoreOpConversion, BufferAtomicRMWOpConversion>(
typeConverter, targetInfo, axisInfoAnalysis, benefit);
patterns
.add<AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
StoreOpConversion, BufferLoadOpConversion, BufferStoreOpConversion,
BufferAtomicRMWOpConversion, AsyncCopyGlobalToLocalOpConversion>(
typeConverter, targetInfo, axisInfoAnalysis, benefit);
patterns.add<AsyncWaitConversion, AsyncCommitGroupConversion>(typeConverter,
benefit);
}
} // namespace mlir::triton::AMD
6 changes: 3 additions & 3 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,9 @@ static int32_t getDefaultCtrlBitsForCacheModifier(triton::CacheModifier cm) {
// .cv: don't cache and fetch again
// .wb: write-back, writes back data at all cache levels
// .wt: write-through, write data directly to system memory
int32_t
getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier cm, bool isBufferLoad,
mlir::triton::AMD::TargetInfo &targetInfo) {
int32_t getCtrlBitsForCacheModifierOnTarget(
triton::CacheModifier cm, bool isBufferLoad,
const mlir::triton::AMD::TargetInfo &targetInfo) {
if (targetInfo.getGPUKind() == llvm::AMDGPU::GK_GFX942) // gfx942
return getCtrlBitsForCacheModifierOnGFX942(cm, isBufferLoad);
else
Expand Down
Loading
Loading