Skip to content

Commit ff383bf

Browse files
committed
add tests
1 parent 3c41aa1 commit ff383bf

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

test/TritonGPU/amd/amd-fold-true-cmpi.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,55 @@
11
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -test-tritonamdgpu-fold-true-cmpi -canonicalize | FileCheck %s
22

3+
module attributes {"ttg.num-warps" = 4 : i32} {
4+
tt.func @cmpsle(%arg0: !tt.ptr<f32>) -> i1 {
5+
%c0 = arith.constant 0 : i32
6+
%c1024_i32 = arith.constant 1024 : i32
7+
%cmpsle = arith.cmpi sle, %c0, %c1024_i32 : i32
8+
tt.return %cmpsle: i1
9+
}
10+
}
11+
12+
// CHECK-LABEL: tt.func @cmpsle(
13+
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> i1 {
14+
// CHECK: %[[VAL_1:.*]] = arith.constant true
15+
// CHECK: tt.return %[[VAL_1]] : i1
16+
// CHECK: }
17+
18+
// -----
19+
20+
module attributes {"ttg.num-warps" = 4 : i32} {
21+
tt.func @assumepid(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {
22+
%c0 = arith.constant 0 : i32
23+
%c1024_i32 = arith.constant 1024 : i32
24+
%pid = tt.get_program_id x : i32
25+
%cmpsle = arith.cmpi sle, %pid, %c1024_i32 : i32
26+
llvm.intr.assume %cmpsle : i1
27+
%cmpsge = arith.cmpi sge, %pid, %c0 : i32
28+
llvm.intr.assume %cmpsge : i1
29+
%1 = arith.muli %pid, %c1024_i32 : i32
30+
%2 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
31+
%3 = tt.splat %2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
32+
%4 = tt.load %3 : tensor<1024x!tt.ptr<f32>>
33+
tt.return %4 : tensor<1024xf32>
34+
}
35+
}
36+
37+
// CHECK-LABEL: tt.func @assumepid(
38+
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024xf32> {
39+
// CHECK: %[[VAL_1:.*]] = arith.constant true
40+
// CHECK: %[[VAL_2:.*]] = arith.constant 1024 : i32
41+
// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32
42+
// CHECK: llvm.intr.assume %[[VAL_1]] : i1
43+
// CHECK: llvm.intr.assume %[[VAL_1]] : i1
44+
// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32
45+
// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_4]] : !tt.ptr<f32>, i32
46+
// CHECK: %[[VAL_6:.*]] = tt.splat %[[VAL_5]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
47+
// CHECK: %[[VAL_7:.*]] = tt.load %[[VAL_6]] : tensor<1024x!tt.ptr<f32>>
48+
// CHECK: tt.return %[[VAL_7]] : tensor<1024xf32>
49+
// CHECK: }
50+
51+
// -----
52+
353
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
454
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
555
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}>

third_party/amd/include/Analysis/RangeAnalysis.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,10 @@ collectRanges(const DataFlowSolver &solver, ValueRange values);
122122

123123
bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp);
124124

125-
void populateFoldTrueCmpIOpPatterns(
126-
RewritePatternSet &patterns, std::shared_ptr<DataFlowSolver> solver);
125+
bool isEmptyInitializedRange(ConstantIntRanges rv);
126+
127+
void populateFoldTrueCmpIOpPatterns(RewritePatternSet &patterns,
128+
std::shared_ptr<DataFlowSolver> solver);
127129

128130
} // namespace mlir::triton::AMD
129131

0 commit comments

Comments
 (0)