Skip to content

Commit 855ca6c

Browse files
yangshuxinShuxin YangShuxin Yang
authored
[AMD] enhance range analysis for buffer ops (#8372)
This change fix bugs in range-analysis, and let buffer-ops use the range-analysis result to decide if it's legal to convert memory-op to buffer-ops. The highlight are following: * Range Analysis - fix the way to use `tl.assume`. Previously, it does not consider the control flow relationship between, say `tl.assume x > 0` and the location of occurrence of x. - correct the value range of `make_range(begin, end)`, previous vr is [begin, end], now is [begin, end-1]. Small change in concept incur huge change the regression test. * Buffer-ops - for large tensor (>2G), remove the ad-hoc, and mistaken range-analysis in the pass. It only relies on the result of the range-analysis pass. - previous, buffer-ops pass only check element-index > 0. The right condition is byte-offset in [0, 2G-element-size]. - Previous there is a similar work here #7908, contributed by @njriasan . My change to this part is similar but fix some bugs in PR7908 (.e.g. lattice could be nullptr), and update large number of testings. That being said, now that @njriasan made the first change, credit for the part belong to him. --------- Co-authored-by: Shuxin Yang <[email protected]> Co-authored-by: Shuxin Yang <[email protected]>
1 parent df9fe1e commit 855ca6c

File tree

8 files changed

+1007
-430
lines changed

8 files changed

+1007
-430
lines changed

test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
3636
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
3737
// COMMON: buffer_load %arg0[%[[offset]]]
3838
%9 = tt.load %6 : tensor<256x!tt.ptr<f32>, #blocked0>
39-
// COMMON: buffer_load %arg1[%[[offset]]]
39+
// Note: offset = pid * 256 + arange(0, 256); byte-ofst="offset * sizeof(i32)" may not fall into range of 2G.
40+
// COMMON-NOT: buffer_load %arg1[%[[offset]]]
4041
%10 = tt.load %8 : tensor<256x!tt.ptr<f32>, #blocked0>
4142
// COMMON: %[[data:.*]] = arith.addf
4243
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
4344
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
4445
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
45-
// COMMON: buffer_store %[[data]], %arg2[%[[offset]]]
46+
// Note: see the explanation above
47+
// COMMON-NOT: buffer_store %[[data]], %arg2[%[[offset]]]
4648
tt.store %13, %11 : tensor<256x!tt.ptr<f32>, #blocked0>
4749
tt.return
4850
}
@@ -70,7 +72,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
7072
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
7173
%8 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
7274
%9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
73-
// COMMON: buffer_load %[[scalar_ptr]][%[[offset]]]
75+
// Note: the base "scalar_ptr" points to arg0 which is a large-tensor.
76+
// the offset="%sub + arange(0,1024)" where "%sub=pid*1024-128",
77+
// We can prove "offset > 0", but cannot prove byte-offset < 2G.
78+
// COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset]]]
7479
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>, #blocked>
7580
tt.return %10 : tensor<1024xf32, #blocked>
7681
}
@@ -122,7 +127,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
122127
// COMMON: %[[offset_32_bit:.*]] = arith.trunci
123128
%narrow4 = arith.trunci %4 : tensor<1024xi64, #blocked> to tensor <1024xi32, #blocked>
124129
%9 = tt.addptr %8, %narrow4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
125-
// COMMON: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]]
130+
// Note: base is arg0 which is large-tensor, the offset=int(long(pid*1024) * long(arange(0, 1024))
131+
// offset is in [0, i32-max].
132+
// COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]]
126133
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>, #blocked>
127134
tt.return %10 : tensor<1024xf32, #blocked>
128135
}
@@ -555,7 +562,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
555562
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
556563
%6 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
557564
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
558-
// COMMON: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]]
565+
// Note: the large tensor is accessed, offset is in the range of [0, smax].
566+
// without tl.assume the range would be [-128, smax]
567+
// COMMON-NOT: amdgpu.buffer_atomic_rmw
559568
%8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked>
560569
tt.return %8 : tensor<1024xf32, #blocked>
561570
}

test/TritonGPU/amd/amd-convert-buffer-ops.mlir

Lines changed: 174 additions & 23 deletions
Large diffs are not rendered by default.

test/TritonGPU/amd/amd-range-analysis.mlir

Lines changed: 353 additions & 102 deletions
Large diffs are not rendered by default.

third_party/amd/include/Analysis/RangeAnalysis.h

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
55
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
66
#include "mlir/Dialect/Arith/IR/Arith.h"
7+
#include "mlir/IR/Dominance.h"
78
#include "mlir/Interfaces/LoopLikeInterface.h"
89

910
namespace mlir::triton {
@@ -32,15 +33,20 @@ namespace mlir::triton::AMD {
3233
/// See visitRegionSuccessors.
3334
struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis {
3435
using dataflow::IntegerRangeAnalysis::IntegerRangeAnalysis;
36+
using Base = dataflow::IntegerRangeAnalysis;
3537
TritonIntegerRangeAnalysis(
3638
DataFlowSolver &solver,
37-
const DenseMap<Value, SetVector<Operation *>> &assumptions)
38-
: dataflow::IntegerRangeAnalysis(solver), assumptions(assumptions) {}
39+
const DenseMap<Value, SetVector<Operation *>> &assumptions,
40+
DominanceInfo *dominanceInfo, bool assumeNoArithOverflow_ = false)
41+
: dataflow::IntegerRangeAnalysis(solver), assumptions(assumptions),
42+
domInfo(dominanceInfo), assumeNoArithOverflow(assumeNoArithOverflow_) {}
3943

4044
void setToEntryState(dataflow::IntegerValueRangeLattice *lattice) override;
4145

4246
void initializeFuncOp(triton::FuncOp funcOp);
4347

48+
LogicalResult initialize(Operation *top) override;
49+
4450
LogicalResult visitOperation(
4551
Operation *op,
4652
ArrayRef<const dataflow::IntegerValueRangeLattice *> operands,
@@ -95,7 +101,8 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis {
95101
/// llvm.intr.assume %assumesltlhs : i1
96102
/// for %K, will produce a final range
97103
/// [0, 2147483647] ∩ [-2147483648, 128] = [0, 128]
98-
std::optional<ConstantIntRanges> maybeGetAssumedRange(Value anchor) const;
104+
std::optional<ConstantIntRanges> maybeGetAssumedRange(Value anchor,
105+
Block *useBlock) const;
99106

100107
int64_t getTotalLoopTripCount(LoopLikeOpInterface loop);
101108

@@ -125,6 +132,32 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis {
125132
/// If one uses collectAssumptions below then `assumptions` will look like
126133
/// %K -> {arith.cmpi slt..., arith.cmpi sge}.
127134
llvm::DenseMap<Value, SetVector<Operation *>> assumptions;
135+
136+
/// The defaultTransferFunc is the default transfer function for this dataflow
137+
/// problem.
138+
/// @param[in] op: the Operation in question
139+
/// @param[in] result: a particular value defined by this op. Note that op
140+
/// may define multiple values.
141+
/// @param[in] srcLattices: lattices of all source operands
142+
/// @param[in] destLattices: lattices all all result values
143+
/// @param[in] incomingRange: the value-range inffered for result
144+
void defaultTransferFunc(
145+
Operation *op, Value result,
146+
ArrayRef<const dataflow::IntegerValueRangeLattice *> srcLattices,
147+
ArrayRef<dataflow::IntegerValueRangeLattice *> destLattices,
148+
const IntegerValueRange &incomingRange);
149+
150+
private:
151+
void visitYieldHelper(Operation *yieldOp, Value value);
152+
LogicalResult visitOperationHelper(
153+
Operation *op,
154+
ArrayRef<const dataflow::IntegerValueRangeLattice *> operands,
155+
ArrayRef<dataflow::IntegerValueRangeLattice *> resultsLattices);
156+
157+
DenseSet<Value> signedIntValues;
158+
llvm::SmallMapVector<Value, ConstantIntRanges, 2> opResultAssumption;
159+
DominanceInfo *domInfo = nullptr;
160+
bool assumeNoArithOverflow = false;
128161
};
129162

130163
std::optional<SmallVector<std::optional<ConstantIntRanges>>>

0 commit comments

Comments
 (0)