Skip to content

Commit 23b0072

Browse files
Mogballzwu-2025
authored andcommitted
[Warp Specialization] Final couple of fixes (triton-lang#6917)
* put critical path in the def partition, not the sink, to keep the correction partition small * don't assign a latency to MMA ops that have read-modify-write, since they won't get double buffered
1 parent b577cbd commit 23b0072

File tree

4 files changed

+176
-7
lines changed

4 files changed

+176
-7
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,12 @@ class AssignMMALatencies {
301301
// overlap. WS does not have this problem because the MMA is placed in
302302
// a different partition than the MMA, so we can correctly set the
303303
// latency.
304-
if (forOp->hasAttr(kWarpSpecializeAttrName))
305-
opLatency[&op] += 1;
304+
if (forOp->hasAttr(kWarpSpecializeAttrName)) {
305+
if (ttng::hasAccReadModifyWrite(mma, forOp))
306+
opLatency.erase(&op); // can't pipeline the MMA
307+
else
308+
opLatency[&op] += 1;
309+
}
306310
}
307311
}
308312
}

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,6 @@ static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
245245
for (auto [mmaOp, userPartition] : llvm::zip(mmas, userPartitions)) {
246246
scheduleUsers(loop, schedule, userPartition, mmaOp);
247247
}
248-
for (ttng::MMAv5OpInterface mmaOp : mmas) {
249-
scheduleDependencies(loop, schedule, defaultPartition, mmaOp);
250-
}
251248

252249
return schedule;
253250
}
@@ -444,10 +441,10 @@ void propagatePartitions(scf::ForOp loop, WarpSchedule &schedule) {
444441
});
445442
}
446443

447-
// If all ops are on the critical path, assign them to the sink partition.
444+
// If all ops are on the critical path, assign them to the def partition.
448445
if (critPath.size() == cluster.ops.size()) {
449446
for (Operation *op : cluster.ops)
450-
schedule.insert(sinkPartition, op);
447+
schedule.insert(defPartition, op);
451448
continue;
452449
}
453450

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: triton-opt %s --tritongpu-hoist-tmem-alloc --tritongpu-partition-scheduling -allow-unregistered-dialect | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
4+
#load_blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
5+
6+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
7+
#shared_T = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
8+
9+
#smem = #ttg.shared_memory
10+
#tmem_acc = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true>
11+
#tmem_lhs = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = false>
12+
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
13+
14+
// CHECK-LABEL: @attention_forward
15+
tt.func public @attention_forward(
16+
%Q_shared: !ttg.memdesc<256x64xf16, #shared, #smem>,
17+
%K_desc: !tt.tensordesc<tensor<64x64xf16, #shared>>,
18+
%V_desc: !tt.tensordesc<tensor<64x64xf16, #shared>>,
19+
%qk_scale: f32,
20+
%n_tiles: i32
21+
) {
22+
%true = arith.constant true
23+
%false = arith.constant false
24+
%c0_i32 = arith.constant 0 : i32
25+
%c64_i32 = arith.constant 64 : i32
26+
27+
%neg_inf = arith.constant dense<0xFF800000> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
28+
%zero = arith.constant dense<0.0> : tensor<256x64xf32, #blocked>
29+
%one = arith.constant dense<1.0> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
30+
31+
%QK_tmem, %QK_tok = ttng.tmem_alloc : () -> (!ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>, !ttg.async.token)
32+
33+
%loop_outs:4 = scf.for %i = %c0_i32 to %n_tiles step %c64_i32 iter_args(
34+
%l_i = %one,
35+
%acc = %zero,
36+
%m_i = %neg_inf,
37+
%e_i = %one
38+
) -> (
39+
tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>,
40+
tensor<256x64xf32, #blocked>,
41+
tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>,
42+
tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
43+
) : i32 {
44+
45+
%K = tt.descriptor_load %K_desc[%i, %c0_i32] : !tt.tensordesc<tensor<64x64xf16, #shared>> -> tensor<64x64xf16, #load_blocked>
46+
%K_shared = ttg.local_alloc %K : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem>
47+
48+
%K_trans = ttg.memdesc_trans %K_shared {order = array<i32: 1, 0>} : !ttg.memdesc<64x64xf16, #shared, #smem> -> !ttg.memdesc<64x64xf16, #shared_T, #smem>
49+
%QK_mma_tok = ttng.tc_gen5_mma %Q_shared, %K_trans, %QK_tmem[%QK_tok], %false, %true : !ttg.memdesc<256x64xf16, #shared, #smem>, !ttg.memdesc<64x64xf16, #shared_T, #smem>, !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>
50+
51+
%QK, %QK_load_tok = ttng.tmem_load %QK_tmem[%QK_mma_tok] : !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable> -> tensor<256x64xf32, #blocked>
52+
%row_max = "compute_row_max"(%QK, %qk_scale) : (tensor<256x64xf32, #blocked>, f32) -> tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
53+
%QK_adj = "sub_row_max"(%QK, %row_max, %qk_scale) : (tensor<256x64xf32, #blocked>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, f32) -> tensor<256x64xf32, #blocked>
54+
// CHECK: [[SOFTMAX:%.*]] = math.exp2 {{.*}} {ttg.partition = 0 : i32} : tensor<256x64xf32
55+
%softmax = math.exp2 %QK_adj : tensor<256x64xf32, #blocked>
56+
%diff = arith.subf %m_i, %row_max : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
57+
%alpha = math.exp2 %diff : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
58+
59+
%l_ij = "tt.reduce"(%softmax) <{axis = 1 : i32}> ({
60+
^bb0(%arg29: f32, %arg30: f32):
61+
%68 = arith.addf %arg29, %arg30 : f32
62+
tt.reduce.return %68 : f32
63+
}) : (tensor<256x64xf32, #blocked>) -> tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
64+
%l_i_scaled = arith.mulf %l_i, %alpha : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
65+
%next_l_i = arith.addf %l_i_scaled, %l_ij : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
66+
67+
%alpha_0 = tt.expand_dims %alpha {axis = 1 : i32} : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xf32, #blocked>
68+
%alpha_1 = tt.broadcast %alpha_0 : tensor<256x1xf32, #blocked> -> tensor<256x64xf32, #blocked>
69+
70+
%acc_corrected = arith.mulf %acc, %alpha_1 : tensor<256x64xf32, #blocked>
71+
72+
// CHECK: [[X:%.*]] = arith.addf [[SOFTMAX]], [[SOFTMAX]] {ttg.partition = 0 : i32}
73+
%x = arith.addf %softmax, %softmax : tensor<256x64xf32, #blocked>
74+
// CHECK-NEXT: [[ACC_X:%.*]] = arith.addf %{{.*}}, [[X]] {ttg.partition = 3 : i32}
75+
%acc_x = arith.addf %acc, %x : tensor<256x64xf32, #blocked>
76+
%e = "sum"(%acc_x) : (tensor<256x64xf32, #blocked>) -> tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
77+
%next_e_i = arith.addf %e_i, %e : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
78+
79+
%V = tt.descriptor_load %V_desc[%i, %c0_i32] : !tt.tensordesc<tensor<64x64xf16, #shared>> -> tensor<64x64xf16, #load_blocked>
80+
%V_shared = ttg.local_alloc %V : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem>
81+
%P = arith.truncf %softmax : tensor<256x64xf32, #blocked> to tensor<256x64xf16, #blocked>
82+
83+
%P_tmem = ttng.tmem_alloc %P : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #tmem_lhs, #ttng.tensor_memory>
84+
%acc_tmem, %acc_tok = ttng.tmem_alloc %acc_corrected : (tensor<256x64xf32, #blocked>) -> (!ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>, !ttg.async.token)
85+
%PV_mma_tok = ttng.tc_gen5_mma %P_tmem, %V_shared, %acc_tmem[%acc_tok], %true, %true : !ttg.memdesc<256x64xf16, #tmem_lhs, #ttng.tensor_memory>, !ttg.memdesc<64x64xf16, #shared, #smem>, !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>
86+
%O, %O_tok = ttng.tmem_load %acc_tmem[%PV_mma_tok] : !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable> -> tensor<256x64xf32, #blocked>
87+
88+
scf.yield %next_l_i, %O, %row_max, %next_e_i : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256x64xf32, #blocked>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
89+
} {tt.warp_specialize}
90+
91+
"use"(%loop_outs#0, %loop_outs#1, %loop_outs#2) : (tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256x64xf32, #blocked>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) -> ()
92+
93+
tt.return
94+
}
95+
96+
}

test/TritonGPU/pipeline-assign-latencies.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,3 +1014,75 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
10141014
tt.return %2 : tensor<128x128xf16, #blocked1>
10151015
}
10161016
}
1017+
1018+
// -----
1019+
1020+
#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
1021+
#load_blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
1022+
1023+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
1024+
#shared_T = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
1025+
1026+
#smem = #ttg.shared_memory
1027+
#tmem_acc = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true>
1028+
#tmem_lhs = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = false>
1029+
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
1030+
1031+
// CHECK-LABEL: @attention_forward
1032+
tt.func public @attention_forward(
1033+
%Q_shared: !ttg.memdesc<256x64xf16, #shared, #smem>,
1034+
%K_desc: !tt.tensordesc<tensor<64x64xf16, #shared>>,
1035+
%V_desc: !tt.tensordesc<tensor<64x64xf16, #shared>>,
1036+
%qk_scale: f32,
1037+
%n_tiles: i32
1038+
) {
1039+
%true = arith.constant true
1040+
%false = arith.constant false
1041+
%c0_i32 = arith.constant 0 : i32
1042+
%c64_i32 = arith.constant 64 : i32
1043+
1044+
%neg_inf = arith.constant dense<0xFF800000> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
1045+
%zero = arith.constant dense<0.0> : tensor<256x64xf32, #blocked>
1046+
%one = arith.constant dense<1.0> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
1047+
1048+
%QK_tmem, %QK_tok = ttng.tmem_alloc : () -> (!ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>, !ttg.async.token)
1049+
1050+
%loop_outs:3 = scf.for %i = %c0_i32 to %n_tiles step %c64_i32 iter_args(
1051+
%l_i = %one,
1052+
%acc = %zero,
1053+
%m_i = %neg_inf
1054+
) -> (
1055+
tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>,
1056+
tensor<256x64xf32, #blocked>,
1057+
tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
1058+
) : i32 {
1059+
// CHECK: descriptor_load {{.*}} {tt.latency = 2 : i32}
1060+
%K = tt.descriptor_load %K_desc[%i, %c0_i32] : !tt.tensordesc<tensor<64x64xf16, #shared>> -> tensor<64x64xf16, #load_blocked>
1061+
%K_shared = ttg.local_alloc %K : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem>
1062+
%K_trans = ttg.memdesc_trans %K_shared {order = array<i32: 1, 0>} : !ttg.memdesc<64x64xf16, #shared, #smem> -> !ttg.memdesc<64x64xf16, #shared_T, #smem>
1063+
// CHECK: tc_gen5_mma {{.*}} {tt.latency = 2 : i32, tt.self_latency = 1 : i32}
1064+
%QK_mma_tok = ttng.tc_gen5_mma %Q_shared, %K_trans, %QK_tmem[%QK_tok], %false, %true : !ttg.memdesc<256x64xf16, #shared, #smem>, !ttg.memdesc<64x64xf16, #shared_T, #smem>, !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>
1065+
%QK, %QK_load_tok = ttng.tmem_load %QK_tmem[%QK_mma_tok] : !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable> -> tensor<256x64xf32, #blocked>
1066+
1067+
%alpha_1, %P, %next_l_i, %row_max = "softmax_work"(%QK, %l_i, %m_i, %qk_scale) : (tensor<256x64xf32, #blocked>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, f32) -> (tensor<256x64xf32, #blocked>, tensor<256x64xf16, #blocked>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>)
1068+
1069+
%acc_corrected = arith.mulf %acc, %alpha_1 : tensor<256x64xf32, #blocked>
1070+
1071+
// CHECK: descriptor_load {{.*}} {tt.latency = 2 : i32}
1072+
%V = tt.descriptor_load %V_desc[%i, %c0_i32] : !tt.tensordesc<tensor<64x64xf16, #shared>> -> tensor<64x64xf16, #load_blocked>
1073+
%V_shared = ttg.local_alloc %V : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem>
1074+
%P_tmem = ttng.tmem_alloc %P : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #tmem_lhs, #ttng.tensor_memory>
1075+
%acc_tmem, %acc_tok = ttng.tmem_alloc %acc_corrected : (tensor<256x64xf32, #blocked>) -> (!ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>, !ttg.async.token)
1076+
// CHECK: tc_gen5_mma {{.*}} {tt.self_latency = 1 : i32}
1077+
%PV_mma_tok = ttng.tc_gen5_mma %P_tmem, %V_shared, %acc_tmem[%acc_tok], %true, %true : !ttg.memdesc<256x64xf16, #tmem_lhs, #ttng.tensor_memory>, !ttg.memdesc<64x64xf16, #shared, #smem>, !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>
1078+
%O, %O_tok = ttng.tmem_load %acc_tmem[%PV_mma_tok] : !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable> -> tensor<256x64xf32, #blocked>
1079+
1080+
scf.yield %next_l_i, %O, %row_max : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256x64xf32, #blocked>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
1081+
} {tt.warp_specialize}
1082+
1083+
"use"(%loop_outs#0, %loop_outs#1, %loop_outs#2) : (tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256x64xf32, #blocked>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) -> ()
1084+
1085+
tt.return
1086+
}
1087+
1088+
}

0 commit comments

Comments
 (0)