Skip to content

Commit 505cd44

Browse files
[Gluon] Expose 3d Dot FMA (#9501)
Enables batched FMA Dots in Gluon. --------- Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com>
1 parent e2f77ae commit 505cd44

File tree

3 files changed

+41
-7
lines changed

3 files changed

+41
-7
lines changed

lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,
103103
Value llA = adaptor.getA();
104104
Value llB = adaptor.getB();
105105

106-
auto sizePerThread = getContigPerThread(dTensorTy);
106+
llvm::SmallVector<unsigned> sizePerThread{dLayout.getSizePerThread()};
107107
auto numElemsPerThread = product(sizePerThread);
108108
SmallVector<unsigned> shapePerCTATile;
109109
for (auto [reg, thread, warp] :

python/test/gluon/test_core.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1915,13 +1915,44 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr):
19151915
ttgl.store(out_ptr + offs, out)
19161916

19171917
a = torch.rand((B, B), dtype=torch.float32, device="cuda")
1918-
b = torch.ones((B, B), dtype=torch.float32, device="cuda")
1918+
b = torch.rand((B, B), dtype=torch.float32, device="cuda")
19191919
c = torch.rand((B, B), dtype=torch.float32, device="cuda")
19201920
out = torch.empty((B, B), dtype=torch.float32, device="cuda")
19211921
kernel[(1, )](a, b, c, out)
19221922
torch.testing.assert_close(out, torch.addmm(c, a, b), atol=1e-2, rtol=1e-2)
19231923

19241924

1925+
def test_dot3d_fma():
1926+
torch.manual_seed(42)
1927+
B = ttgl.constexpr(32)
1928+
BATCH = ttgl.constexpr(8)
1929+
threads_per_warp = ttgl.constexpr(THREADS_PER_WARP)
1930+
1931+
@gluon.jit
1932+
def kernel(a_ptr, b_ptr, c_ptr, out_ptr):
1933+
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1, 1], [1, threads_per_warp, 1], [ttgl.num_warps(), 1, 1],
1934+
[2, 1, 0])
1935+
lhs_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=layout, operand_index=0, k_width=0)
1936+
rhs_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=layout, operand_index=1, k_width=0)
1937+
1938+
offs_b = ttgl.arange(0, BATCH, layout=ttgl.SliceLayout(1, ttgl.SliceLayout(2, layout)))[:, None, None]
1939+
offs_m = ttgl.arange(0, B, layout=ttgl.SliceLayout(0, ttgl.SliceLayout(2, layout)))[None, :, None]
1940+
offs_n = ttgl.arange(0, B, layout=ttgl.SliceLayout(0, ttgl.SliceLayout(1, layout)))[None, None, :]
1941+
offs = offs_b * B * B + offs_m * B + offs_n
1942+
a = ttgl.convert_layout(ttgl.load(a_ptr + offs), lhs_layout)
1943+
b = ttgl.convert_layout(ttgl.load(b_ptr + offs), rhs_layout)
1944+
c = ttgl.load(c_ptr + offs)
1945+
out = ttgl.dot_fma(a, b, c)
1946+
ttgl.store(out_ptr + offs, out)
1947+
1948+
a = torch.rand((BATCH, B, B), dtype=torch.float32, device="cuda")
1949+
b = torch.rand((BATCH, B, B), dtype=torch.float32, device="cuda")
1950+
c = torch.rand((BATCH, B, B), dtype=torch.float32, device="cuda")
1951+
out = torch.empty((BATCH, B, B), dtype=torch.float32, device="cuda")
1952+
kernel[(1, )](a, b, c, out)
1953+
torch.testing.assert_close(out, torch.matmul(a, b) + c, atol=1e-2, rtol=1e-2)
1954+
1955+
19251956
@gluon.jit
19261957
def kernel_auto_layout_constant(threads_per_warp: ttgl.constexpr):
19271958
BLOCK: ttgl.constexpr = 16

python/triton/experimental/gluon/language/_core.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -635,11 +635,14 @@ def dot_fma(a, b, acc, _semantic=None):
635635
assert b.type.layout.parent == mma_layout, "b's parent layout must be the same as acc's layout"
636636
assert a.type.layout.operand_index == 0, "a's operand index must be 0"
637637
assert b.type.layout.operand_index == 1, "b's operand index must be 1"
638-
639-
M, N = acc.shape
640-
K = a.shape[1]
641-
if M * N * K > 2**19:
642-
warnings.warn(f"Large dot FMA instruction size {M}x{N}x{K} may have slow compile times")
638+
assert len(acc.shape) == 2 or len(acc.shape) == 3
639+
assert len(acc.shape) == len(a.shape) == len(b.shape)
640+
641+
unified_dot_shape = acc.shape + a.shape[-1:] # join batch/M/N and K in one list
642+
if math.prod(unified_dot_shape) > 2**19:
643+
dot_name = "batched dot" if len(acc.shape) == 3 else "dot"
644+
shape_str = "x".join([str(x) for x in unified_dot_shape])
645+
warnings.warn(f"Large {dot_name} FMA instruction size {shape_str} may have slow compile times")
643646

644647
handle = _semantic.dot(a, b, acc, input_precision=None, max_num_imprecise_acc=None, out_dtype=acc.dtype).handle
645648
return tensor(handle, acc.type)

0 commit comments

Comments
 (0)