Skip to content

Commit fa8b7bb

Browse files
authored
[AMD] Support ConvertLayout in CanonicalizePointers (#6142)
`CanonicalizePointers` pass was missing a rewrite pattern for `ConvertLayout` to change it to use the offset of the fat ptr. Without this change the pass will fail if the ptr of the `tt.load` was transformed by a `ConvertLayout`. Added a lit test for it and a general correctness test for indirect loads and stores.
1 parent 9ca8bd3 commit fa8b7bb

File tree

3 files changed

+122
-1
lines changed

3 files changed

+122
-1
lines changed

python/test/unit/language/test_core.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7211,3 +7211,47 @@ def aliasing_kernel(buffer, buffer2):
72117211
buffer = torch.zeros(1, device=device)
72127212
aliasing_kernel[(1, )](buffer, buffer)
72137213
assert buffer[0] == 1
7214+
7215+
7216+
@pytest.mark.interpreter
7217+
@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"])
7218+
def test_indirect_load(dtype, device):
7219+
7220+
@triton.jit
7221+
def indirect_load(offset_ptr, x_ptr, output_ptr, SIZE: tl.constexpr):
7222+
linear_offsets = tl.arange(0, SIZE)
7223+
offsets = tl.load(offset_ptr + linear_offsets)
7224+
x = tl.load(x_ptr + offsets)
7225+
tl.store(output_ptr + linear_offsets, x)
7226+
7227+
SIZE = 512
7228+
x = numpy_random(SIZE, dtype_str=dtype)
7229+
x_tri = to_triton(x, device)
7230+
# Flip the range to load the tensor in reverse order
7231+
ptr = torch.arange(SIZE, device=device, dtype=torch.int32).flip(0)
7232+
out_tri = torch.empty(SIZE, device=device)
7233+
indirect_load[(1, 1)](ptr, x_tri, out_tri, SIZE)
7234+
7235+
np.testing.assert_allclose(np.flip(x), to_numpy(out_tri))
7236+
7237+
7238+
@pytest.mark.interpreter
7239+
@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"])
7240+
def test_indirect_store(dtype, device):
7241+
7242+
@triton.jit
7243+
def indirect_store(offset_ptr, x_ptr, output_ptr, SIZE: tl.constexpr):
7244+
linear_offsets = tl.arange(0, SIZE)
7245+
offsets = tl.load(offset_ptr + linear_offsets)
7246+
x = tl.load(x_ptr + linear_offsets)
7247+
tl.store(output_ptr + offsets, x)
7248+
7249+
SIZE = 512
7250+
x = numpy_random(SIZE, dtype_str=dtype)
7251+
x_tri = to_triton(x, device)
7252+
# Flip the range to store the tensor in reverse order
7253+
ptr = torch.arange(SIZE, device=device, dtype=torch.int32).flip(0)
7254+
out_tri = torch.empty(SIZE, device=device)
7255+
indirect_store[(1, 1)](ptr, x_tri, out_tri, SIZE)
7256+
7257+
np.testing.assert_allclose(np.flip(x), to_numpy(out_tri))

test/TritonGPU/amd/amd-canonicalize-pointers.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,40 @@ module attributes {"ttg.num-warps" = 4 : i32} {
132132

133133
// -----
134134

135+
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
136+
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
137+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
138+
tt.func public @convertLayoutOp(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<i32>, %arg2: tensor<1024xi32, #blocked>) -> tensor<1024xf32, #blocked1> {
139+
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
140+
%1 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
141+
%2 = tt.addptr %1, %arg2 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
142+
%3 = tt.load %2 : tensor<1024x!tt.ptr<i32>, #blocked>
143+
%4 = tt.addptr %0, %3 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
144+
%5 = ttg.convert_layout %4 : tensor<1024x!tt.ptr<f32>, #blocked> -> tensor<1024x!tt.ptr<f32>, #blocked1>
145+
%6 = tt.load %5 : tensor<1024x!tt.ptr<f32>, #blocked1>
146+
tt.return %6 : tensor<1024xf32, #blocked1>
147+
}
148+
}
149+
150+
// CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
151+
// CHECK: #[[$ATTR_1:.+]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
152+
153+
// CHECK-LABEL: tt.func public @convertLayoutOp(
154+
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>, %[[VAL_1:.*]]: !tt.ptr<i32>, %[[VAL_2:.*]]: tensor<1024xi32, #[[$ATTR_0]]>) -> tensor<1024xf32, #[[$ATTR_1]]> {
155+
// CHECK: %[[VAL_3:.*]] = tt.splat %[[VAL_1]] : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #[[$ATTR_0]]>
156+
// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_3]], %[[VAL_2]] : tensor<1024x!tt.ptr<i32>, #[[$ATTR_0]]>, tensor<1024xi32, #[[$ATTR_0]]>
157+
// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : tensor<1024x!tt.ptr<i32>, #[[$ATTR_0]]>
158+
// CHECK: %[[VAL_6:.*]] = arith.extsi %[[VAL_5]] : tensor<1024xi32, #[[$ATTR_0]]> to tensor<1024xi64, #[[$ATTR_0]]>
159+
// CHECK: %[[VAL_7:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<1024xi64, #[[$ATTR_0]]> -> tensor<1024xi64, #[[$ATTR_1]]>
160+
// CHECK: %[[VAL_8:.*]] = arith.trunci %[[VAL_7]] : tensor<1024xi64, #[[$ATTR_1]]> to tensor<1024xi32, #[[$ATTR_1]]>
161+
// CHECK: %[[VAL_9:.*]] = tt.splat %[[VAL_0]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #[[$ATTR_1]]>
162+
// CHECK: %[[VAL_10:.*]] = tt.addptr %[[VAL_9]], %[[VAL_8]] : tensor<1024x!tt.ptr<f32>, #[[$ATTR_1]]>, tensor<1024xi32, #[[$ATTR_1]]>
163+
// CHECK: %[[VAL_11:.*]] = tt.load %[[VAL_10]] : tensor<1024x!tt.ptr<f32>, #[[$ATTR_1]]>
164+
// CHECK: tt.return %[[VAL_11]] : tensor<1024xf32, #[[$ATTR_1]]>
165+
// CHECK: }
166+
167+
// -----
168+
135169
module attributes {"ttg.num-warps" = 4 : i32} {
136170
tt.func @forOp(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
137171
%c1024_i32 = arith.constant 1024 : i32

third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,48 @@ class ConvertExpandDims
10791079
}
10801080
};
10811081

1082+
/// convert integer offset, keep base
1083+
class ConvertConvertLayoutOp
1084+
: public PointerCanonicalizationPattern<tt::gpu::ConvertLayoutOp> {
1085+
public:
1086+
using PointerCanonicalizationPattern::PointerCanonicalizationPattern;
1087+
1088+
LogicalResult
1089+
matchAndRewrite_(tt::gpu::ConvertLayoutOp cvtOp, OneToNOpAdaptor adaptor,
1090+
ConversionPatternRewriter &rewriter) const override {
1091+
ValueRange remappedOperands = adaptor.getSrc();
1092+
if (remappedOperands.size() != 2) {
1093+
// some prior op materialized the fat ptr, e.g.:
1094+
// %3 = tt.bitcast %2
1095+
// %4 = tt.splat %3
1096+
return success();
1097+
}
1098+
Value fatPtrBase = remappedOperands[0];
1099+
Value fatPtrOffset = remappedOperands[1];
1100+
if (!llvm::isa<tt::PointerType>(fatPtrBase.getType())) {
1101+
return rewriter.notifyMatchFailure(cvtOp,
1102+
"non tt.ptr base unimplemented");
1103+
}
1104+
auto offsetTensorTy = dyn_cast<RankedTensorType>(fatPtrOffset.getType());
1105+
if (!offsetTensorTy) {
1106+
return rewriter.notifyMatchFailure(
1107+
cvtOp, "non RankedTensorType offset unimplemented");
1108+
}
1109+
1110+
RankedTensorType outType = cvtOp.getResult().getType();
1111+
auto newOffsetType = RankedTensorType::get(outType.getShape(),
1112+
offsetTensorTy.getElementType(),
1113+
outType.getEncoding());
1114+
tt::gpu::ConvertLayoutOp cvtOffset =
1115+
rewriter.create<tt::gpu::ConvertLayoutOp>(cvtOp.getLoc(), newOffsetType,
1116+
fatPtrOffset);
1117+
rewriter.replaceOpWithMultiple(cvtOp, {{fatPtrBase, cvtOffset}});
1118+
fatPtrs[{fatPtrBase, cvtOffset}] = fatPtrs.at({fatPtrBase, fatPtrOffset});
1119+
1120+
return success();
1121+
}
1122+
};
1123+
10821124
template <typename SourceOp, int PtrLikeIdx = 0>
10831125
class MaterializeFatPointer : public PointerCanonicalizationPattern<SourceOp> {
10841126
public:
@@ -1452,7 +1494,8 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
14521494
RewritePatternSet patterns(&getContext());
14531495
patterns.add<
14541496
ConvertFuncOpArgsUnrealizedCasts, ConvertBroadcastOp, ConvertSplatOp,
1455-
ConvertAddPtrOp, MaterializeFatPointer<tt::AtomicCASOp>,
1497+
ConvertConvertLayoutOp, ConvertAddPtrOp,
1498+
MaterializeFatPointer<tt::AtomicCASOp>,
14561499
MaterializeFatPointer<tt::AtomicRMWOp>,
14571500
MaterializeFatPointer<tt::BitcastOp>, MaterializeFatPointer<tt::LoadOp>,
14581501
MaterializeFatPointer<triton::gpu::AsyncCopyGlobalToLocalOp>,

0 commit comments

Comments
 (0)