Skip to content

Commit 55c3183

Browse files
committed
[AMD] CanonicalizePointers: Handle different base pointers and offsets
`scf.if` and `arith.select` ops can have different base pointers accross branches. In that case, `FatPtrAttrs` can't keep `smallTensorBase` or pick one of the branches. Instead, replace `smallTensorBase` (Value) with `isSmallTensor` (bool) in `FatPtrAttrs` and add a conservative `merge` function for combining attrs. Also, handle different scalar offsets properly when handling `arith.select` ops.
1 parent d28db13 commit 55c3183

File tree

2 files changed

+77
-34
lines changed

2 files changed

+77
-34
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-canonicalize-pointers -canonicalize | FileCheck %s
2+
3+
// CHECK-LABEL: tt.func @scf_if_different_bases
4+
// CHECK: [[BASE:%.*]] = arith.select %arg2, %arg0, %arg1 : !tt.ptr<f32>
5+
// CHECK: [[OFFSET:%.*]] = arith.select %arg2, %c16_i32, %c32_i32 : i32
6+
// CHECK: [[PTR:%.*]] = tt.addptr [[BASE]], [[OFFSET]]
7+
// CHECK: tt.load [[PTR]]
8+
module attributes {"ttg.num-warps" = 4 : i32} {
9+
tt.func @scf_if_different_bases(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
10+
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
11+
%arg2: i1) -> f32 {
12+
%c16_i32 = arith.constant 16 : i32
13+
%c32_i32 = arith.constant 32 : i32
14+
%0 = scf.if %arg2 -> (!tt.ptr<f32>) {
15+
%2 = tt.addptr %arg0, %c16_i32 : !tt.ptr<f32>, i32
16+
scf.yield %2 : !tt.ptr<f32>
17+
} else {
18+
%2 = tt.addptr %arg1, %c32_i32 : !tt.ptr<f32>, i32
19+
scf.yield %2 : !tt.ptr<f32>
20+
}
21+
%1 = tt.load %0 : !tt.ptr<f32>
22+
tt.return %1 : f32
23+
}
24+
}
25+
26+
// -----
27+
28+
// CHECK-LABEL: tt.func @select_different_bases
29+
// CHECK: [[BASE:%.*]] = arith.select %arg2, %arg0, %arg1 : !tt.ptr<f32>
30+
// CHECK: [[OFFSET:%.*]] = arith.select %arg2, %c16_i32, %c32_i32 : i32
31+
// CHECK: [[PTR:%.*]] = tt.addptr [[BASE]], [[OFFSET]]
32+
// CHECK: tt.load [[PTR]]
33+
module attributes {"ttg.num-warps" = 4 : i32} {
34+
tt.func @select_different_bases(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
35+
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
36+
%arg2: i1) -> f32 {
37+
%c16_i32 = arith.constant 16 : i32
38+
%c32_i32 = arith.constant 32 : i32
39+
%2 = tt.addptr %arg0, %c16_i32 : !tt.ptr<f32>, i32
40+
%3 = tt.addptr %arg1, %c32_i32 : !tt.ptr<f32>, i32
41+
%4 = arith.select %arg2, %2, %3 : !tt.ptr<f32>
42+
%5 = tt.load %4 : !tt.ptr<f32>
43+
tt.return %5 : f32
44+
}
45+
}

third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -462,18 +462,28 @@ struct FatPointers {
462462

463463
friend bool operator==(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
464464
return lhs.canNarrow == rhs.canNarrow &&
465-
lhs.attributes == rhs.attributes &&
466-
lhs.smallTensorBase == rhs.smallTensorBase;
465+
lhs.isSmallTensor == rhs.isSmallTensor &&
466+
lhs.attributes == rhs.attributes;
467467
}
468468

469469
friend bool operator!=(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
470470
return !(lhs == rhs);
471471
}
472472

473+
static FatPtrAttrs merge(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
474+
FatPtrAttrs result;
475+
result.canNarrow = lhs.canNarrow && rhs.canNarrow;
476+
result.isSmallTensor = lhs.isSmallTensor && rhs.isSmallTensor;
477+
for (const auto &attr : lhs.attributes) {
478+
auto it = rhs.attributes.find(attr.first);
479+
if (it != rhs.attributes.end() && it->second == attr.second)
480+
result.attributes[attr.first] = attr.second;
481+
}
482+
return result;
483+
}
484+
473485
llvm::DenseMap<StringRef, Attribute> attributes;
474-
// If the fat-pointer points to somewhere in a small-tensor, keep track the
475-
// base of the tensor.
476-
Value smallTensorBase;
486+
bool isSmallTensor = false;
477487
bool canNarrow = false;
478488
};
479489

@@ -745,7 +755,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
745755
RewriterBase::InsertionGuard guard(rewriter);
746756
rewriter.setInsertionPoint(addPtrOp);
747757

748-
if (fatPtrs.at({fatPtrBase, fatPtrOffset}).smallTensorBase)
758+
if (fatPtrs.at({fatPtrBase, fatPtrOffset}).isSmallTensor)
749759
return rewriteSmallTensorPtr(addPtrOp, adaptor, rewriter);
750760

751761
// Query all discardable attributes that we want to preserve
@@ -861,7 +871,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
861871
const auto &oldAttr = fatPtrs.at({fatPtrBase, fatPtrOffset});
862872

863873
LDBG("smal-tensor addPtr: " << addPtrOp);
864-
LDBG(" - with tensor-base: " << oldAttr.smallTensorBase);
874+
LDBG(" - isSmallTensor: " << oldAttr.isSmallTensor);
865875
LDBG(" - with originl offset: " << origOffset);
866876
LDBG(" - fatPtr base: " << fatPtrBase);
867877
LDBG(" - fatPtr offst: " << fatPtrOffset);
@@ -1362,17 +1372,6 @@ class ConvertArithSelectOp
13621372
// select of base and offset
13631373
ValueRange fatPtrFalse = adaptor.getFalseValue();
13641374
ValueRange fatPtrTrue = adaptor.getTrueValue();
1365-
// Simple case of a scalar select: update the base pointer
1366-
if (!isa<RankedTensorType>(selectOp.getType())) {
1367-
auto newSelectOp = arith::SelectOp::create(
1368-
rewriter, selectOp.getLoc(), selectOp.getType(),
1369-
selectOp.getCondition(), fatPtrTrue[0], selectOp.getFalseValue());
1370-
rewriter.replaceOpWithMultiple(selectOp, {{newSelectOp, fatPtrTrue[1]}});
1371-
fatPtrs[{newSelectOp, /*fatPtrOffset*/ fatPtrTrue[1]}] =
1372-
fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]});
1373-
return success();
1374-
}
1375-
13761375
// Rewrite to select(fatBaseT, fatBaseF) and select(fatOffsetT, fatOffsetF)
13771376
auto newBase = arith::SelectOp::create(rewriter, selectOp.getLoc(),
13781377
selectOp.getCondition(),
@@ -1381,12 +1380,10 @@ class ConvertArithSelectOp
13811380
selectOp.getCondition(),
13821381
fatPtrTrue[1], fatPtrFalse[1]);
13831382

1384-
assert((fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}) ==
1385-
fatPtrs.at({fatPtrFalse[0], fatPtrFalse[1]})) &&
1386-
"expected can narrow to be the same for both fatPtrT and fatPtrF");
1387-
13881383
rewriter.replaceOpWithMultiple(selectOp, {{newBase, newOffset}});
1389-
fatPtrs[{newBase, newOffset}] = fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]});
1384+
fatPtrs[{newBase, newOffset}] = FatPointers::FatPtrAttrs::merge(
1385+
fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}),
1386+
fatPtrs.at({fatPtrFalse[0], fatPtrFalse[1]}));
13901387

13911388
return success();
13921389
}
@@ -1434,14 +1431,6 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
14341431
assert(i < ifOp.thenYield().getNumOperands() &&
14351432
i + 1 < ifOp.thenYield().getNumOperands() &&
14361433
"expected idx to be within bounds of IfOp's results");
1437-
Value thenFatPtrBase = ifOp.thenYield().getOperand(i);
1438-
Value thenFatPtrOffset = ifOp.thenYield().getOperand(i + 1);
1439-
Value elseFatPtrBase = ifOp.elseYield().getOperand(i);
1440-
Value elseFatPtrOffset = ifOp.elseYield().getOperand(i + 1);
1441-
assert((fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}) ==
1442-
fatPtrs.at({elseFatPtrBase, elseFatPtrOffset})) &&
1443-
"expected then fat ptr canNarrow and else fat ptr canNarrow "
1444-
"to be equal");
14451434
}
14461435
}
14471436
}
@@ -1467,8 +1456,17 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
14671456
for (int64_t idx : yieldPtrOffsets) {
14681457
Value thenFatPtrBase = newIfOp.thenYield().getOperand(idx);
14691458
Value thenFatPtrOffset = newIfOp.thenYield().getOperand(idx + 1);
1470-
fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] =
1471-
fatPtrs.at({thenFatPtrBase, thenFatPtrOffset});
1459+
const auto &thenAttrs = fatPtrs.at({thenFatPtrBase, thenFatPtrOffset});
1460+
if (withElseRegion) {
1461+
Value elseFatPtrBase = newIfOp.elseYield().getOperand(idx);
1462+
Value elseFatPtrOffset = newIfOp.elseYield().getOperand(idx + 1);
1463+
const auto &elseAttrs = fatPtrs.at({elseFatPtrBase, elseFatPtrOffset});
1464+
fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] =
1465+
FatPointers::FatPtrAttrs::merge(thenAttrs, elseAttrs);
1466+
} else {
1467+
fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] =
1468+
thenAttrs;
1469+
}
14721470
}
14731471

14741472
ResultRange results = newIfOp.getResults();
@@ -1708,7 +1706,7 @@ struct InitFuncPtrArgs : OpRewritePattern<tt::FuncOp> {
17081706
rewriter.replaceAllUsesExcept(arg, dummyCast.getResult(0), dummyCast);
17091707
fatPtrs[{arg, zeroOffset}].canNarrow = true;
17101708
if (bitness != 64)
1711-
fatPtrs[{arg, zeroOffset}].smallTensorBase = arg;
1709+
fatPtrs[{arg, zeroOffset}].isSmallTensor = true;
17121710
}
17131711

17141712
newOp->setDiscardableAttr(kInitFuncArgsRewritten, rewriter.getUnitAttr());

0 commit comments

Comments
 (0)