Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: triton-opt %s -split-input-file -tritonamdgpu-canonicalize-pointers -canonicalize | FileCheck %s

// CHECK-LABEL: tt.func @scf_if_different_bases
// CHECK: [[BASE:%.*]] = arith.select %arg2, %arg0, %arg1 : !tt.ptr<f32>
// CHECK: [[OFFSET:%.*]] = arith.select %arg2, %c16_i32, %c32_i32 : i32
// CHECK: [[PTR:%.*]] = tt.addptr [[BASE]], [[OFFSET]]
// CHECK: tt.load [[PTR]]
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @scf_if_different_bases(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg2: i1) -> f32 {
%c16_i32 = arith.constant 16 : i32
%c32_i32 = arith.constant 32 : i32
%0 = scf.if %arg2 -> (!tt.ptr<f32>) {
%2 = tt.addptr %arg0, %c16_i32 : !tt.ptr<f32>, i32
scf.yield %2 : !tt.ptr<f32>
} else {
%2 = tt.addptr %arg1, %c32_i32 : !tt.ptr<f32>, i32
scf.yield %2 : !tt.ptr<f32>
}
%1 = tt.load %0 : !tt.ptr<f32>
tt.return %1 : f32
}
}

// -----

// CHECK-LABEL: tt.func @select_different_bases
// CHECK: [[BASE:%.*]] = arith.select %arg2, %arg0, %arg1 : !tt.ptr<f32>
// CHECK: [[OFFSET:%.*]] = arith.select %arg2, %c16_i32, %c32_i32 : i32
// CHECK: [[PTR:%.*]] = tt.addptr [[BASE]], [[OFFSET]]
// CHECK: tt.load [[PTR]]
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @select_different_bases(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg2: i1) -> f32 {
%c16_i32 = arith.constant 16 : i32
%c32_i32 = arith.constant 32 : i32
%2 = tt.addptr %arg0, %c16_i32 : !tt.ptr<f32>, i32
%3 = tt.addptr %arg1, %c32_i32 : !tt.ptr<f32>, i32
%4 = arith.select %arg2, %2, %3 : !tt.ptr<f32>
%5 = tt.load %4 : !tt.ptr<f32>
tt.return %5 : f32
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -462,18 +462,28 @@ struct FatPointers {

friend bool operator==(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
return lhs.canNarrow == rhs.canNarrow &&
lhs.attributes == rhs.attributes &&
lhs.smallTensorBase == rhs.smallTensorBase;
lhs.isSmallTensor == rhs.isSmallTensor &&
lhs.attributes == rhs.attributes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if ptr1 and ptr2 is from small tensor a and b respectively, they share the same attribute, then == return true?.

I think this is fundamentally wrong although replacing smallTensorBase with a boolean does not expose bug.

Maybe you can add another field to indicate the case where the base is dynamic, it is a combination of multiple real tensors. say, introduce dynamic-base for that purpose.

When merging from different base, clobber the smallTensorBase, and set the dynamic-base.

Or conservertively, just clobber of tensorBase to indicate it too complicated to handle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if ptr1 and ptr2 is from small tensor a and b respectively, they share the same attribute, then == return true?

No, we compare attributes here, not the actual pointers.

Maybe you can add another field to indicate the case where the base is dynamic, it is a combination of multiple real tensors. say, introduce dynamic-base for that purpose.

Where will we use dynamic base in this case? Currently, we keep smallTensorBase, but don't use it other than checking if it's a small tensor. With isSmallTensor added, we'll no longer need that either.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we compare attributes here, not the actual pointers.

Then, why it's valid? Two pointers are equal if the share the same attribute even if they don't have the same base?

While I cannot give your examples as to why this will have incur problem at this moment (need to check the code), does not your change sound very dangerous? You claim two pointers are equals without comparing their base-pointer. It might works for now because, by current implementation, when two pointers involved into comparision, they may happen to have to same bases. It has huge potential problems.

We cannot weaken the condition at the huge risk of correctness just to make dealing dynamic-base slightly easier.

}

friend bool operator!=(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
return !(lhs == rhs);
}

static FatPtrAttrs merge(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would intersect be a better name? merge sounds a bit like we take the union of all attributes.

FatPtrAttrs result;
result.canNarrow = lhs.canNarrow && rhs.canNarrow;
result.isSmallTensor = lhs.isSmallTensor && rhs.isSmallTensor;
for (const auto &attr : lhs.attributes) {
auto it = rhs.attributes.find(attr.first);
if (it != rhs.attributes.end() && it->second == attr.second)
result.attributes[attr.first] = attr.second;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you can replace the data-type to small vectormap?

}
return result;
}

llvm::DenseMap<StringRef, Attribute> attributes;
// If the fat-pointer points to somewhere in a small-tensor, keep track the
// base of the tensor.
Value smallTensorBase;
bool isSmallTensor = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, we cannot say ptr-to-a == ptr-to-b.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, this is a part of the attributes, it doesn't mean pointers are the same.

bool canNarrow = false;
};

Expand Down Expand Up @@ -745,7 +755,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
RewriterBase::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(addPtrOp);

if (fatPtrs.at({fatPtrBase, fatPtrOffset}).smallTensorBase)
if (fatPtrs.at({fatPtrBase, fatPtrOffset}).isSmallTensor)
return rewriteSmallTensorPtr(addPtrOp, adaptor, rewriter);

// Query all discardable attributes that we want to preserve
Expand Down Expand Up @@ -861,7 +871,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
const auto &oldAttr = fatPtrs.at({fatPtrBase, fatPtrOffset});

LDBG("smal-tensor addPtr: " << addPtrOp);
LDBG(" - with tensor-base: " << oldAttr.smallTensorBase);
LDBG(" - isSmallTensor: " << oldAttr.isSmallTensor);
LDBG(" - with originl offset: " << origOffset);
LDBG(" - fatPtr base: " << fatPtrBase);
LDBG(" - fatPtr offst: " << fatPtrOffset);
Expand Down Expand Up @@ -1362,17 +1372,6 @@ class ConvertArithSelectOp
// select of base and offset
ValueRange fatPtrFalse = adaptor.getFalseValue();
ValueRange fatPtrTrue = adaptor.getTrueValue();
// Simple case of a scalar select: update the base pointer
if (!isa<RankedTensorType>(selectOp.getType())) {
auto newSelectOp = arith::SelectOp::create(
rewriter, selectOp.getLoc(), selectOp.getType(),
selectOp.getCondition(), fatPtrTrue[0], selectOp.getFalseValue());
rewriter.replaceOpWithMultiple(selectOp, {{newSelectOp, fatPtrTrue[1]}});
fatPtrs[{newSelectOp, /*fatPtrOffset*/ fatPtrTrue[1]}] =
fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]});
return success();
}

// Rewrite to select(fatBaseT, fatBaseF) and select(fatOffsetT, fatOffsetF)
auto newBase = arith::SelectOp::create(rewriter, selectOp.getLoc(),
selectOp.getCondition(),
Expand All @@ -1381,12 +1380,10 @@ class ConvertArithSelectOp
selectOp.getCondition(),
fatPtrTrue[1], fatPtrFalse[1]);

assert((fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}) ==
fatPtrs.at({fatPtrFalse[0], fatPtrFalse[1]})) &&
"expected can narrow to be the same for both fatPtrT and fatPtrF");

rewriter.replaceOpWithMultiple(selectOp, {{newBase, newOffset}});
fatPtrs[{newBase, newOffset}] = fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]});
fatPtrs[{newBase, newOffset}] = FatPointers::FatPtrAttrs::merge(
fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}),
fatPtrs.at({fatPtrFalse[0], fatPtrFalse[1]}));

return success();
}
Expand Down Expand Up @@ -1434,14 +1431,6 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
assert(i < ifOp.thenYield().getNumOperands() &&
i + 1 < ifOp.thenYield().getNumOperands() &&
"expected idx to be within bounds of IfOp's results");
Value thenFatPtrBase = ifOp.thenYield().getOperand(i);
Value thenFatPtrOffset = ifOp.thenYield().getOperand(i + 1);
Value elseFatPtrBase = ifOp.elseYield().getOperand(i);
Value elseFatPtrOffset = ifOp.elseYield().getOperand(i + 1);
assert((fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}) ==
fatPtrs.at({elseFatPtrBase, elseFatPtrOffset})) &&
"expected then fat ptr canNarrow and else fat ptr canNarrow "
"to be equal");
}
}
}
Expand All @@ -1467,8 +1456,17 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
for (int64_t idx : yieldPtrOffsets) {
Value thenFatPtrBase = newIfOp.thenYield().getOperand(idx);
Value thenFatPtrOffset = newIfOp.thenYield().getOperand(idx + 1);
fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] =
fatPtrs.at({thenFatPtrBase, thenFatPtrOffset});
const auto &thenAttrs = fatPtrs.at({thenFatPtrBase, thenFatPtrOffset});
if (withElseRegion) {
Value elseFatPtrBase = newIfOp.elseYield().getOperand(idx);
Value elseFatPtrOffset = newIfOp.elseYield().getOperand(idx + 1);
const auto &elseAttrs = fatPtrs.at({elseFatPtrBase, elseFatPtrOffset});
fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] =
FatPointers::FatPtrAttrs::merge(thenAttrs, elseAttrs);
} else {
fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] =
thenAttrs;
}
}

ResultRange results = newIfOp.getResults();
Expand Down Expand Up @@ -1708,7 +1706,7 @@ struct InitFuncPtrArgs : OpRewritePattern<tt::FuncOp> {
rewriter.replaceAllUsesExcept(arg, dummyCast.getResult(0), dummyCast);
fatPtrs[{arg, zeroOffset}].canNarrow = true;
if (bitness != 64)
fatPtrs[{arg, zeroOffset}].smallTensorBase = arg;
fatPtrs[{arg, zeroOffset}].isSmallTensor = true;
}

newOp->setDiscardableAttr(kInitFuncArgsRewritten, rewriter.getUnitAttr());
Expand Down
Loading