Skip to content

Commit 3d0e00c

Browse files
[AMD] Enable buffer ops for i64 offsets in ConvertToBufferOps (#9619)
Previously, `ConvertToBufferOps` unconditionally rejected any load/store with 64-bit offsets. This prevented buffer_load/buffer_store from being used for kernels (e.g. flex attention) that use 64-bit pointer arithmetic. This patch allows 64-bit offsets through when they can be proved safe, truncates them to 32-bit, and uses the faster buffer instructions instead. --------- Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
1 parent 60db4a5 commit 3d0e00c

File tree

3 files changed

+193
-104
lines changed

3 files changed

+193
-104
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops="arch-generation-name=gfx942" | FileCheck %s
2+
3+
// Test that tt.load with i64 offsets derived from provably bounded non-negative
4+
// expressions is converted to amdg.buffer_load with an arith.trunci from i64 to i32.
5+
6+
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
7+
8+
// CHECK-LABEL: @load_i64_offset_bounded
9+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
10+
tt.func @load_i64_offset_bounded(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) -> tensor<256xf32, #blocked> {
11+
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
12+
%range_ext = arith.extsi %range : tensor<256xi32, #blocked> to tensor<256xi64, #blocked>
13+
%c1024_i64 = arith.constant 1024 : i64
14+
%stride = tt.splat %c1024_i64 : i64 -> tensor<256xi64, #blocked>
15+
%offset = arith.muli %range_ext, %stride : tensor<256xi64, #blocked>
16+
%base = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked>
17+
%ptr = tt.addptr %base, %offset : tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xi64, #blocked>
18+
// CHECK: arith.trunci
19+
// CHECK-SAME: tensor<256xi64,
20+
// CHECK-SAME: to tensor<256xi32,
21+
// CHECK: amdg.buffer_load
22+
%val = tt.load %ptr : tensor<256x!tt.ptr<f32>, #blocked>
23+
tt.return %val : tensor<256xf32, #blocked>
24+
}
25+
}
26+
27+
// -----
28+
29+
// Test that i64 offset loads are NOT converted when the offset may be negative.
30+
31+
#blocked1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
32+
33+
// CHECK-LABEL: @load_i64_offset_possibly_negative
34+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
35+
tt.func @load_i64_offset_possibly_negative(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i64) -> tensor<256xf32, #blocked1> {
36+
%splat_off = tt.splat %arg1 : i64 -> tensor<256xi64, #blocked1>
37+
%base = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked1>
38+
%ptr = tt.addptr %base, %splat_off : tensor<256x!tt.ptr<f32>, #blocked1>, tensor<256xi64, #blocked1>
39+
// CHECK-NOT: amdg.buffer_load
40+
// CHECK: tt.load
41+
%val = tt.load %ptr : tensor<256x!tt.ptr<f32>, #blocked1>
42+
tt.return %val : tensor<256xf32, #blocked1>
43+
}
44+
}
45+
46+
// -----
47+
48+
// Test that i64 offset stores are converted with trunci when offset is bounded.
49+
50+
#blocked2 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
51+
52+
// CHECK-LABEL: @store_i64_offset_bounded
53+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
54+
tt.func @store_i64_offset_bounded(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %data: tensor<256xf32, #blocked2>) {
55+
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
56+
%range_ext = arith.extsi %range : tensor<256xi32, #blocked2> to tensor<256xi64, #blocked2>
57+
%c512_i64 = arith.constant 512 : i64
58+
%stride = tt.splat %c512_i64 : i64 -> tensor<256xi64, #blocked2>
59+
%offset = arith.muli %range_ext, %stride : tensor<256xi64, #blocked2>
60+
%base = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked2>
61+
%ptr = tt.addptr %base, %offset : tensor<256x!tt.ptr<f32>, #blocked2>, tensor<256xi64, #blocked2>
62+
// CHECK: arith.trunci
63+
// CHECK-SAME: tensor<256xi64,
64+
// CHECK-SAME: to tensor<256xi32,
65+
// CHECK: amdg.buffer_store
66+
tt.store %ptr, %data : tensor<256x!tt.ptr<f32>, #blocked2>
67+
tt.return
68+
}
69+
}
70+
71+
// -----
72+
73+
// Test that i64 offset loads with tt.pointer_range=32 attribute are converted.
74+
75+
#blocked3 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
76+
77+
// CHECK-LABEL: @load_i64_offset_pointer_range_32
78+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
79+
tt.func @load_i64_offset_pointer_range_32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: i64) -> tensor<256xf32, #blocked3> {
80+
%splat_off = tt.splat %arg1 : i64 -> tensor<256xi64, #blocked3>
81+
%base = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked3>
82+
%ptr = tt.addptr %base, %splat_off : tensor<256x!tt.ptr<f32>, #blocked3>, tensor<256xi64, #blocked3>
83+
// CHECK: arith.trunci
84+
// CHECK-SAME: tensor<256xi64,
85+
// CHECK-SAME: to tensor<256xi32,
86+
// CHECK: amdg.buffer_load
87+
%val = tt.load %ptr : tensor<256x!tt.ptr<f32>, #blocked3>
88+
tt.return %val : tensor<256xf32, #blocked3>
89+
}
90+
}

0 commit comments

Comments
 (0)