Skip to content

Commit 089b8bc

Browse files
authored
[Blackwell] Set mutable attribute in InjectTMemCopy (triton-lang#5875)
Make sure the alloc we copy into is mutable and add verifier for it.
1 parent f29d8c7 commit 089b8bc

File tree

6 files changed

+66
-12
lines changed

6 files changed

+66
-12
lines changed

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,11 @@ class InjectTMemCopy
195195
if (!localLoad || !isTmemCopyCompatible(localLoad.getSrc().getType())) {
196196
return failure();
197197
}
198-
198+
MemDescType newType = MemDescType::get(
199+
dstType.getShape(), dstType.getElementType(), dstType.getEncoding(),
200+
dstType.getMemorySpace(), /*mutableMemory=*/true);
199201
Value newTmemAlloc = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
200-
tmemAlloc.getLoc(), dstType, Value());
202+
tmemAlloc.getLoc(), newType, Value());
201203

202204
// Since tcgen05.cp followed by tcgen05.mma is guaranteed to execute in that
203205
// order, we do not need to wait for the completion of the copy before MMA.

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ LogicalResult TMEMStoreOp::verify() {
301301
if (!isa<triton::nvidia_gpu::TensorMemoryEncodingAttr,
302302
TensorMemoryScalesEncodingAttr>(getDst().getType().getEncoding()))
303303
return emitOpError("should use tensor memory encoding.");
304+
if (!getDst().getType().getMutableMemory()) {
305+
return emitOpError("Cannot store into an immutable alloc");
306+
}
304307
return success();
305308
}
306309

@@ -378,6 +381,10 @@ LogicalResult TMEMAllocOp::verify() {
378381
if (!isa<triton::nvidia_gpu::TensorMemoryEncodingAttr,
379382
TensorMemoryScalesEncodingAttr>(getType().getEncoding()))
380383
return emitOpError("should use tensor memory encoding.");
384+
if (!getSrc()) {
385+
if (!getType().getMutableMemory())
386+
return emitError("uninitialized alloc must have a mutable memdesc type");
387+
}
381388
return success();
382389
}
383390

@@ -404,6 +411,9 @@ LogicalResult TMEMCopyOp::verify() {
404411
getBarrier().getType().getMemorySpace())) {
405412
return emitOpError("The optional barrier should be a shared memory buffer");
406413
}
414+
if (!getDst().getType().getMutableMemory()) {
415+
return emitOpError("Cannot copy into an immutable alloc");
416+
}
407417

408418
auto srcTy = cast<triton::gpu::MemDescType>(getSrc().getType());
409419
auto sharedEnc =

test/TritonGPU/dot-operands.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
105105
#tmem_scales = #ttng.tensor_memory_scales_encoding<>
106106
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
107107
// CHECK-LABEL: @inject_tmem_copy
108+
// CHECK: ttng.tmem_alloc {{.*}}, mutable
108109
// CHECK: ttng.tmem_copy
109110

110111
tt.func public @inject_tmem_copy(%scale: tensor<2x512x!tt.ptr<i8>, #blocked4> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) attributes {noinline = false} {

test/TritonGPU/loop-pipeline-blackwell.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,16 +378,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
378378
%122 = tt.load %arg19 : tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>
379379

380380
%137 = ttg.local_alloc %121 : (tensor<1x2x32x4x4xi8, #blocked2>) -> !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>
381-
%130 = ttng.tmem_alloc : () -> !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory>
382-
ttng.tmem_copy %137, %130, : (!ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory>) -> ()
381+
%130 = ttng.tmem_alloc : () -> !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory, mutable>
382+
ttng.tmem_copy %137, %130, : (!ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory, mutable>) -> ()
383383

384384
%139 = ttg.local_alloc %122 : (tensor<1x2x32x4x4xi8, #blocked2>) -> !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>
385-
%131 = ttng.tmem_alloc : () -> !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory>
386-
ttng.tmem_copy %139, %131, : (!ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory>) -> ()
385+
%131 = ttng.tmem_alloc : () -> !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory, mutable>
386+
ttng.tmem_copy %139, %131, : (!ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory, mutable>) -> ()
387387

388388
%127 = ttng.tmem_alloc %arg15 : (tensor<128x128xf32, #blocked4>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
389389

390-
ttng.tc_gen5_mma_scaled %118, %120, %127, %130, %131, %true, %true lhs = e5m2 rhs = e5m2 : (!ttg.memdesc<128x256xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<256x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>, i1, i1) -> ()
390+
ttng.tc_gen5_mma_scaled %118, %120, %127, %130, %131, %true, %true lhs = e5m2 rhs = e5m2 : (!ttg.memdesc<128x256xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<256x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory, mutable>, i1, i1) -> ()
391391
%132 = ttng.tmem_load %127 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked4>
392392

393393
%133 = tt.addptr %arg16, %incr_A : tensor<128x256x!tt.ptr<f8E5M2>, #blocked>, tensor<128x256xi32, #blocked>

test/TritonGPU/mma-pipeline-blackwell.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -999,14 +999,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
999999
%21 = ttg.async_wait %arg7 {num = 0 : i32}
10001000
%22 = ttg.memdesc_subview %2[%19, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>
10011001

1002-
%127 = ttng.tmem_alloc : () -> !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory>
1003-
ttng.tmem_copy %arg8, %127, : (!ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory>) -> ()
1004-
%128 = ttng.tmem_alloc : () -> !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory>
1005-
ttng.tmem_copy %arg9, %128, : (!ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory>) -> ()
1002+
%127 = ttng.tmem_alloc : () -> !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory, mutable>
1003+
ttng.tmem_copy %arg8, %127, : (!ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory, mutable>) -> ()
1004+
%128 = ttng.tmem_alloc : () -> !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory, mutable>
1005+
ttng.tmem_copy %arg9, %128, : (!ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory, mutable>) -> ()
10061006

10071007
%tmem = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
10081008

1009-
ttng.tc_gen5_mma_scaled %20, %22, %tmem, %127, %128, %true, %true lhs = e5m2 rhs = e5m2: (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory>, !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory>, i1, i1) -> ()
1009+
ttng.tc_gen5_mma_scaled %20, %22, %tmem, %127, %128, %true, %true lhs = e5m2 rhs = e5m2: (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory, mutable>, i1, i1) -> ()
10101010
%acc_res = ttng.tmem_load %tmem : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1>
10111011
%23 = arith.addi %arg4, %c1_i32 : i32
10121012
%24 = arith.cmpi slt, %23, %c2_i32 : i32

test/TritonNvidiaGPU/invalid.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,47 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
1313

1414
// -----
1515

16+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
17+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
18+
tt.func public @alloc_tensor_memory() {
19+
// expected-error @+1 {{uninitialized alloc must have a mutable memdesc type}}
20+
%0 = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>
21+
tt.return
22+
}
23+
}
24+
25+
// -----
26+
27+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
28+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
29+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
30+
tt.func public @alloc_tensor_memory() {
31+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
32+
%true = arith.constant true
33+
%0 = ttng.tmem_alloc %cst : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>
34+
// expected-error @+1 {{Cannot store into an immutable alloc}}
35+
ttng.tmem_store %cst, %0, %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>
36+
tt.return
37+
}
38+
}
39+
40+
// -----
41+
42+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
43+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
44+
#tmem = #ttng.tensor_memory_scales_encoding<>
45+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
46+
tt.func public @alloc_tensor_memory(%arg: !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>) {
47+
%cst = arith.constant dense<0> : tensor<128x4xi8, #blocked>
48+
%0 = ttng.tmem_alloc %cst : (tensor<128x4xi8, #blocked>) -> !ttg.memdesc<128x4xi8, #tmem, #ttng.tensor_memory>
49+
// expected-error @+1 {{Cannot copy into an immutable alloc}}
50+
ttng.tmem_copy %arg, %0, : (!ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem, #ttng.tensor_memory>) -> ()
51+
tt.return
52+
}
53+
}
54+
55+
// -----
56+
1657
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}>
1758
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1859

0 commit comments

Comments
 (0)