Skip to content

Commit f8d5d1e

Browse files
authored
[AMD] Use pointee type for buffer op alignment in AxisAnalysis (#6145)
For buffer ops we have to compute the contiguity based on the offsets with are `i32` but we need to get the alignment/divisibility based on the pointer type to get the correct vector size.
1 parent 1239887 commit f8d5d1e

File tree

5 files changed

+112
-12
lines changed

5 files changed

+112
-12
lines changed

include/triton/Analysis/AxisInfo.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,19 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
205205
unsigned getContiguity(Value value);
206206
unsigned getAlignment(Value value);
207207

208+
// Overloads of the above methods but have separated elementBitWidth to
209+
// calculate the contiguity. These are useful for computing axis info when
210+
// lowering to hardware intrinsics that require a scalar/warp-uniform base ptr
211+
// with separate per lane offsets like AMD buffer operations.
212+
//
213+
// As a concrete example, instead of a single tensor<128x64x!tt.ptr<f16>>
214+
// value, now we have two separate values: !tt.ptr<f16> for the base pointer
215+
// and tensor<128x64xi32> for the offset. For such cases, we want to compute
216+
// the contiguity on the offsets but use the pointee element type bit width
217+
// instead of the offset element type bit width for alignment
218+
unsigned getContiguity(Value offsetsValue, unsigned elementBitWidth);
219+
unsigned getAlignment(Value offsetsValue, unsigned elementBitWidth);
220+
208221
unsigned getMaskAlignment(Value mask);
209222

210223
private:

lib/Analysis/AxisInfo.cpp

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,13 +1222,24 @@ unsigned ModuleAxisInfoAnalysis::getContiguity(Value value) {
12221222
auto tensorTy = dyn_cast<RankedTensorType>(value.getType());
12231223
if (!tensorTy)
12241224
return 1;
1225+
auto elemTy = tensorTy.getElementType();
1226+
// Get the pointee type if we have a tensor of ptrs to compute contiguity for
1227+
if (auto ptrTy = dyn_cast<PointerType>(elemTy)) {
1228+
elemTy = ptrTy.getPointeeType();
1229+
}
1230+
return getContiguity(value, elemTy.getIntOrFloatBitWidth());
1231+
}
1232+
1233+
unsigned ModuleAxisInfoAnalysis::getContiguity(Value offsetsValue,
1234+
unsigned elementBitWidth) {
12251235
// FIXME: This is not as good as it could be, as we don't need to restrict
12261236
// the analysis to one dimension. We should determine contiguity on the
12271237
// flattenOuts() layout
1238+
auto tensorTy = cast<RankedTensorType>(offsetsValue.getType());
12281239
auto linAttr =
12291240
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
12301241
auto order = linAttr.getOrder();
1231-
unsigned align = getAlignment(value);
1242+
unsigned align = getAlignment(offsetsValue, elementBitWidth);
12321243

12331244
auto uniqueContigPerThread = linAttr.getContigPerThread();
12341245
assert(order[0] < uniqueContigPerThread.size() &&
@@ -1244,7 +1255,19 @@ unsigned ModuleAxisInfoAnalysis::getAlignment(Value value) {
12441255
auto tensorTy = dyn_cast<RankedTensorType>(value.getType());
12451256
if (!tensorTy)
12461257
return 1;
1247-
auto *axisInfo = getAxisInfo(value);
1258+
1259+
auto elemTy = tensorTy.getElementType();
1260+
// Get the pointee type if we have a tensor of ptrs to compute contiguity for
1261+
if (auto ptrTy = dyn_cast<PointerType>(elemTy)) {
1262+
elemTy = ptrTy.getPointeeType();
1263+
}
1264+
return getAlignment(value, elemTy.getIntOrFloatBitWidth());
1265+
}
1266+
1267+
unsigned ModuleAxisInfoAnalysis::getAlignment(Value offsetsValue,
1268+
unsigned elementBitWidth) {
1269+
auto tensorTy = cast<RankedTensorType>(offsetsValue.getType());
1270+
auto *axisInfo = getAxisInfo(offsetsValue);
12481271
if (!axisInfo)
12491272
return 1;
12501273
auto linAttr =
@@ -1253,18 +1276,12 @@ unsigned ModuleAxisInfoAnalysis::getAlignment(Value value) {
12531276
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
12541277
auto maxContig = axisInfo->getContiguity(order[0]);
12551278

1256-
auto elemTy = tensorTy.getElementType();
1257-
// Get the pointee type if we have a tensor of ptrs to compute contiguity for
1258-
if (auto ptrTy = dyn_cast<PointerType>(elemTy)) {
1259-
elemTy = ptrTy.getPointeeType();
1260-
}
1261-
auto elemNumBits = elemTy.getIntOrFloatBitWidth();
1262-
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
1279+
auto elemNumBytes = std::max<unsigned>(elementBitWidth / 8, 1);
12631280
auto maxMultiple = std::max<int64_t>(maxMultipleBytes / elemNumBytes, 1);
12641281
unsigned alignment = std::min(maxMultiple, maxContig);
12651282
LDBG("getAlignment order[0] "
12661283
<< order[0] << " maxMultipleBytes = " << maxMultipleBytes
1267-
<< " maxContig = " << maxContig << " elemNumBits = " << elemNumBits
1284+
<< " maxContig = " << maxContig << " elemNumBits = " << elementBitWidth
12681285
<< " maxMultiple = " << maxMultiple << " alignment " << alignment);
12691286
LLVM_DEBUG({
12701287
std::string axisStr;

python/test/unit/language/test_core.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7213,6 +7213,56 @@ def aliasing_kernel(buffer, buffer2):
72137213
assert buffer[0] == 1
72147214

72157215

7216+
@pytest.mark.interpreter
7217+
@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"])
7218+
def test_strided_load(dtype, device):
7219+
7220+
@triton.jit
7221+
def take_every_second_element(x_ptr, output_ptr, BLOCK_SIZE: tl.constexpr):
7222+
strided_offsets = tl.arange(0, BLOCK_SIZE) * 2
7223+
linear_offsets = tl.arange(0, BLOCK_SIZE)
7224+
x = tl.load(x_ptr + strided_offsets)
7225+
tl.store(output_ptr + linear_offsets, x)
7226+
7227+
STRIDE = 2
7228+
SIZE = 512
7229+
OUT_SIZE = SIZE // STRIDE
7230+
7231+
x = numpy_random(SIZE, dtype_str=dtype)
7232+
x_tri = to_triton(x, device)
7233+
out_tri = torch.empty(OUT_SIZE, device=device)
7234+
take_every_second_element[(1, 1)](x_tri, out_tri, OUT_SIZE)
7235+
7236+
# Test that every second element (starting from [0]) from x is stored in out_tri
7237+
np.testing.assert_allclose(x[::2], to_numpy(out_tri))
7238+
7239+
7240+
@pytest.mark.interpreter
7241+
@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"])
7242+
def test_strided_store(dtype, device):
7243+
7244+
@triton.jit
7245+
def store_into_every_second(x_ptr, output_ptr, BLOCK_SIZE: tl.constexpr):
7246+
strided_offsets = tl.arange(0, BLOCK_SIZE) * 2
7247+
linear_offsets = tl.arange(0, BLOCK_SIZE)
7248+
x = tl.load(x_ptr + linear_offsets)
7249+
tl.store(output_ptr + strided_offsets, x)
7250+
7251+
STRIDE = 2
7252+
SIZE = 512
7253+
OUT_SIZE = SIZE * STRIDE
7254+
7255+
x = numpy_random(SIZE, dtype_str=dtype)
7256+
x_tri = to_triton(x, device)
7257+
out_tri = torch.zeros(OUT_SIZE, device=device)
7258+
store_into_every_second[(1, 1)](x_tri, out_tri, SIZE)
7259+
7260+
# Test that every second element (starting from [0]) is the same as in x
7261+
np.testing.assert_allclose(x, to_numpy(out_tri)[::2])
7262+
# Test that every second element (starting from [1]) is still zero
7263+
np.testing.assert_allclose(np.zeros_like(x), to_numpy(out_tri)[1::2])
7264+
7265+
72167266
@pytest.mark.interpreter
72177267
@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"])
72187268
def test_indirect_load(dtype, device):

test/Conversion/amd/buffer_load_store.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
128128

129129
// -----
130130

131+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
132+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
133+
// CHECK-LABEL: buffer_load_8xf16
134+
tt.func public @buffer_load_8xf16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
135+
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
136+
%1 = tt.splat %arg2 : i32 -> tensor<256x64xi32, #blocked>
137+
%2 = tt.expand_dims %0 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
138+
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked>
139+
%4 = arith.addi %3, %1 : tensor<256x64xi32, #blocked>
140+
// Load 16 f16 elements check for correct vector size of instruction (4xi32 = 8xf16)
141+
// CHECK-COUNT-4: rocdl.raw.ptr.buffer.load {{.*}} : vector<4xi32>
142+
%5 = amdgpu.buffer_load %arg0[%4] : tensor<256x64xf16, #blocked>
143+
// CHECK-COUNT-4: rocdl.raw.ptr.buffer.store {{.*}} : vector<4xi32>
144+
amdgpu.buffer_store %5, %arg0[%4] : tensor<256x64xf16, #blocked>
145+
tt.return
146+
}
147+
}
148+
149+
// -----
150+
131151
#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
132152
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
133153
// CHECK-LABEL: buffer_load_store_vec1

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,12 +539,12 @@ unsigned getContiguity(Value ptr, Value offset,
539539

540540
// To compute the contiguity of the scalar/warp-uniform ptr and offset pair we
541541
// need to look at the contiguity of the offsets and the alignment of the ptr
542-
auto contiguity = axisAnalysisPass.getContiguity(offset);
542+
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
543+
auto contiguity = axisAnalysisPass.getContiguity(offset, elemNumBits);
543544

544545
// To get the alignment of the scalar ptr we need to look at the divisibility
545546
auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr);
546547
auto maxMultipleBytes = axisInfo->getDivisibility(0);
547-
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
548548
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
549549
auto align = std::max<unsigned>(maxMultipleBytes / elemNumBytes, 1);
550550

0 commit comments

Comments
 (0)