|
1 | 1 | // RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -test-tritonamdgpu-fold-true-cmpi -canonicalize | FileCheck %s |
2 | 2 |
|
| 3 | +module attributes {"ttg.num-warps" = 4 : i32} { |
| 4 | + tt.func @cmpsle(%arg0: !tt.ptr<f32>) -> i1 { |
| 5 | + %c0 = arith.constant 0 : i32 |
| 6 | + %c1024_i32 = arith.constant 1024 : i32 |
| 7 | + %cmpsle = arith.cmpi sle, %c0, %c1024_i32 : i32 |
| 8 | + tt.return %cmpsle: i1 |
| 9 | + } |
| 10 | +} |
| 11 | + |
| 12 | +// CHECK-LABEL: tt.func @cmpsle( |
| 13 | +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> i1 { |
| 14 | +// CHECK: %[[VAL_1:.*]] = arith.constant true |
| 15 | +// CHECK: tt.return %[[VAL_1]] : i1 |
| 16 | +// CHECK: } |
| 17 | + |
| 18 | +// ----- |
| 19 | + |
| 20 | +module attributes {"ttg.num-warps" = 4 : i32} { |
| 21 | + tt.func @assumepid(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> { |
| 22 | + %c0 = arith.constant 0 : i32 |
| 23 | + %c1024_i32 = arith.constant 1024 : i32 |
| 24 | + %pid = tt.get_program_id x : i32 |
| 25 | + %cmpsle = arith.cmpi sle, %pid, %c1024_i32 : i32 |
| 26 | + llvm.intr.assume %cmpsle : i1 |
| 27 | + %cmpsge = arith.cmpi sge, %pid, %c0 : i32 |
| 28 | + llvm.intr.assume %cmpsge : i1 |
| 29 | + %1 = arith.muli %pid, %c1024_i32 : i32 |
| 30 | + %2 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32 |
| 31 | + %3 = tt.splat %2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> |
| 32 | + %4 = tt.load %3 : tensor<1024x!tt.ptr<f32>> |
| 33 | + tt.return %4 : tensor<1024xf32> |
| 34 | + } |
| 35 | +} |
| 36 | + |
| 37 | +// CHECK-LABEL: tt.func @assumepid( |
| 38 | +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024xf32> { |
| 39 | +// CHECK: %[[VAL_1:.*]] = arith.constant true |
| 40 | +// CHECK: %[[VAL_2:.*]] = arith.constant 1024 : i32 |
| 41 | +// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32 |
| 42 | +// CHECK: llvm.intr.assume %[[VAL_1]] : i1 |
| 43 | +// CHECK: llvm.intr.assume %[[VAL_1]] : i1 |
| 44 | +// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32 |
| 45 | +// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_4]] : !tt.ptr<f32>, i32 |
| 46 | +// CHECK: %[[VAL_6:.*]] = tt.splat %[[VAL_5]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> |
| 47 | +// CHECK: %[[VAL_7:.*]] = tt.load %[[VAL_6]] : tensor<1024x!tt.ptr<f32>> |
| 48 | +// CHECK: tt.return %[[VAL_7]] : tensor<1024xf32> |
| 49 | +// CHECK: } |
| 50 | + |
| 51 | +// ----- |
| 52 | + |
3 | 53 | #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> |
4 | 54 | #blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> |
5 | 55 | #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> |
|
0 commit comments