-
Notifications
You must be signed in to change notification settings - Fork 2.6k
[AMD] CanonicalizePointers: Handle different base pointers and offsets #9541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| // RUN: triton-opt %s -split-input-file -tritonamdgpu-canonicalize-pointers -canonicalize | FileCheck %s | ||
|
|
||
| // CHECK-LABEL: tt.func @scf_if_different_bases | ||
| // CHECK: [[BASE:%.*]] = arith.select %arg2, %arg0, %arg1 : !tt.ptr<f32> | ||
| // CHECK: [[OFFSET:%.*]] = arith.select %arg2, %c16_i32, %c32_i32 : i32 | ||
| // CHECK: [[PTR:%.*]] = tt.addptr [[BASE]], [[OFFSET]] | ||
| // CHECK: tt.load [[PTR]] | ||
| module attributes {"ttg.num-warps" = 4 : i32} { | ||
| tt.func @scf_if_different_bases(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, | ||
| %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, | ||
| %arg2: i1) -> f32 { | ||
| %c16_i32 = arith.constant 16 : i32 | ||
| %c32_i32 = arith.constant 32 : i32 | ||
| %0 = scf.if %arg2 -> (!tt.ptr<f32>) { | ||
| %2 = tt.addptr %arg0, %c16_i32 : !tt.ptr<f32>, i32 | ||
| scf.yield %2 : !tt.ptr<f32> | ||
| } else { | ||
| %2 = tt.addptr %arg1, %c32_i32 : !tt.ptr<f32>, i32 | ||
| scf.yield %2 : !tt.ptr<f32> | ||
| } | ||
| %1 = tt.load %0 : !tt.ptr<f32> | ||
| tt.return %1 : f32 | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: tt.func @select_different_bases | ||
| // CHECK: [[BASE:%.*]] = arith.select %arg2, %arg0, %arg1 : !tt.ptr<f32> | ||
| // CHECK: [[OFFSET:%.*]] = arith.select %arg2, %c16_i32, %c32_i32 : i32 | ||
| // CHECK: [[PTR:%.*]] = tt.addptr [[BASE]], [[OFFSET]] | ||
| // CHECK: tt.load [[PTR]] | ||
| module attributes {"ttg.num-warps" = 4 : i32} { | ||
| tt.func @select_different_bases(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, | ||
| %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, | ||
| %arg2: i1) -> f32 { | ||
| %c16_i32 = arith.constant 16 : i32 | ||
| %c32_i32 = arith.constant 32 : i32 | ||
| %2 = tt.addptr %arg0, %c16_i32 : !tt.ptr<f32>, i32 | ||
| %3 = tt.addptr %arg1, %c32_i32 : !tt.ptr<f32>, i32 | ||
| %4 = arith.select %arg2, %2, %3 : !tt.ptr<f32> | ||
| %5 = tt.load %4 : !tt.ptr<f32> | ||
| tt.return %5 : f32 | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -462,18 +462,28 @@ struct FatPointers { | |
|
|
||
| friend bool operator==(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { | ||
| return lhs.canNarrow == rhs.canNarrow && | ||
| lhs.attributes == rhs.attributes && | ||
| lhs.smallTensorBase == rhs.smallTensorBase; | ||
| lhs.isSmallTensor == rhs.isSmallTensor && | ||
| lhs.attributes == rhs.attributes; | ||
| } | ||
|
|
||
| friend bool operator!=(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { | ||
| return !(lhs == rhs); | ||
| } | ||
|
|
||
| static FatPtrAttrs merge(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would |
||
| FatPtrAttrs result; | ||
| result.canNarrow = lhs.canNarrow && rhs.canNarrow; | ||
| result.isSmallTensor = lhs.isSmallTensor && rhs.isSmallTensor; | ||
| for (const auto &attr : lhs.attributes) { | ||
| auto it = rhs.attributes.find(attr.first); | ||
| if (it != rhs.attributes.end() && it->second == attr.second) | ||
| result.attributes[attr.first] = attr.second; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe you can replace the data-type to small vectormap? |
||
| } | ||
| return result; | ||
| } | ||
|
|
||
kelesvol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| llvm::DenseMap<StringRef, Attribute> attributes; | ||
| // If the fat-pointer points to somewhere in a small-tensor, keep track the | ||
| // base of the tensor. | ||
| Value smallTensorBase; | ||
| bool isSmallTensor = false; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nope, we cannot say ptr-to-a == ptr-to-b.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above, this is a part of the attributes, it doesn't mean pointers are the same. |
||
| bool canNarrow = false; | ||
| }; | ||
|
|
||
|
|
@@ -745,7 +755,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> { | |
| RewriterBase::InsertionGuard guard(rewriter); | ||
| rewriter.setInsertionPoint(addPtrOp); | ||
|
|
||
| if (fatPtrs.at({fatPtrBase, fatPtrOffset}).smallTensorBase) | ||
| if (fatPtrs.at({fatPtrBase, fatPtrOffset}).isSmallTensor) | ||
| return rewriteSmallTensorPtr(addPtrOp, adaptor, rewriter); | ||
|
|
||
| // Query all discardable attributes that we want to preserve | ||
|
|
@@ -861,7 +871,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> { | |
| const auto &oldAttr = fatPtrs.at({fatPtrBase, fatPtrOffset}); | ||
|
|
||
| LDBG("smal-tensor addPtr: " << addPtrOp); | ||
| LDBG(" - with tensor-base: " << oldAttr.smallTensorBase); | ||
| LDBG(" - isSmallTensor: " << oldAttr.isSmallTensor); | ||
| LDBG(" - with originl offset: " << origOffset); | ||
| LDBG(" - fatPtr base: " << fatPtrBase); | ||
| LDBG(" - fatPtr offst: " << fatPtrOffset); | ||
|
|
@@ -1362,17 +1372,6 @@ class ConvertArithSelectOp | |
| // select of base and offset | ||
| ValueRange fatPtrFalse = adaptor.getFalseValue(); | ||
| ValueRange fatPtrTrue = adaptor.getTrueValue(); | ||
| // Simple case of a scalar select: update the base pointer | ||
| if (!isa<RankedTensorType>(selectOp.getType())) { | ||
| auto newSelectOp = arith::SelectOp::create( | ||
| rewriter, selectOp.getLoc(), selectOp.getType(), | ||
| selectOp.getCondition(), fatPtrTrue[0], selectOp.getFalseValue()); | ||
| rewriter.replaceOpWithMultiple(selectOp, {{newSelectOp, fatPtrTrue[1]}}); | ||
| fatPtrs[{newSelectOp, /*fatPtrOffset*/ fatPtrTrue[1]}] = | ||
| fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}); | ||
| return success(); | ||
| } | ||
|
|
||
| // Rewrite to select(fatBaseT, fatBaseF) and select(fatOffsetT, fatOffsetF) | ||
| auto newBase = arith::SelectOp::create(rewriter, selectOp.getLoc(), | ||
| selectOp.getCondition(), | ||
|
|
@@ -1381,12 +1380,10 @@ class ConvertArithSelectOp | |
| selectOp.getCondition(), | ||
| fatPtrTrue[1], fatPtrFalse[1]); | ||
|
|
||
| assert((fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}) == | ||
| fatPtrs.at({fatPtrFalse[0], fatPtrFalse[1]})) && | ||
| "expected can narrow to be the same for both fatPtrT and fatPtrF"); | ||
|
|
||
| rewriter.replaceOpWithMultiple(selectOp, {{newBase, newOffset}}); | ||
| fatPtrs[{newBase, newOffset}] = fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}); | ||
| fatPtrs[{newBase, newOffset}] = FatPointers::FatPtrAttrs::merge( | ||
| fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}), | ||
| fatPtrs.at({fatPtrFalse[0], fatPtrFalse[1]})); | ||
|
|
||
| return success(); | ||
| } | ||
|
|
@@ -1434,14 +1431,6 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> { | |
| assert(i < ifOp.thenYield().getNumOperands() && | ||
| i + 1 < ifOp.thenYield().getNumOperands() && | ||
| "expected idx to be within bounds of IfOp's results"); | ||
| Value thenFatPtrBase = ifOp.thenYield().getOperand(i); | ||
| Value thenFatPtrOffset = ifOp.thenYield().getOperand(i + 1); | ||
| Value elseFatPtrBase = ifOp.elseYield().getOperand(i); | ||
| Value elseFatPtrOffset = ifOp.elseYield().getOperand(i + 1); | ||
| assert((fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}) == | ||
| fatPtrs.at({elseFatPtrBase, elseFatPtrOffset})) && | ||
| "expected then fat ptr canNarrow and else fat ptr canNarrow " | ||
| "to be equal"); | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -1467,8 +1456,17 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> { | |
| for (int64_t idx : yieldPtrOffsets) { | ||
| Value thenFatPtrBase = newIfOp.thenYield().getOperand(idx); | ||
| Value thenFatPtrOffset = newIfOp.thenYield().getOperand(idx + 1); | ||
| fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] = | ||
| fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}); | ||
| const auto &thenAttrs = fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}); | ||
| if (withElseRegion) { | ||
| Value elseFatPtrBase = newIfOp.elseYield().getOperand(idx); | ||
| Value elseFatPtrOffset = newIfOp.elseYield().getOperand(idx + 1); | ||
| const auto &elseAttrs = fatPtrs.at({elseFatPtrBase, elseFatPtrOffset}); | ||
| fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] = | ||
| FatPointers::FatPtrAttrs::merge(thenAttrs, elseAttrs); | ||
| } else { | ||
| fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] = | ||
| thenAttrs; | ||
| } | ||
| } | ||
|
|
||
| ResultRange results = newIfOp.getResults(); | ||
|
|
@@ -1708,7 +1706,7 @@ struct InitFuncPtrArgs : OpRewritePattern<tt::FuncOp> { | |
| rewriter.replaceAllUsesExcept(arg, dummyCast.getResult(0), dummyCast); | ||
| fatPtrs[{arg, zeroOffset}].canNarrow = true; | ||
| if (bitness != 64) | ||
| fatPtrs[{arg, zeroOffset}].smallTensorBase = arg; | ||
| fatPtrs[{arg, zeroOffset}].isSmallTensor = true; | ||
| } | ||
|
|
||
| newOp->setDiscardableAttr(kInitFuncArgsRewritten, rewriter.getUnitAttr()); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if
ptr1andptr2is from small tensor a and b respectively, they share the same attribute, then==return true?.I think this is fundamentally wrong although replacing smallTensorBase with a boolean does not expose bug.
Maybe you can add another field to indicate the case where the
baseis dynamic, it is a combination of multiple real tensors. say, introducedynamic-basefor that purpose.When merging from different base, clobber the smallTensorBase, and set the dynamic-base.
Or conservertively, just clobber of tensorBase to indicate it too complicated to handle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we compare attributes here, not the actual pointers.
Where will we use
dynamic basein this case? Currently, we keepsmallTensorBase, but don't use it other than checking if it's a small tensor. WithisSmallTensoradded, we'll no longer need that either.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then, why it's valid? Two pointers are equal if the share the same attribute even if they don't have the same base?
While I cannot give your examples as to why this will have incur problem at this moment (need to check the code), does not your change sound very dangerous? You claim two pointers are equals without comparing their base-pointer. It might works for now because, by current implementation, when two pointers involved into comparision, they may happen to have to same bases. It has huge potential problems.
We cannot weaken the condition at the huge risk of correctness just to make dealing dynamic-base slightly easier.