@@ -36,13 +36,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
3636 %8 = tt.addptr %7 , %4 : tensor <256 x!tt.ptr <f32 >, #blocked0 >, tensor <256 xi32 , #blocked0 >
3737 // COMMON: buffer_load %arg0[%[[offset]]]
3838 %9 = tt.load %6 : tensor <256 x!tt.ptr <f32 >, #blocked0 >
39- // COMMON: buffer_load %arg1[%[[offset]]]
39+ // Note: offset = pid * 256 + arange(0, 256); byte-ofst="offset * sizeof(i32)" may not fall into range of 2G.
40+ // COMMON-NOT: buffer_load %arg1[%[[offset]]]
4041 %10 = tt.load %8 : tensor <256 x!tt.ptr <f32 >, #blocked0 >
4142 // COMMON: %[[data:.*]] = arith.addf
4243 %11 = arith.addf %9 , %10 : tensor <256 xf32 , #blocked0 >
4344 %12 = tt.splat %arg2 : !tt.ptr <f32 > -> tensor <256 x!tt.ptr <f32 >, #blocked0 >
4445 %13 = tt.addptr %12 , %4 : tensor <256 x!tt.ptr <f32 >, #blocked0 >, tensor <256 xi32 , #blocked0 >
45- // COMMON: buffer_store %[[data]], %arg2[%[[offset]]]
46+ // Note: see the explanation above
47+ // COMMON-NOT: buffer_store %[[data]], %arg2[%[[offset]]]
4648 tt.store %13 , %11 : tensor <256 x!tt.ptr <f32 >, #blocked0 >
4749 tt.return
4850 }
@@ -70,7 +72,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
7072 %5 = tt.addptr %arg0 , %1 : !tt.ptr <f32 >, i32
7173 %8 = tt.splat %5 : !tt.ptr <f32 > -> tensor <1024 x!tt.ptr <f32 >, #blocked >
7274 %9 = tt.addptr %8 , %4 : tensor <1024 x!tt.ptr <f32 >, #blocked >, tensor <1024 xi32 , #blocked >
73- // COMMON: buffer_load %[[scalar_ptr]][%[[offset]]]
75+ // Note: the base "scalar_ptr" points to arg0 which is a large-tensor.
76+ // the offset="%sub + arange(0,1024)" where "%sub=pid*1024-128",
77+ // We can prove "offset > 0", but cannot prove byte-offset < 2G.
78+ // COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset]]]
7479 %10 = tt.load %9 : tensor <1024 x!tt.ptr <f32 >, #blocked >
7580 tt.return %10 : tensor <1024 xf32 , #blocked >
7681 }
@@ -122,7 +127,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
122127 // COMMON: %[[offset_32_bit:.*]] = arith.trunci
123128 %narrow4 = arith.trunci %4 : tensor <1024 xi64 , #blocked > to tensor <1024 xi32 , #blocked >
124129 %9 = tt.addptr %8 , %narrow4 : tensor <1024 x!tt.ptr <f32 >, #blocked >, tensor <1024 xi32 , #blocked >
125- // COMMON: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]]
130+ // Note: base is arg0 which is large-tensor, the offset=int(long(pid*1024) * long(arange(0, 1024))
131+ // offset is in [0, i32-max].
132+ // COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]]
126133 %10 = tt.load %9 : tensor <1024 x!tt.ptr <f32 >, #blocked >
127134 tt.return %10 : tensor <1024 xf32 , #blocked >
128135 }
@@ -555,7 +562,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
555562 %5 = tt.addptr %arg0 , %1 : !tt.ptr <f32 >, i32
556563 %6 = tt.splat %5 : !tt.ptr <f32 > -> tensor <1024 x!tt.ptr <f32 >, #blocked >
557564 %7 = tt.addptr %6 , %4 : tensor <1024 x!tt.ptr <f32 >, #blocked >, tensor <1024 xi32 , #blocked >
558- // COMMON: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]]
565+ // Note: the large tensor is accessed, offset is in the range of [0, smax].
566+ // without tl.assume the range would be [-128, smax]
567+ // COMMON-NOT: amdgpu.buffer_atomic_rmw
559568 %8 = tt.atomic_rmw fadd , acq_rel , gpu , %7 , %arg1 : (tensor <1024 x!tt.ptr <f32 >, #blocked >, tensor <1024 xf32 , #blocked >) -> tensor <1024 xf32 , #blocked >
560569 tt.return %8 : tensor <1024 xf32 , #blocked >
561570 }
0 commit comments