Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ba17a9e
Basic lowering AsyncCommitGroup and AsyncWait
AlexAUT Jan 17, 2025
2327587
WIP lowering of AsyncCopy
AlexAUT Jan 17, 2025
1d4edf6
Added layout checks for asynccopy lowering
AlexAUT Jan 24, 2025
ead4915
Support direct to lds
AlexAUT Jan 27, 2025
3141ba4
Enable non working masking
AlexAUT Jan 27, 2025
644aa1e
Add support to enable disable direct to lds with env var AMDGCN_USE_D…
AlexAUT Jan 28, 2025
7c9bab1
Fix masking and others for direct to lds
AlexAUT Jan 28, 2025
cb823d0
Fix when AsycCopy is lowered without a mask
AlexAUT Jan 28, 2025
c097616
Use ROCDL instead of intrinsics
AlexAUT Jan 28, 2025
1a9f1e0
Cleanup and simplify AsyncCopy lowering
AlexAUT Jan 28, 2025
a20b686
CacheModifiers for AsyncCopy
AlexAUT Jan 28, 2025
97d677d
Add lit test for AsyncCopy
AlexAUT Jan 28, 2025
30352ad
Split AsyncCopy Lit for gfx950
AlexAUT Jan 28, 2025
fe8619d
Add const to getCtrlBitsForCacheModifierOnTarget
AlexAUT Jan 28, 2025
7941a30
Cleanup StreamPipeliner changes
AlexAUT Jan 28, 2025
def9313
Revert stream pipeline related changes
AlexAUT Jan 28, 2025
318caa2
Add missing CDNA1 to AsyncCopy support list
AlexAUT Jan 28, 2025
6600138
Cleanup
AlexAUT Jan 28, 2025
ea02c3c
Replace macros for llvm ops with TritonLLVMOpBuilder
AlexAUT Jan 29, 2025
13419bb
Fix wrong value in supported bit width for global.to.lds
AlexAUT Jan 30, 2025
ca8b441
Addressing review comments
AlexAUT Jan 31, 2025
6aa3554
Unified async ops lit tests
AlexAUT Jan 31, 2025
04fad93
Emit correct wmcnt wait instead of waiting on all cnts
AlexAUT Jan 31, 2025
f6cbe22
Add tests for AsyncWait/AsyncCommitGroup
AlexAUT Jan 31, 2025
3d30f43
Limit AsyncWait conversion to gfx9
AlexAUT Feb 3, 2025
0c382db
Add AsyncOpy lowering lit test with masking and other values
AlexAUT Feb 3, 2025
f560aeb
Added async copy lit tests with cache modifiers
AlexAUT Feb 5, 2025
d6b0d02
Merge branch 'main' into global_to_lds_lowering
AlexAUT Feb 5, 2025
d90ffbe
Adjust to shared encoding changes
AlexAUT Feb 5, 2025
5356802
Fix a few small issues
antiagainst Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions test/Conversion/amd/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,46 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: async_copy
tt.func public @async_copy(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// We need the splat to allow the AxisAnalysis to work during lowering
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
// CHECK: rocdl.global.load.lds
%2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: async_copy_vectorized
tt.func public @async_copy_vectorized(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>

// Each thread needs to load 8 elements and we load 2 (sizePerThread) per global.load.lds
// CHECK-COUNT-4: rocdl.global.load.lds
// CHECK-NOT: rocdl.global.load.lds
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
tt.return
}
}
24 changes: 24 additions & 0 deletions test/Conversion/amd/tritongpu_to_llvm_gfx950.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 --convert-builtin-func-to-llvm | FileCheck %s

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: async_copy_vectorized
tt.func public @async_copy_vectorized(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>

// Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds
// CHECK: rocdl.global.load.lds
// CHECK-NOT: rocdl.global.load.lds
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
tt.return
}
}
221 changes: 220 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ using namespace mlir::triton::gpu;

using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::getSharedMemoryBase;
using ::mlir::LLVM::AMD::getContiguity;
using mlir::LLVM::AMD::getCtrlBitsForCacheModifierOnTarget;
using ::mlir::LLVM::AMD::getVectorSize;
using ::mlir::LLVM::AMD::llLoad;
using ::mlir::LLVM::AMD::llStore;
using ::mlir::triton::AMD::ISAFamily;
using ::mlir::triton::gpu::getTotalElemsPerThread;

namespace {
Expand Down Expand Up @@ -396,6 +399,177 @@ struct BufferLoadOpConversion
}
};

struct AsyncCopyGlobalToLocalOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::AsyncCopyGlobalToLocalOp>,
public LoadStoreConversionBase {
using ConvertOpToLLVMPattern<
triton::gpu::AsyncCopyGlobalToLocalOp>::ConvertOpToLLVMPattern;

AsyncCopyGlobalToLocalOpConversion(LLVMTypeConverter &converter,
const AMD::TargetInfo &targetInfo,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertOpToLLVMPattern<triton::gpu::AsyncCopyGlobalToLocalOp>(converter,
benefit),
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}

bool isLoadWidthSupported(unsigned bits,
const AMD::TargetInfo &targetInfo) const {
llvm::SmallSetVector<unsigned, 10> supportedWidths;
switch (targetInfo.getISAFamily()) {
case mlir::triton::AMD::ISAFamily::CDNA1:
case mlir::triton::AMD::ISAFamily::CDNA2:
case mlir::triton::AMD::ISAFamily::CDNA3:
supportedWidths.insert(8);
supportedWidths.insert(16);
supportedWidths.insert(32);
if (targetInfo.getGPUKind() == llvm::AMDGPU::GPUKind::GK_GFX950) {
supportedWidths.insert(96);
supportedWidths.insert(128);
}
break;
default:
return false;
}

return supportedWidths.contains(bits);
}

LogicalResult
matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);

auto srcTy = op.getSrc().getType();
auto srcEncoding = srcTy.getEncoding();
assert((isa<BlockedEncodingAttr, SliceEncodingAttr>(srcEncoding) &&
"Unexpected srcEncoding in AsyncCopyGlobalToLocalOpConversion"));
assert(srcTy.getShape().size() <= 2 && "Async copy only supports 1d and 2d "
"tensors: Unexpected rank of %src");

auto dstTy = op.getResult().getType();
auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType());

Value llSrc = adaptor.getSrc();

auto srcElems = unpackLLElements(loc, llSrc, rewriter);

Value llDst = adaptor.getResult();
auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct(
loc, llDst, resElemTy, rewriter);

unsigned maxVec = getContiguity(op.getSrc(), axisAnalysisPass);

Value mask = op.getMask();
if (mask) {
maxVec = std::min(maxVec, getMaskAlignment(mask));
}

// global.load.lds does not support per lane offsets.
// We need to ensure that we write coalesced into shared memory.
// This means that the kLane dim needs to be contigeous based on the
// vectorization size
auto shape = dstTy.getShape();
LinearLayout srcLayout =
triton::gpu::toLinearLayout(shape, srcTy.getEncoding());
LinearLayout sharedLayout = triton::gpu::toLinearLayout(
shape, dstTy.getEncoding(), resElemTy.getIntOrFloatBitWidth());
LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout);

StringAttr kLane = rewriter.getStringAttr("lane");
for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) {
auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0];
unsigned expected = maxVec * (1 << inLane);
if (basis != expected) {
return emitError(loc, "Invalid layout in AsyncCopy: ")
<< "Lane: " << 1 + inLane << " is " << basis << " should be "
<< expected << "\n";
}
}

// Addresses to store into, one per `vecTy`.
VectorType vecTy;
SmallVector<Value> shmemAddrs;
bool ok = emitTransferBetweenRegistersAndShared(
srcTy, dstTy, resElemTy, {}, smemObj, loc, rewriter, targetInfo,
[&](VectorType vecTy_, Value shmemAddr) {
vecTy = vecTy_;
shmemAddrs.push_back(shmemAddr);
});
assert(ok);

int vecBits = vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
if (!isLoadWidthSupported(vecBits, targetInfo)) {
return emitError(loc, "Async copy does not support the required load "
"vectorization, got ")
<< vecBits << " bits";
}

int vecBytes = vecBits / 8;
assert(llvm::isPowerOf2_32(vecBytes));
Value vecBytesVal = b.i32_val(vecBytes);

Value cacheModifiers = b.i32_val(
getCtrlBitsForCacheModifierOnTarget(op.getCache(), false, targetInfo));

Value llMask = adaptor.getMask();
SmallVector<Value> maskElems;
if (llMask) {
maskElems = unpackLLElements(loc, llMask, rewriter);
assert(srcElems.size() == maskElems.size());
}

Value other = op.getOther();
SmallVector<Value> otherElems;
if (other) {
otherElems = unpackLLElements(loc, adaptor.getOther(), rewriter);
assert(srcElems.size() == otherElems.size());
}

for (int i = 0; i < shmemAddrs.size(); i++) {
auto srcIdx = i * maxVec;
auto srcPtr = srcElems[srcIdx];

if (!mask) {
rewriter.create<ROCDL::GlobalLoadLDSOp>(
loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0),
cacheModifiers);
} else {
Block *currentBlock = rewriter.getInsertionBlock();
Block *afterLoad =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
Block *loadBlock = rewriter.createBlock(afterLoad);
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, maskElems[srcIdx], loadBlock,
afterLoad);
rewriter.setInsertionPointToStart(loadBlock);
rewriter.create<ROCDL::GlobalLoadLDSOp>(
loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0),
cacheModifiers);

rewriter.create<LLVM::BrOp>(loc, afterLoad);
rewriter.setInsertionPointToStart(afterLoad);
if (other) {
Value storeVal =
packElementRangeIntoVector(rewriter, this->getTypeConverter(),
loc, vecTy, otherElems, srcIdx);
llStore(rewriter, loc, shmemAddrs[i], storeVal,
b.icmp_ne(maskElems[srcIdx], b.true_val()), 0, op.getCache());
}
}
}

// Drop the result token.
Value zero = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), IntegerType::get(op.getContext(), 32),
rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
};

struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
public LoadStoreConversionBase {
using ConvertOpToLLVMPattern<triton::StoreOp>::ConvertOpToLLVMPattern;
Expand Down Expand Up @@ -1459,6 +1633,49 @@ struct AtomicRMWOpConversion
return endBlock->getArgument(0);
}
};

struct AsyncWaitConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
using ConvertOpToLLVMPattern<AsyncWaitOp>::ConvertOpToLLVMPattern;

AsyncWaitConversion(LLVMTypeConverter &converter,
const AMD::TargetInfo &targetInfo,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertOpToLLVMPattern<AsyncWaitOp>(converter, benefit) {}

LogicalResult
matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto loc = op->getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
rewriter.create<ROCDL::WaitcntOp>(loc, op.getNum());
rewriter.replaceOp(op, b.i32_val(0));
return success();
}
};

struct AsyncCommitGroupConversion
: public ConvertOpToLLVMPattern<AsyncCommitGroupOp> {
using ConvertOpToLLVMPattern<AsyncCommitGroupOp>::ConvertOpToLLVMPattern;

AsyncCommitGroupConversion(LLVMTypeConverter &converter,
const AMD::TargetInfo &targetInfo,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertOpToLLVMPattern<AsyncCommitGroupOp>(converter, benefit) {}

LogicalResult
matchAndRewrite(AsyncCommitGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Drop the result token
auto loc = op->getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
rewriter.replaceOp(op, b.i32_val(0));
return success();
}
};

} // namespace

namespace mlir::triton::AMD {
Expand All @@ -1470,7 +1687,9 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
PatternBenefit benefit) {
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
StoreOpConversion, BufferLoadOpConversion,
BufferStoreOpConversion, BufferAtomicRMWOpConversion>(
BufferStoreOpConversion, BufferAtomicRMWOpConversion,
AsyncCopyGlobalToLocalOpConversion, AsyncCommitGroupConversion,
AsyncWaitConversion, AsyncCommitGroupConversion>(
typeConverter, targetInfo, axisInfoAnalysis, benefit);
}
} // namespace mlir::triton::AMD
6 changes: 3 additions & 3 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,9 @@ static int32_t getDefaultCtrlBitsForCacheModifier(triton::CacheModifier cm) {
// .cv: don't cache and fetch again
// .wb: write-back, writes back data at all cache levels
// .wt: write-through, write data directly to system memory
int32_t
getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier cm, bool isBufferLoad,
mlir::triton::AMD::TargetInfo &targetInfo) {
int32_t getCtrlBitsForCacheModifierOnTarget(
triton::CacheModifier cm, bool isBufferLoad,
const mlir::triton::AMD::TargetInfo &targetInfo) {
if (targetInfo.getGPUKind() == llvm::AMDGPU::GK_GFX942) // gfx942
return getCtrlBitsForCacheModifierOnGFX942(cm, isBufferLoad);
else
Expand Down
5 changes: 3 additions & 2 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
// Get flags <volatile, nontemporal> for a predicated Load or Store
std::pair<bool, bool> getCacheModifierFlagsForPredicatedCall(LLVM::CallOp);
// Get the cachepolicy value for a cache modifier
int32_t getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier, bool,
mlir::triton::AMD::TargetInfo &);
int32_t
getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier, bool,
const mlir::triton::AMD::TargetInfo &);

// Get cache modifier information for buffer atomics
int32_t getCtrlBitsForBufferAtomicsOnGFX942(bool setSC0, bool setSC1,
Expand Down
Loading