Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 50 additions & 33 deletions lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,28 @@ constexpr uint64_t getUnaryOpId(UnaryOpId opId) {
// Scratch memory management
// ------------------------------------------------------------

static ttg::BlockedEncodingAttr
getOptimizedBlockedEncoding(PatternRewriter &rewriter, ArrayRef<int64_t> shape,
Type elemType) {
int numWarps = ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent());
int threadsPerWarp = ttg::lookupThreadsPerWarp(rewriter);
int numCTAs = ttg::lookupNumCTAs(rewriter.getInsertionBlock()->getParentOp());
auto base = ttg::getDefaultBlockedEncoding(rewriter.getContext(), shape,
numWarps, threadsPerWarp, numCTAs);
SmallVector<unsigned> order = llvm::to_vector(base.getOrder());
SmallVector<unsigned> sizePerThread(shape.size(), 1);
unsigned elemBits = elemType.getIntOrFloatBitWidth();
unsigned maxElems = std::max(128u / elemBits, 1u);
if (!order.empty()) {
unsigned dim = order.front();
sizePerThread[dim] =
static_cast<unsigned>(std::min<int64_t>(shape[dim], maxElems));
}
return ttg::BlockedEncodingAttr::get(
rewriter.getContext(), sizePerThread, base.getThreadsPerWarp(),
base.getWarpsPerCTA(), order, base.getCGALayout());
}

struct ScratchInfo {
Value ptr;
RankedTensorType tensorType;
Expand Down Expand Up @@ -95,30 +117,6 @@ class TmemScratchManager {
return Value();
}

static ttg::BlockedEncodingAttr
getOptimizedBlockedEncoding(PatternRewriter &rewriter,
ArrayRef<int64_t> shape, Type elemType) {
int numWarps =
ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent());
int threadsPerWarp = ttg::lookupThreadsPerWarp(rewriter);
int numCTAs =
ttg::lookupNumCTAs(rewriter.getInsertionBlock()->getParentOp());
auto base = ttg::getDefaultBlockedEncoding(
rewriter.getContext(), shape, numWarps, threadsPerWarp, numCTAs);
SmallVector<unsigned> order = llvm::to_vector(base.getOrder());
SmallVector<unsigned> sizePerThread(shape.size(), 1);
unsigned elemBits = elemType.getIntOrFloatBitWidth();
unsigned maxElems = std::max(128u / elemBits, 1u);
if (!order.empty()) {
unsigned dim = order.front();
sizePerThread[dim] =
static_cast<unsigned>(std::min<int64_t>(shape[dim], maxElems));
}
return ttg::BlockedEncodingAttr::get(
rewriter.getContext(), sizePerThread, base.getThreadsPerWarp(),
base.getWarpsPerCTA(), order, base.getCGALayout());
}

std::optional<ScratchInfo>
getOrCreate(Value memdesc, PatternRewriter &rewriter, Region *scope) {
if (auto arg = dyn_cast<BlockArgument>(memdesc)) {
Expand Down Expand Up @@ -830,9 +828,16 @@ struct DotPattern : public OpRewritePattern<tt::DotOp> {
int64_t tileM = std::min<int64_t>(kTileM, m);
int64_t tileN = std::min<int64_t>(kTileN, n);

auto accLayout = cast<ttg::DistributedEncodingTrait>(cTy.getEncoding());
auto aLayout = cast<ttg::DistributedEncodingTrait>(aTy.getEncoding());
auto bLayout = cast<ttg::DistributedEncodingTrait>(bTy.getEncoding());
// Use optimized blocked layouts for emulation tiles instead of the
// original dot encodings. Encodings like AMDWmmaEncodingAttr impose
// minimum shape requirements (e.g. >= 16x16) that the small emulation
// tiles (kTileM x kTileN = 8x8) cannot satisfy.
auto accLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN},
cTy.getElementType());
auto aLayout =
getOptimizedBlockedEncoding(rewriter, {tileM, k}, aTy.getElementType());
auto bLayout =
getOptimizedBlockedEncoding(rewriter, {k, tileN}, bTy.getElementType());

auto accTileTy =
RankedTensorType::get({tileM, tileN}, cTy.getElementType(), accLayout);
Expand All @@ -845,6 +850,12 @@ struct DotPattern : public OpRewritePattern<tt::DotOp> {
Value bPtr = createScratchAndStore(rewriter, loc, op.getB(), bTy);
Value dPtr = createScratchAndStore(rewriter, loc, op.getC(), cTy);

// Each warp may only store a subset of each tile's rows, so a barrier is
// needed to make all scratch stores visible before the loops read them.
ttg::BarrierOp::create(rewriter, loc,
ttg::AddrSpace::GlobalRead |
ttg::AddrSpace::GlobalWrite);

auto mLoop = emitMmaEmulationLoops(
rewriter, loc, aPtr, bPtr, dPtr, m, n, k, tileM, tileN, aTileTy,
bTileTy, accTileTy, accLayout, accElem, useDInt, predInt,
Expand All @@ -853,6 +864,12 @@ struct DotPattern : public OpRewritePattern<tt::DotOp> {
return failure();
rewriter.setInsertionPointAfter(*mLoop);

// Same reason: each warp may only write a subset of D's rows in the loop,
// so synchronize before the final load.
ttg::BarrierOp::create(rewriter, loc,
ttg::AddrSpace::GlobalRead |
ttg::AddrSpace::GlobalWrite);

Value out = loadScratchStrided2D(rewriter, loc, dPtr, cTy, /*stride1=*/m);
if (!out)
return failure();
Expand Down Expand Up @@ -1035,16 +1052,16 @@ struct TCGen5MMAPattern : public OpRewritePattern<ttng::TCGen5MMAOp> {
int64_t tileM = std::min<int64_t>(kTileM, m);
int64_t tileN = std::min<int64_t>(kTileN, n);

auto accTileLayout = TmemScratchManager::getOptimizedBlockedEncoding(
rewriter, {tileM, tileN}, dMemTy.getElementType());
auto accTileLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN},
dMemTy.getElementType());
auto accTileTy = RankedTensorType::get(
{tileM, tileN}, dMemTy.getElementType(), accTileLayout);
auto aTileLayout = TmemScratchManager::getOptimizedBlockedEncoding(
rewriter, {tileM, k}, aMemTy.getElementType());
auto aTileLayout = getOptimizedBlockedEncoding(rewriter, {tileM, k},
aMemTy.getElementType());
auto aTileTy =
RankedTensorType::get({tileM, k}, aMemTy.getElementType(), aTileLayout);
auto bTileLayout = TmemScratchManager::getOptimizedBlockedEncoding(
rewriter, {k, tileN}, bMemTy.getElementType());
auto bTileLayout = getOptimizedBlockedEncoding(rewriter, {k, tileN},
bMemTy.getElementType());
auto bTileTy =
RankedTensorType::get({k, tileN}, bMemTy.getElementType(), bTileLayout);

Expand Down
113 changes: 113 additions & 0 deletions python/test/gluon/test_fpsan.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,3 +665,116 @@ def reduce_kernel(a_ptr, c_ptr, M: tl.constexpr, N: tl.constexpr, stride_am: tl.
reduce_kernel[(1, )](a, c1, M=M, N=N, stride_am=a.stride(0), stride_ak=a.stride(1), ORDER=0)
reduce_kernel[(1, )](a, c2, M=M, N=N, stride_am=a.stride(0), stride_ak=a.stride(1), ORDER=1)
assert _payload_equal(c1, c2)


@pytest.mark.skipif(not (is_hip_cdna3() or is_hip_cdna4()), reason="Requires CDNA3 or CDNA4")
def test_mfma_dot(device, fresh_knobs):
_require_cuda_backend(device)

M, N, K = 16, 16, 32

fresh_knobs.compilation.instrumentation_mode = "fpsan"

cdna_version = 3 if is_hip_cdna3() else 4
nonkdim = 32
kdim = 8 if cdna_version == 3 else 16
k_width_val = 4 if cdna_version == 3 else 8

blocked = gl.BlockedLayout([4, 4], [4, 16], [4, 1], [1, 0])
mfma_layout = gl.amd.AMDMFMALayout(cdna_version, [nonkdim, nonkdim, kdim], True, [4, 1])

@gluon.jit
def kernel(a_ptr, b_ptr, c_ptr, out_ptr, BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr, BLOCK_K: gl.constexpr,
blocked: gl.constexpr, k_width: gl.constexpr, mfma_layout: gl.constexpr):
dot_a_layout: gl.constexpr = gl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=k_width)
dot_b_layout: gl.constexpr = gl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=k_width)

offs_am = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, blocked))
offs_bn = gl.arange(0, BLOCK_N, layout=gl.SliceLayout(0, blocked))
offs_ak = gl.arange(0, BLOCK_K, layout=gl.SliceLayout(0, blocked))
offs_bk = gl.arange(0, BLOCK_K, layout=gl.SliceLayout(1, blocked))

a = gl.load(a_ptr + offs_am[:, None] * BLOCK_K + offs_ak[None, :])
b = gl.load(b_ptr + offs_bk[:, None] * BLOCK_N + offs_bn[None, :])
c = gl.load(c_ptr + offs_am[:, None] * BLOCK_N + offs_bn[None, :])

a1 = gl.convert_layout(a, layout=dot_a_layout)
b1 = gl.convert_layout(b, layout=dot_b_layout)
c_acc = gl.convert_layout(c, layout=mfma_layout)

result = gl.amd.cdna3.mfma(a1, b1, c_acc)
result = gl.convert_layout(result, layout=blocked)
gl.store(out_ptr + offs_am[:, None] * BLOCK_N + offs_bn[None, :], result)

rs = np.random.RandomState(0)
a_bits = rs.randint(-(2**31), 2**31 - 1, size=(M, K), dtype=np.int32)
b_bits = rs.randint(-(2**31), 2**31 - 1, size=(K, N), dtype=np.int32)
c_bits = rs.randint(-(2**31), 2**31 - 1, size=(M, N), dtype=np.int32)
exp_bits = _mm_payload_u32(a_bits, b_bits, c_bits)

a = torch.tensor(a_bits, device="cuda", dtype=torch.int32)
b = torch.tensor(b_bits, device="cuda", dtype=torch.int32)
c = torch.tensor(c_bits, device="cuda", dtype=torch.int32)
out = torch.empty((M, N), device="cuda", dtype=torch.int32)

aw = triton.TensorWrapper(a, dtype=torch.float32)
bw = triton.TensorWrapper(b, dtype=torch.float32)
cw = triton.TensorWrapper(c, dtype=torch.float32)
outw = triton.TensorWrapper(out, dtype=torch.float32)

kernel[(1, )](aw, bw, cw, outw, BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, blocked=blocked, k_width=k_width_val,
mfma_layout=mfma_layout)

_assert_payload_equal(out, exp_bits)


@pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires gfx1250")
def test_wmma_dot(device, fresh_knobs):
_require_cuda_backend(device)

B = 32
fresh_knobs.compilation.instrumentation_mode = "fpsan"

@gluon.jit
def kernel(a_ptr, b_ptr, c_ptr, out_ptr, BLOCK: gl.constexpr, INSTR_SHAPE_K: gl.constexpr, K_WIDTH: gl.constexpr):
blocked: gl.constexpr = gl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0])
wmma: gl.constexpr = gl.amd.AMDWMMALayout(3, True, [[0, 1], [1, 0]], [], [16, 16, INSTR_SHAPE_K])

offs_m = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, blocked))[:, None]
offs_k = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, blocked))[None, :]
offs_bk = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, blocked))[:, None]
offs_n = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, blocked))[None, :]

a = gl.load(a_ptr + offs_m * BLOCK + offs_k)
b = gl.load(b_ptr + offs_bk * BLOCK + offs_n)
c = gl.load(c_ptr + offs_m * BLOCK + offs_n)
c = gl.convert_layout(c, wmma)

a = gl.convert_layout(a, gl.DotOperandLayout(0, wmma, K_WIDTH))
b = gl.convert_layout(b, gl.DotOperandLayout(1, wmma, K_WIDTH))
acc = gl.amd.gfx1250.wmma(a, b, c)

out_layout: gl.constexpr = gl.SliceLayout(1, wmma)
offs_cm = gl.arange(0, BLOCK, layout=out_layout)[:, None]
offs_cn = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, wmma))[None, :]
gl.store(out_ptr + offs_cm * BLOCK + offs_cn, acc)

rs = np.random.RandomState(0)
a_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32)
b_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32)
c_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32)
exp_bits = _mm_payload_u32(a_bits, b_bits, c_bits)

a = torch.tensor(a_bits, device="cuda", dtype=torch.int32)
b = torch.tensor(b_bits, device="cuda", dtype=torch.int32)
c = torch.tensor(c_bits, device="cuda", dtype=torch.int32)
out = torch.empty((B, B), device="cuda", dtype=torch.int32)

aw = triton.TensorWrapper(a, dtype=torch.float32)
bw = triton.TensorWrapper(b, dtype=torch.float32)
cw = triton.TensorWrapper(c, dtype=torch.float32)
outw = triton.TensorWrapper(out, dtype=torch.float32)

kernel[(1, )](aw, bw, cw, outw, BLOCK=B, INSTR_SHAPE_K=4, K_WIDTH=2)

_assert_payload_equal(out, exp_bits)
68 changes: 68 additions & 0 deletions test/TritonGPU/amd/amd-fpsan.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: @dot_emulation
tt.func public @dot_emulation() -> tensor<16x16xf32, #blocked> {
// CHECK: ttg.barrier global_read|global_write
// CHECK: scf.for
// CHECK: ttg.barrier global_read|global_write
// CHECK-NOT: tt.dot
// CHECK-NOT: ttg.convert_layout
%cst = arith.constant 1.000000e+00 : f16
Expand All @@ -20,6 +22,72 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ

// -----

#mfma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [2, 2], instrShape = [16, 16, 16], isTransposed = true}>
#mfma_dot_a = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>
#mfma_dot_b = #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: @mfma_cdna3_dot_emulation
tt.func public @mfma_cdna3_dot_emulation() -> tensor<32x32xf32, #mfma> {
// CHECK: ttg.barrier global_read|global_write
// CHECK: scf.for
// CHECK: ttg.barrier global_read|global_write
// CHECK-NOT: tt.dot
// CHECK-NOT: ttg.convert_layout
%cst = arith.constant 1.000000e+00 : f16
%zero = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mfma>
%a = tt.splat %cst : f16 -> tensor<32x32xf16, #mfma_dot_a>
%b = tt.splat %cst : f16 -> tensor<32x32xf16, #mfma_dot_b>
%out = tt.dot %a, %b, %zero : tensor<32x32xf16, #mfma_dot_a> * tensor<32x32xf16, #mfma_dot_b> -> tensor<32x32xf32, #mfma>
tt.return %out : tensor<32x32xf32, #mfma>
}
}

// -----

#mfma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 2], instrShape = [16, 16, 32], isTransposed = true}>
#mfma_dot_a = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>
#mfma_dot_b = #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: @mfma_cdna4_dot_emulation
tt.func public @mfma_cdna4_dot_emulation() -> tensor<32x32xf32, #mfma> {
// CHECK: ttg.barrier global_read|global_write
// CHECK: scf.for
// CHECK: ttg.barrier global_read|global_write
// CHECK-NOT: tt.dot
// CHECK-NOT: ttg.convert_layout
%cst = arith.constant 1.000000e+00 : f16
%zero = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mfma>
%a = tt.splat %cst : f16 -> tensor<32x32xf16, #mfma_dot_a>
%b = tt.splat %cst : f16 -> tensor<32x32xf16, #mfma_dot_b>
%out = tt.dot %a, %b, %zero : tensor<32x32xf16, #mfma_dot_a> * tensor<32x32xf16, #mfma_dot_b> -> tensor<32x32xf32, #mfma>
tt.return %out : tensor<32x32xf32, #mfma>
}
}

// -----

#wmma = #ttg.amd_wmma<{version = 3, isTranspose = true, ctaLayout = {warp = [[0, 1], [1, 0]]}, instrShape = [16, 16, 32]}>
#wmma_dot_a = #ttg.dot_op<{opIdx = 0, parent = #wmma, kWidth = 8}>
#wmma_dot_b = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 8}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @wmma_dot_emulation
tt.func public @wmma_dot_emulation() -> tensor<32x32xf32, #wmma> {
// CHECK: ttg.barrier global_read|global_write
// CHECK: scf.for
// CHECK: ttg.barrier global_read|global_write
// CHECK-NOT: tt.dot
// CHECK-NOT: ttg.convert_layout
%cst = arith.constant 1.000000e+00 : f16
%zero = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #wmma>
%a = tt.splat %cst : f16 -> tensor<32x32xf16, #wmma_dot_a>
%b = tt.splat %cst : f16 -> tensor<32x32xf16, #wmma_dot_b>
%out = tt.dot %a, %b, %zero : tensor<32x32xf16, #wmma_dot_a> * tensor<32x32xf16, #wmma_dot_b> -> tensor<32x32xf32, #wmma>
tt.return %out : tensor<32x32xf32, #wmma>
}
}

// -----

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: @binary_ops
tt.func public @binary_ops(%a: tensor<4xf32>, %b: tensor<4xf32>) -> tensor<4xf32> {
Expand Down
2 changes: 2 additions & 0 deletions test/TritonGPU/fpsan.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
// CHECK-LABEL: @dot_emulation
tt.func public @dot_emulation() -> tensor<16x16xf32, #blocked> {
// CHECK: ttg.barrier global_read|global_write
// CHECK: scf.for
// CHECK: ttg.barrier global_read|global_write
// CHECK-NOT: tt.dot
// CHECK-NOT: ttg.convert_layout
%cst = arith.constant 1.000000e+00 : f16
Expand Down
Loading