Skip to content

Commit a48e358

Browse files
[CONSAN] BufferRegion Analysis (#8837)
This change factors out the responsibility for determining what buffer regions can be independently accessed in the program to a dedicated analysis. Region is a pair (offset, length), and single shmem/tmem allocation can be partitioned in many ways using indexing and slicing operations. Before this change ConSan only did a very superficial IR tracking to check if an allocation is used in a multi-buferred way. Generating full list of the possible buffer regions is a pre-requisite for supporting aliased shmem/tmem in ConSan. In the next step I will add creating an aliasing matrix and proper checks to ConSan. This change also introduces some smaller cleanups and refactoring of FunctionBuilder, like moving the checks for existing auxData to the `create` methods, and ditching error-prone `operator[]`.
1 parent cede64c commit a48e358

File tree

14 files changed

+1265
-429
lines changed

14 files changed

+1265
-429
lines changed

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ void registerTestAliasPass();
5454
void registerTestAlignmentPass();
5555
void registerAMDTestAlignmentPass();
5656
void registerTestAllocationPass();
57+
void registerTestBufferRegionPass();
5758
void registerTestMembarPass();
5859
void registerTestAMDGPUMembarPass();
5960
void registerTestTritonAMDGPURangeAnalysis();
@@ -75,6 +76,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
7576
mlir::test::registerTestAlignmentPass();
7677
mlir::test::registerAMDTestAlignmentPass();
7778
mlir::test::registerTestAllocationPass();
79+
mlir::test::registerTestBufferRegionPass();
7880
mlir::test::registerTestMembarPass();
7981
mlir::test::registerTestLoopPeelingPass();
8082
mlir::test::registerTestAMDGPUMembarPass();
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#ifndef TRITON_ANALYSIS_BUFFER_REGION_H
2+
#define TRITON_ANALYSIS_BUFFER_REGION_H
3+
4+
#include <limits>
5+
#include <set>
6+
7+
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
8+
#include "mlir/IR/Value.h"
9+
10+
namespace mlir::triton {
11+
12+
//===----------------------------------------------------------------------===//
13+
// BufferRegion: a single logical region derived from an alloc
14+
//===----------------------------------------------------------------------===//
15+
struct BufferRegion {
16+
uint32_t baseOffset;
17+
uint32_t length;
18+
19+
bool operator==(const BufferRegion &other) const {
20+
return baseOffset == other.baseOffset && length == other.length;
21+
}
22+
23+
bool operator<(const BufferRegion &other) const {
24+
if (baseOffset != other.baseOffset)
25+
return baseOffset < other.baseOffset;
26+
return length < other.length;
27+
}
28+
29+
template <typename T> void print(T &os) const {
30+
os << "[" << baseOffset << ", " << length << "]";
31+
}
32+
};
33+
34+
} // namespace mlir::triton
35+
36+
namespace llvm {
37+
38+
using namespace mlir::triton;
39+
40+
template <> struct DenseMapInfo<BufferRegion> {
41+
static BufferRegion getEmptyKey() {
42+
constexpr uint32_t empty = std::numeric_limits<uint32_t>::max();
43+
return BufferRegion{empty, empty};
44+
}
45+
static BufferRegion getTombstoneKey() {
46+
constexpr uint32_t tombstone = std::numeric_limits<uint32_t>::max() - 1;
47+
return BufferRegion{tombstone, tombstone};
48+
}
49+
static unsigned getHashValue(const BufferRegion &r) {
50+
return llvm::hash_combine(r.baseOffset, r.length);
51+
}
52+
static bool isEqual(const BufferRegion &a, const BufferRegion &b) {
53+
return a == b;
54+
}
55+
};
56+
57+
} // namespace llvm
58+
59+
namespace mlir::triton {
60+
61+
//===----------------------------------------------------------------------===//
62+
// RegionInfo lattice
63+
//===----------------------------------------------------------------------===//
64+
//
65+
// This wraps a set of BufferRegions and provides lattice semantics
66+
//
67+
struct RegionInfo {
68+
using RegionList = llvm::DenseSet<BufferRegion>;
69+
RegionList regions;
70+
71+
RegionInfo() = default;
72+
RegionInfo(const RegionList &r) : regions(r) {}
73+
74+
// Lattice join: union of regions
75+
static RegionInfo join(const RegionInfo &lhs, const RegionInfo &rhs) {
76+
RegionInfo result = lhs;
77+
for (const auto &reg : rhs.regions)
78+
if (llvm::find(result.regions, reg) == result.regions.end())
79+
result.regions.insert(reg);
80+
return result;
81+
}
82+
83+
bool operator==(const RegionInfo &other) const {
84+
if (regions.size() != other.regions.size())
85+
return false;
86+
for (auto &r : regions)
87+
if (llvm::find(other.regions, r) == other.regions.end())
88+
return false;
89+
return true;
90+
}
91+
92+
template <typename T> void print(T &os) const {
93+
llvm::SmallVector<BufferRegion> sortedRegions(regions.begin(),
94+
regions.end());
95+
llvm::sort(sortedRegions, [](const BufferRegion &a, const BufferRegion &b) {
96+
return a < b;
97+
});
98+
llvm::interleaveComma(sortedRegions, os,
99+
[&](const BufferRegion &r) { r.print(os); });
100+
}
101+
102+
static RegionInfo getPessimisticValueState(MLIRContext *context = nullptr) {
103+
return RegionInfo(); // means "unknown / empty"
104+
}
105+
static RegionInfo getPessimisticValueState(Value) { return RegionInfo(); }
106+
};
107+
108+
//===----------------------------------------------------------------------===//
109+
// BufferRegionAnalysis (Sparse Forward Dataflow)
110+
//===----------------------------------------------------------------------===//
111+
//
112+
// Produces a RegionInfo lattice for each MemDesc/ptr-like SSA value,
113+
// and also collects a global list of all discovered BufferRegions.
114+
//
115+
class BufferRegionAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
116+
dataflow::Lattice<RegionInfo>> {
117+
118+
public:
119+
using Base =
120+
dataflow::SparseForwardDataFlowAnalysis<dataflow::Lattice<RegionInfo>>;
121+
using Base::getLatticeElement;
122+
using Base::SparseForwardDataFlowAnalysis;
123+
124+
enum RegionType { SHARED_MEMORY, TENSOR_MEMORY, BARRIER, NUM_REGION_TYPES };
125+
126+
static bool isMemoryAccessOperation(Operation *op);
127+
128+
// ------------------------------
129+
// Public API for ConSan
130+
// ------------------------------
131+
132+
/// Return the list of all unique (alloc,offset,len) buffer regions
133+
/// discovered by the analysis.
134+
llvm::SmallVector<BufferRegion>
135+
getAllUsedBufferRegions(RegionType type) const {
136+
return llvm::to_vector(usedBufferRegions[type]);
137+
}
138+
139+
void calculateUsedBufferRegions(Operation *op);
140+
141+
// ------------------------------
142+
// Required overrides
143+
// ------------------------------
144+
145+
void setToEntryState(dataflow::Lattice<RegionInfo> *lat) override {
146+
propagateIfChanged(
147+
lat, lat->join(RegionInfo::getPessimisticValueState(lat->getAnchor())));
148+
}
149+
150+
LogicalResult visitOperation(
151+
Operation *op,
152+
llvm::ArrayRef<const dataflow::Lattice<RegionInfo> *> operands,
153+
llvm::ArrayRef<dataflow::Lattice<RegionInfo> *> results) override;
154+
155+
void visitNonControlFlowArguments(
156+
Operation *op, const RegionSuccessor &successor,
157+
llvm::ArrayRef<dataflow::Lattice<RegionInfo> *> argLattices,
158+
unsigned firstIndex) override;
159+
160+
LogicalResult initialize(Operation *top) override;
161+
162+
private:
163+
// Global registry of all regions
164+
std::set<BufferRegion> usedBufferRegions[NUM_REGION_TYPES];
165+
};
166+
167+
} // namespace mlir::triton
168+
169+
#endif // TRITON_ANALYSIS_BUFFER_REGION_H

include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -182,35 +182,35 @@ class FunctionBuilder {
182182
// stageAccessForCommit: mark the buffer as staged (value -1) in the
183183
// outstanding commit table for this thread.
184184
void createStageAccessForCommitCall(ImplicitLocOpBuilder &b, Value buf,
185-
int thread, Value pred, ValueType buffers,
186-
ValueType outstandingCommits,
185+
int thread, Value pred, MemType memType,
186+
CommitKind::Kind commitKind,
187187
Operation *insertPoint);
188188
// commitAccesses: convert staged entries to 1 and increment outstanding
189189
// commits greater than zero for the committing thread.
190190
void createCommitAccessesCall(ImplicitLocOpBuilder &b, int thread, Value pred,
191-
ValueType outstandingCommits,
191+
CommitKind::Kind commitKind,
192192
Operation *insertPoint);
193193
// clearOutstandingCommitsTransferWrites: clear entries farther than
194194
// outstandingNum from the thread and set write visibility for threads in
195195
// transferThreadMask.
196196
void createClearOutstandingCommitsTransferWritesCall(
197197
ImplicitLocOpBuilder &b, int thread, uint64_t transferThreadMask,
198-
int outstandingNum, Value pred, ValueType outstandingCommits,
199-
ValueType writeVisibility, Operation *insertPoint);
198+
int outstandingNum, Value pred, CommitKind::Kind commitKind,
199+
MemType memType, Operation *insertPoint);
200200
// clearOutstandingCommitsTransferReads: clear entries farther than
201201
// outstandingNum from the thread and set read visibility for threads in
202202
// transferThreadMask.
203203
void createClearOutstandingCommitsTransferReadsCall(
204204
ImplicitLocOpBuilder &b, int thread, uint64_t transferThreadMask,
205-
int outstandingNum, Value pred, ValueType outstandingCommits,
206-
ValueType readVisibility, Operation *insertPoint);
205+
int outstandingNum, Value pred, CommitKind::Kind commitKind,
206+
MemType memType, Operation *insertPoint);
207207
// checkOutstandingCommits: assert that the outstanding commit row for the
208208
// buffer is zero before the access described by pendingAccessType.
209209
void createCheckOutstandingCommitsCall(ImplicitLocOpBuilder &b, Value buf,
210210
int thread,
211211
StringRef pendingAccessType,
212-
Value pred, ValueType buffers,
213-
ValueType outstandingCommits,
212+
Value pred, MemType memType,
213+
CommitKind::Kind commitKind,
214214
Operation *insertPoint);
215215

216216
private:

include/triton/Dialect/TritonInstrument/IR/Utility.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,16 @@ struct ValueType {
4747
struct AuxDataMap {
4848
struct RegionToValueMap {
4949
DenseMap<Region *, ValueType> values;
50-
ValueType &operator[](Region *region) { return values[region]; }
51-
ValueType &operator[](Operation *op) {
52-
return values[getEnclosingParitionOrFunctionRegion(op)];
50+
ValueType at(Region *region) {
51+
if (values.find(region) == values.end()) {
52+
assert(false && "Region not found in AuxDataMap");
53+
}
54+
return values[region];
5355
}
56+
ValueType at(Operation *op) {
57+
return at(getEnclosingParitionOrFunctionRegion(op));
58+
}
59+
void insert(Region *region, ValueType value) { values[region] = value; }
5460
bool empty() const { return values.empty(); }
5561

5662
private:
@@ -75,8 +81,8 @@ struct AuxDataMap {
7581

7682
private:
7783
void getBuffersAndBarriers(ModuleOp module,
78-
SmallVector<SmallVector<int32_t>, 2> &bufValues,
79-
SmallVector<int32_t> &barrierValues);
84+
SmallVector<SmallVector<uint32_t>, 2> &bufValues,
85+
SmallVector<uint32_t> &barrierValues);
8086
void passToWarpSpecialize(triton::FuncOp func, ValueType value,
8187
RegionToValueMap &map);
8288
void createInWarpSpecialize(

0 commit comments

Comments
 (0)