Skip to content

Commit 0315d72

Browse files
authored
[AMD] DCE/canonicalize true epilogue conditionals (#6314)
This PR adds a pattern that folds "true" `arith.cmpi` operations to `arith.constant true`; e.g. ```mlir %c0 = arith.constant 0 : i32 %c1024_i32 = arith.constant 1024 : i32 %cmpsge = arith.cmpi sge, %c1024_i32, %c0 : i32 ``` -> ```mlir %cmpsge = arith.constant true ``` (after DCE). The specific use case is "unguarding" the epilogue in pipelined loops (e.g., as produced by `tritonamdgpu-stream-pipeline`). So e.g., ```mlir tt.func @assume_matmul(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %arg4: !tt.ptr<f16>) -> tensor<128x128xf32, #mma> { ... %20:6 = scf.for ... { scf.yield ... } ... %27 = arith.cmpi sge, %26, %c1 : index %31 = scf.if %27 -> (tensor<128x128xf32, #mma>) { %33 = tt.dot %28, %30, %20#2 scf.yield %33 : tensor<128x128xf32, #mma> } else { scf.yield %20#2 : tensor<128x128xf32, #mma> } %32 = arith.select %27, %31, %20#2 : tensor<128x128xf32, #mma> ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> tt.return %32 : tensor<128x128xf32, #mma> } ``` becomes ```mlir tt.func @assume_matmul(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %arg4: !tt.ptr<f16>) -> tensor<128x128xf32, #mma> { ... %20:6 = scf.for ... { scf.yield ... } %21 = ttg.local_load %20#4 %22 = ttg.local_load %20#5 %23 = arith.mulf %22, %cst %24 = tt.dot %21, %23, %20#2 ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> tt.return %24 : tensor<128x128xf32, #mma> } ``` Notice both the `scf.if` and `arith.select` are canonicalized away. **Note**, this _usually_ requires the use of `tl.assume` to hint/constrain the operands of the `arith.cmpi`; specifically wrt the original loop bounds something like `%stop // %step >= 1` (or whatever the arithmetic on the loop bounds needs to be...).
1 parent 0e3492c commit 0315d72

File tree

10 files changed

+398
-6
lines changed

10 files changed

+398
-6
lines changed

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ void registerTestAlignmentPass();
3434
void registerTestAllocationPass();
3535
void registerTestMembarPass();
3636
void registerTestTritonAMDGPURangeAnalysis();
37+
void registerTestTritonAMDGPUFoldTrueCmpIOp();
3738
} // namespace test
3839
} // namespace mlir
3940

@@ -47,6 +48,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
4748
mlir::test::registerTestAllocationPass();
4849
mlir::test::registerTestMembarPass();
4950
mlir::test::registerTestTritonAMDGPURangeAnalysis();
51+
mlir::test::registerTestTritonAMDGPUFoldTrueCmpIOp();
5052
mlir::triton::registerConvertTritonToTritonGPUPass();
5153
mlir::triton::gpu::registerAllocateSharedMemoryPass();
5254
mlir::triton::gpu::registerTritonGPUAllocateWarpGroups();

python/test/unit/language/test_core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4666,7 +4666,10 @@ def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr):
46664666
if is_interpreter():
46674667
return
46684668

4669-
assert 'llvm.assume' in pgm.asm['llir']
4669+
assert 'llvm.intr.assume' in pgm.asm['ttgir']
4670+
# stream pipeliner on AMD folds true cmpi ops to %true (Which llvm itself then dces)
4671+
if not is_hip():
4672+
assert 'llvm.assume' in pgm.asm['llir']
46704673

46714674

46724675
# ---------------
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -test-tritonamdgpu-fold-true-cmpi -canonicalize | FileCheck %s
2+
3+
module attributes {"ttg.num-warps" = 4 : i32} {
4+
tt.func @cmpsle(%arg0: !tt.ptr<f32>) -> i1 {
5+
%c0 = arith.constant 0 : i32
6+
%c1024_i32 = arith.constant 1024 : i32
7+
%cmpsle = arith.cmpi sle, %c0, %c1024_i32 : i32
8+
tt.return %cmpsle: i1
9+
}
10+
}
11+
12+
// CHECK-LABEL: tt.func @cmpsle(
13+
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> i1 {
14+
// CHECK: %[[VAL_1:.*]] = arith.constant true
15+
// CHECK: tt.return %[[VAL_1]] : i1
16+
// CHECK: }
17+
18+
// -----
19+
20+
module attributes {"ttg.num-warps" = 4 : i32} {
21+
tt.func @assumepid(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {
22+
%c0 = arith.constant 0 : i32
23+
%c1024_i32 = arith.constant 1024 : i32
24+
%pid = tt.get_program_id x : i32
25+
%cmpsle = arith.cmpi sle, %pid, %c1024_i32 : i32
26+
llvm.intr.assume %cmpsle : i1
27+
%cmpsge = arith.cmpi sge, %pid, %c0 : i32
28+
llvm.intr.assume %cmpsge : i1
29+
%1 = arith.muli %pid, %c1024_i32 : i32
30+
%2 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
31+
%3 = tt.splat %2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
32+
%4 = tt.load %3 : tensor<1024x!tt.ptr<f32>>
33+
tt.return %4 : tensor<1024xf32>
34+
}
35+
}
36+
37+
// CHECK-LABEL: tt.func @assumepid(
38+
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024xf32> {
39+
// CHECK: %[[VAL_1:.*]] = arith.constant true
40+
// CHECK: %[[VAL_2:.*]] = arith.constant 1024 : i32
41+
// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32
42+
// CHECK: llvm.intr.assume %[[VAL_1]] : i1
43+
// CHECK: llvm.intr.assume %[[VAL_1]] : i1
44+
// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32
45+
// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_4]] : !tt.ptr<f32>, i32
46+
// CHECK: %[[VAL_6:.*]] = tt.splat %[[VAL_5]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
47+
// CHECK: %[[VAL_7:.*]] = tt.load %[[VAL_6]] : tensor<1024x!tt.ptr<f32>>
48+
// CHECK: tt.return %[[VAL_7]] : tensor<1024xf32>
49+
// CHECK: }
50+
51+
// -----
52+
53+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
54+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
55+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}>
56+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
57+
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
58+
#smem = #ttg.shared_memory
59+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
60+
tt.func @assume_matmul(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %arg4: !tt.ptr<f16>) -> tensor<128x128xf32, #mma> {
61+
%c-1 = arith.constant -1 : index
62+
%c1 = arith.constant 1 : index
63+
%c0 = arith.constant 0 : index
64+
%c1_i32 = arith.constant 1 : i32
65+
%c0_i32 = arith.constant 0 : i32
66+
%true = arith.constant true
67+
%cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
68+
%cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked>
69+
%cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1>
70+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
71+
%cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked>
72+
%0 = tt.splat %arg3 : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>, #blocked1>
73+
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
74+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1>
75+
%3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1>
76+
%4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<128x32xi32, #blocked1>
77+
%5 = tt.splat %arg4 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #blocked>
78+
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
79+
%7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked>
80+
%8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked>
81+
%9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<32x128xi32, #blocked>
82+
%10 = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable>
83+
%11 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable>
84+
%12 = arith.cmpi slt, %arg0, %arg1 : index
85+
%13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1>
86+
%14 = tt.load %4, %13 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32x!tt.ptr<f16>, #blocked1>
87+
%15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked>
88+
%16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr<f16>, #blocked>
89+
%17 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
90+
ttg.local_store %14, %17 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
91+
%18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
92+
ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
93+
%19 = arith.subi %arg1, %arg2 : index
94+
%20:6 = scf.for %arg5 = %arg0 to %19 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>) {
95+
%33 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<128x32xi32, #blocked1>
96+
%34 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<32x128xi32, #blocked>
97+
llvm.intr.assume %true : i1
98+
%35 = tt.load %33 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32x!tt.ptr<f16>, #blocked1>
99+
%36 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
100+
%37 = tt.load %34 : tensor<32x128x!tt.ptr<f16>, #blocked>
101+
%38 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
102+
%39 = arith.mulf %38, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
103+
%40 = tt.dot %36, %39, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
104+
%41 = arith.addi %arg9, %c1_i32 : i32
105+
%42 = arith.cmpi slt, %41, %c1_i32 : i32
106+
%43 = arith.select %42, %41, %c0_i32 : i32
107+
%44 = ttg.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
108+
ttg.local_store %35, %44 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
109+
%45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
110+
ttg.local_store %37, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
111+
scf.yield %33, %34, %40, %43, %44, %45 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
112+
}
113+
%21 = arith.cmpi slt, %arg2, %c0 : index
114+
%22 = arith.select %21, %c1, %c-1 : index
115+
%23 = arith.subi %arg1, %arg0 : index
116+
%24 = arith.addi %23, %arg2 : index
117+
%25 = arith.addi %24, %22 : index
118+
%26 = arith.divsi %25, %arg2 : index
119+
%28 = ttg.local_load %20#4 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
120+
%29 = ttg.local_load %20#5 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
121+
%30 = arith.mulf %29, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
122+
%27 = arith.cmpi sge, %26, %c1 : index
123+
llvm.intr.assume %27 : i1
124+
%31 = scf.if %27 -> (tensor<128x128xf32, #mma>) {
125+
%33 = tt.dot %28, %30, %20#2 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
126+
scf.yield %33 : tensor<128x128xf32, #mma>
127+
} else {
128+
scf.yield %20#2 : tensor<128x128xf32, #mma>
129+
}
130+
%32 = arith.select %27, %31, %20#2 : tensor<128x128xf32, #mma>
131+
ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable>
132+
ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable>
133+
tt.return %32 : tensor<128x128xf32, #mma>
134+
}
135+
}
136+
137+
// CHECK: #[[$ATTR_2:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}>
138+
// CHECK: #[[$ATTR_3:.+]] = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
139+
// CHECK: #[[$ATTR_4:.+]] = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
140+
// CHECK: #[[$ATTR_5:.+]] = #ttg.shared_memory
141+
142+
// CHECK-LABEL: tt.func @assume_matmul(
143+
// CHECK: %[[VAL_7:.*]] = arith.constant true
144+
// CHECK: %[[VAL_8:.*]] = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$ATTR_2]], kWidth = 2}>>
145+
// CHECK: %[[VAL_23:.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #[[$ATTR_3]], #[[$ATTR_5]], mutable>
146+
// CHECK: %[[VAL_24:.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable>
147+
// CHECK: %[[VAL_33:.*]]:6 = scf.for
148+
// CHECK: scf.yield
149+
// CHECK: }
150+
// CHECK-NEXT: %[[VAL_54:.*]] = ttg.local_load %[[VAL_55:.*]]#4 : !ttg.memdesc<128x32xf16, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$ATTR_2]], kWidth = 2}>>
151+
// CHECK-NEXT: %[[VAL_56:.*]] = ttg.local_load %[[VAL_55]]#5 : !ttg.memdesc<32x128xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$ATTR_2]], kWidth = 2}>>
152+
// CHECK-NEXT: %[[VAL_57:.*]] = arith.mulf %[[VAL_56]], %[[VAL_8]] : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$ATTR_2]], kWidth = 2}>>
153+
// CHECK-NEXT: llvm.intr.assume %[[VAL_7]] : i1
154+
// CHECK-NEXT: %[[VAL_58:.*]] = tt.dot %[[VAL_54]], %[[VAL_57]], %[[VAL_55]]#2 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$ATTR_2]], kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$ATTR_2]], kWidth = 2}>> -> tensor<128x128xf32, #[[$ATTR_2]]>
155+
// CHECK-NEXT: ttg.local_dealloc %[[VAL_23]] : !ttg.memdesc<1x128x32xf16, #[[$ATTR_3]], #[[$ATTR_5]], mutable>
156+
// CHECK-NEXT: ttg.local_dealloc %[[VAL_24]] : !ttg.memdesc<1x32x128xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable>
157+
// CHECK-NEXT: tt.return %[[VAL_58]] : tensor<128x128xf32, #[[$ATTR_2]]>
158+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)