|
| 1 | +// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops="arch-generation-name=gfx942" | FileCheck %s |
| 2 | + |
| 3 | +// Test that tt.load with i64 offsets derived from provably bounded non-negative |
| 4 | +// expressions is converted to amdg.buffer_load with an arith.trunci from i64 to i32. |
| 5 | + |
| 6 | +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> |
| 7 | + |
| 8 | +// CHECK-LABEL: @load_i64_offset_bounded |
| 9 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { |
| 10 | + tt.func @load_i64_offset_bounded(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) -> tensor<256xf32, #blocked> { |
| 11 | + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked> |
| 12 | + %range_ext = arith.extsi %range : tensor<256xi32, #blocked> to tensor<256xi64, #blocked> |
| 13 | + %c1024_i64 = arith.constant 1024 : i64 |
| 14 | + %stride = tt.splat %c1024_i64 : i64 -> tensor<256xi64, #blocked> |
| 15 | + %offset = arith.muli %range_ext, %stride : tensor<256xi64, #blocked> |
| 16 | + %base = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked> |
| 17 | + %ptr = tt.addptr %base, %offset : tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xi64, #blocked> |
| 18 | + // CHECK: arith.trunci |
| 19 | + // CHECK-SAME: tensor<256xi64, |
| 20 | + // CHECK-SAME: to tensor<256xi32, |
| 21 | + // CHECK: amdg.buffer_load |
| 22 | + %val = tt.load %ptr : tensor<256x!tt.ptr<f32>, #blocked> |
| 23 | + tt.return %val : tensor<256xf32, #blocked> |
| 24 | + } |
| 25 | +} |
| 26 | + |
| 27 | +// ----- |
| 28 | + |
| 29 | +// Test that i64 offset loads are NOT converted when the offset may be negative. |
| 30 | + |
| 31 | +#blocked1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> |
| 32 | + |
| 33 | +// CHECK-LABEL: @load_i64_offset_possibly_negative |
| 34 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { |
| 35 | + tt.func @load_i64_offset_possibly_negative(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i64) -> tensor<256xf32, #blocked1> { |
| 36 | + %splat_off = tt.splat %arg1 : i64 -> tensor<256xi64, #blocked1> |
| 37 | + %base = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked1> |
| 38 | + %ptr = tt.addptr %base, %splat_off : tensor<256x!tt.ptr<f32>, #blocked1>, tensor<256xi64, #blocked1> |
| 39 | + // CHECK-NOT: amdg.buffer_load |
| 40 | + // CHECK: tt.load |
| 41 | + %val = tt.load %ptr : tensor<256x!tt.ptr<f32>, #blocked1> |
| 42 | + tt.return %val : tensor<256xf32, #blocked1> |
| 43 | + } |
| 44 | +} |
| 45 | + |
| 46 | +// ----- |
| 47 | + |
| 48 | +// Test that i64 offset stores are converted with trunci when offset is bounded. |
| 49 | + |
| 50 | +#blocked2 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> |
| 51 | + |
| 52 | +// CHECK-LABEL: @store_i64_offset_bounded |
| 53 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { |
| 54 | + tt.func @store_i64_offset_bounded(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %data: tensor<256xf32, #blocked2>) { |
| 55 | + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> |
| 56 | + %range_ext = arith.extsi %range : tensor<256xi32, #blocked2> to tensor<256xi64, #blocked2> |
| 57 | + %c512_i64 = arith.constant 512 : i64 |
| 58 | + %stride = tt.splat %c512_i64 : i64 -> tensor<256xi64, #blocked2> |
| 59 | + %offset = arith.muli %range_ext, %stride : tensor<256xi64, #blocked2> |
| 60 | + %base = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked2> |
| 61 | + %ptr = tt.addptr %base, %offset : tensor<256x!tt.ptr<f32>, #blocked2>, tensor<256xi64, #blocked2> |
| 62 | + // CHECK: arith.trunci |
| 63 | + // CHECK-SAME: tensor<256xi64, |
| 64 | + // CHECK-SAME: to tensor<256xi32, |
| 65 | + // CHECK: amdg.buffer_store |
| 66 | + tt.store %ptr, %data : tensor<256x!tt.ptr<f32>, #blocked2> |
| 67 | + tt.return |
| 68 | + } |
| 69 | +} |
| 70 | + |
| 71 | +// ----- |
| 72 | + |
| 73 | +// Test that i64 offset loads with tt.pointer_range=32 attribute are converted. |
| 74 | + |
| 75 | +#blocked3 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> |
| 76 | + |
| 77 | +// CHECK-LABEL: @load_i64_offset_pointer_range_32 |
| 78 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { |
| 79 | + tt.func @load_i64_offset_pointer_range_32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: i64) -> tensor<256xf32, #blocked3> { |
| 80 | + %splat_off = tt.splat %arg1 : i64 -> tensor<256xi64, #blocked3> |
| 81 | + %base = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked3> |
| 82 | + %ptr = tt.addptr %base, %splat_off : tensor<256x!tt.ptr<f32>, #blocked3>, tensor<256xi64, #blocked3> |
| 83 | + // CHECK: arith.trunci |
| 84 | + // CHECK-SAME: tensor<256xi64, |
| 85 | + // CHECK-SAME: to tensor<256xi32, |
| 86 | + // CHECK: amdg.buffer_load |
| 87 | + %val = tt.load %ptr : tensor<256x!tt.ptr<f32>, #blocked3> |
| 88 | + tt.return %val : tensor<256xf32, #blocked3> |
| 89 | + } |
| 90 | +} |
0 commit comments