diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers-different-bases.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers-different-bases.mlir new file mode 100644 index 000000000000..5f4fad124efc --- /dev/null +++ b/test/TritonGPU/amd/amd-canonicalize-pointers-different-bases.mlir @@ -0,0 +1,45 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-canonicalize-pointers -canonicalize | FileCheck %s + +// CHECK-LABEL: tt.func @scf_if_different_bases +// CHECK: [[BASE:%.*]] = arith.select %arg2, %arg0, %arg1 : !tt.ptr +// CHECK: [[OFFSET:%.*]] = arith.select %arg2, %c16_i32, %c32_i32 : i32 +// CHECK: [[PTR:%.*]] = tt.addptr [[BASE]], [[OFFSET]] +// CHECK: tt.load [[PTR]] +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @scf_if_different_bases(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg2: i1) -> f32 { + %c16_i32 = arith.constant 16 : i32 + %c32_i32 = arith.constant 32 : i32 + %0 = scf.if %arg2 -> (!tt.ptr) { + %2 = tt.addptr %arg0, %c16_i32 : !tt.ptr, i32 + scf.yield %2 : !tt.ptr + } else { + %2 = tt.addptr %arg1, %c32_i32 : !tt.ptr, i32 + scf.yield %2 : !tt.ptr + } + %1 = tt.load %0 : !tt.ptr + tt.return %1 : f32 + } +} + +// ----- + +// CHECK-LABEL: tt.func @select_different_bases +// CHECK: [[BASE:%.*]] = arith.select %arg2, %arg0, %arg1 : !tt.ptr +// CHECK: [[OFFSET:%.*]] = arith.select %arg2, %c16_i32, %c32_i32 : i32 +// CHECK: [[PTR:%.*]] = tt.addptr [[BASE]], [[OFFSET]] +// CHECK: tt.load [[PTR]] +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @select_different_bases(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg2: i1) -> f32 { + %c16_i32 = arith.constant 16 : i32 + %c32_i32 = arith.constant 32 : i32 + %2 = tt.addptr %arg0, %c16_i32 : !tt.ptr, i32 + %3 = tt.addptr %arg1, %c32_i32 : !tt.ptr, i32 + %4 = arith.select %arg2, %2, %3 : !tt.ptr + %5 = tt.load %4 : !tt.ptr + tt.return %5 : f32 + } +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 33468f08c67b..3ab5511b10ea 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -462,18 +462,28 @@ struct FatPointers { friend bool operator==(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { return lhs.canNarrow == rhs.canNarrow && - lhs.attributes == rhs.attributes && - lhs.smallTensorBase == rhs.smallTensorBase; + lhs.isSmallTensor == rhs.isSmallTensor && + lhs.attributes == rhs.attributes; } friend bool operator!=(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { return !(lhs == rhs); } + static FatPtrAttrs merge(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { + FatPtrAttrs result; + result.canNarrow = lhs.canNarrow && rhs.canNarrow; + result.isSmallTensor = lhs.isSmallTensor && rhs.isSmallTensor; + for (const auto &attr : lhs.attributes) { + auto it = rhs.attributes.find(attr.first); + if (it != rhs.attributes.end() && it->second == attr.second) + result.attributes[attr.first] = attr.second; + } + return result; + } + llvm::DenseMap attributes; - // If the fat-pointer points to somewhere in a small-tensor, keep track the - // base of the tensor. - Value smallTensorBase; + bool isSmallTensor = false; bool canNarrow = false; }; @@ -745,7 +755,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern { RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(addPtrOp); - if (fatPtrs.at({fatPtrBase, fatPtrOffset}).smallTensorBase) + if (fatPtrs.at({fatPtrBase, fatPtrOffset}).isSmallTensor) return rewriteSmallTensorPtr(addPtrOp, adaptor, rewriter); // Query all discardable attributes that we want to preserve @@ -861,7 +871,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern { const auto &oldAttr = fatPtrs.at({fatPtrBase, fatPtrOffset}); LDBG("smal-tensor addPtr: " << addPtrOp); - LDBG(" - with tensor-base: " << oldAttr.smallTensorBase); + LDBG(" - isSmallTensor: " << oldAttr.isSmallTensor); LDBG(" - with originl offset: " << origOffset); LDBG(" - fatPtr base: " << fatPtrBase); LDBG(" - fatPtr offst: " << fatPtrOffset); @@ -1362,17 +1372,6 @@ class ConvertArithSelectOp // select of base and offset ValueRange fatPtrFalse = adaptor.getFalseValue(); ValueRange fatPtrTrue = adaptor.getTrueValue(); - // Simple case of a scalar select: update the base pointer - if (!isa(selectOp.getType())) { - auto newSelectOp = arith::SelectOp::create( - rewriter, selectOp.getLoc(), selectOp.getType(), - selectOp.getCondition(), fatPtrTrue[0], selectOp.getFalseValue()); - rewriter.replaceOpWithMultiple(selectOp, {{newSelectOp, fatPtrTrue[1]}}); - fatPtrs[{newSelectOp, /*fatPtrOffset*/ fatPtrTrue[1]}] = - fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}); - return success(); - } - // Rewrite to select(fatBaseT, fatBaseF) and select(fatOffsetT, fatOffsetF) auto newBase = arith::SelectOp::create(rewriter, selectOp.getLoc(), selectOp.getCondition(), @@ -1381,12 +1380,10 @@ class ConvertArithSelectOp selectOp.getCondition(), fatPtrTrue[1], fatPtrFalse[1]); - assert((fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}) == - fatPtrs.at({fatPtrFalse[0], fatPtrFalse[1]})) && - "expected can narrow to be the same for both fatPtrT and fatPtrF"); - rewriter.replaceOpWithMultiple(selectOp, {{newBase, newOffset}}); - fatPtrs[{newBase, newOffset}] = fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}); + fatPtrs[{newBase, newOffset}] = FatPointers::FatPtrAttrs::merge( + fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}), + fatPtrs.at({fatPtrFalse[0], fatPtrFalse[1]})); return success(); } @@ -1434,14 +1431,6 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern { assert(i < ifOp.thenYield().getNumOperands() && i + 1 < ifOp.thenYield().getNumOperands() && "expected idx to be within bounds of IfOp's results"); - Value thenFatPtrBase = ifOp.thenYield().getOperand(i); - Value thenFatPtrOffset = ifOp.thenYield().getOperand(i + 1); - Value elseFatPtrBase = ifOp.elseYield().getOperand(i); - Value elseFatPtrOffset = ifOp.elseYield().getOperand(i + 1); - assert((fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}) == - fatPtrs.at({elseFatPtrBase, elseFatPtrOffset})) && - "expected then fat ptr canNarrow and else fat ptr canNarrow " - "to be equal"); } } } @@ -1467,8 +1456,17 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern { for (int64_t idx : yieldPtrOffsets) { Value thenFatPtrBase = newIfOp.thenYield().getOperand(idx); Value thenFatPtrOffset = newIfOp.thenYield().getOperand(idx + 1); - fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] = - fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}); + const auto &thenAttrs = fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}); + if (withElseRegion) { + Value elseFatPtrBase = newIfOp.elseYield().getOperand(idx); + Value elseFatPtrOffset = newIfOp.elseYield().getOperand(idx + 1); + const auto &elseAttrs = fatPtrs.at({elseFatPtrBase, elseFatPtrOffset}); + fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] = + FatPointers::FatPtrAttrs::merge(thenAttrs, elseAttrs); + } else { + fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] = + thenAttrs; + } } ResultRange results = newIfOp.getResults(); @@ -1708,7 +1706,7 @@ struct InitFuncPtrArgs : OpRewritePattern { rewriter.replaceAllUsesExcept(arg, dummyCast.getResult(0), dummyCast); fatPtrs[{arg, zeroOffset}].canNarrow = true; if (bitness != 64) - fatPtrs[{arg, zeroOffset}].smallTensorBase = arg; + fatPtrs[{arg, zeroOffset}].isSmallTensor = true; } newOp->setDiscardableAttr(kInitFuncArgsRewritten, rewriter.getUnitAttr());