Skip to content

Commit 9073370

Browse files
[CONSAN] Adding support for aliasing (#8939)
This change adds support for buffer aliasing in ConSan. This is achieved by using BufferRegion analysis to create AliasingMatrix that marks which buffers may overlap with each other. This information is then used during the visibility checks: instead of verifying that the buffer being accessed is visible, we load a row from the aliasing matrix, describing all the buffers that may alias buffer being access (including the buffer in question). This set of buffers is then checked for having visibility into potential outstanding reads/writes. There is still follow-up change coming: moving BufferRegion analysis to TritonInstrument. It was not done in this PR to make review easier.
1 parent f90bf63 commit 9073370

File tree

14 files changed

+1250
-486
lines changed

14 files changed

+1250
-486
lines changed

include/triton/Analysis/BufferRegion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ class BufferRegionAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
162162
private:
163163
// Global registry of all regions
164164
std::set<BufferRegion> usedBufferRegions[NUM_REGION_TYPES];
165+
166+
static void verifyOpIsSupported(Operation *op);
165167
};
166168

167169
} // namespace mlir::triton

include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,25 +117,27 @@ class FunctionBuilder {
117117
// from the visibility bitmask. We know this is safe because there cannot be
118118
// outstanding writes to this buffer at this point.
119119
void createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
120-
uint64_t threadMask, Value pred,
121-
MemType memType, Operation *insertPoint);
120+
uint32_t length, uint64_t threadMask,
121+
Value pred, MemType memType,
122+
Operation *insertPoint);
122123
// setReadVisibility: add the threads set in threadMask to the buffer's read
123124
// visibility bitmask.
124125
void createSetReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
125-
uint64_t threadMask, Value pred,
126-
MemType memType, Operation *insertPoint);
126+
uint32_t length, uint64_t threadMask,
127+
Value pred, MemType memType,
128+
Operation *insertPoint);
127129
// clearWriteTracking: clear all the information about threads writing to a
128130
// buffer.
129131
void createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value buf,
130-
Value pred, MemType memType,
131-
Operation *insertPoint);
132+
uint32_t length, Value pred,
133+
MemType memType, Operation *insertPoint);
132134
// clearReadVisibility: clear the read visibility for a buffer.
133135
void createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
134-
Value pred, MemType memType,
135-
Operation *insertPoint);
136+
uint32_t length, Value pred,
137+
MemType memType, Operation *insertPoint);
136138
// clearReadTracking: clear the read tracking for a buffer.
137139
void createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value buf,
138-
Value pred, MemType memType,
140+
uint32_t length, Value pred, MemType memType,
139141
Operation *insertPoint);
140142
// trackVisibleWrites: snapshot buffers currently visible to the thread into
141143
// the tracking table for a barrier.
@@ -160,15 +162,15 @@ class FunctionBuilder {
160162
// verifyWriteVisibility: ensure the thread either sees the latest write or no
161163
// other thread is writing the buffer.
162164
void createVerifyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
163-
int thread, StringRef operandName,
164-
Value pred, MemType memType,
165-
Operation *insertPoint);
165+
uint32_t length, int thread,
166+
StringRef operandName, Value pred,
167+
MemType memType, Operation *insertPoint);
166168
// verifyReadVisibility: ensure all reads from the buffer are visible to the
167169
// thread.
168170
void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
169-
int thread, StringRef operandName,
170-
Value pred, MemType memType,
171-
Operation *insertPoint);
171+
uint32_t length, int thread,
172+
StringRef operandName, Value pred,
173+
MemType memType, Operation *insertPoint);
172174
// copyWriteVisibility: replicate the write visibility bit of sourceThread to
173175
// every destination thread in destMask.
174176
void createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread,
@@ -182,7 +184,8 @@ class FunctionBuilder {
182184
// stageAccessForCommit: mark the buffer as staged (value -1) in the
183185
// outstanding commit table for this thread.
184186
void createStageAccessForCommitCall(ImplicitLocOpBuilder &b, Value buf,
185-
int thread, Value pred, MemType memType,
187+
uint32_t length, int thread, Value pred,
188+
MemType memType,
186189
CommitKind::Kind commitKind,
187190
Operation *insertPoint);
188191
// commitAccesses: convert staged entries to 1 and increment outstanding
@@ -207,7 +210,7 @@ class FunctionBuilder {
207210
// checkOutstandingCommits: assert that the outstanding commit row for the
208211
// buffer is zero before the access described by pendingAccessType.
209212
void createCheckOutstandingCommitsCall(ImplicitLocOpBuilder &b, Value buf,
210-
int thread,
213+
uint32_t length, int thread,
211214
StringRef pendingAccessType,
212215
Value pred, MemType memType,
213216
CommitKind::Kind commitKind,

include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,30 +34,33 @@ def TTI_ExperimentalAssertInThreadOp : TTI_Op<"experimental_assert_in_thread", [
3434
}
3535

3636

37-
def TTI_ExperimentalBufferPointersOp : TTI_Op<"experimental_buffer_pointers", [Pure]> {
38-
let summary = "definte an array of pointers to shared memory buffers";
37+
def TTI_ExperimentalBufferDescriptorsOp
38+
: TTI_Op<"experimental_buffer_descriptors", [Pure]> {
39+
let summary = "define an array of buffer descriptors";
3940
let description = [{
40-
Create a tensor of pointers to shared memory buffers.
41+
Create a tensor of buffer descriptors packing 32-bit pointer offsets and
42+
32-bit lengths into 64-bit elements.
4143
}];
42-
let arguments = (ins DenseI32ArrayAttr:$offsets, TT_MemTypeAttr:$memType);
44+
let arguments = (ins DenseI32ArrayAttr:$offsets, DenseI32ArrayAttr:$lengths,
45+
TT_MemTypeAttr:$memType);
4346
let results = (outs TT_Tensor:$result);
4447
let assemblyFormat = [{
45-
$offsets `,` $memType attr-dict `:` type($result)
48+
$offsets `,` $lengths `,` $memType attr-dict `:` type($result)
4649
}];
4750
}
4851

49-
def TTI_ExperimentalMemDescToI64Op : TTI_Op<"experimental_memdesc_to_i64", [Pure]> {
50-
let summary = "Convert a memdesc into its base pointer as i64";
52+
def TTI_ExperimentalMemDescToI32Op : TTI_Op<"experimental_memdesc_to_i32", [Pure]> {
53+
let summary = "Convert a memdesc into its base pointer as i32";
5154
let description = [{
52-
Extract the base pointer from the given memdesc and return it as a 64-bit
55+
Extract the base pointer from the given memdesc and return it as a 32-bit
5356
integer. This can be used to compare the memdesc against tensors of barrier
5457
pointers maintained by the concurrency sanitizer.
5558
}];
5659
let arguments = (ins TTG_MemDescType:$memdesc);
57-
let results = (outs I64:$result);
60+
let results = (outs I32:$result);
5861
let builders = [
5962
OpBuilder<(ins "Value":$memdesc), [{
60-
build($_builder, $_state, $_builder.getI64Type(), memdesc);
63+
build($_builder, $_state, $_builder.getI32Type(), memdesc);
6164
}]>
6265
];
6366
let assemblyFormat = "$memdesc attr-dict `:` type($memdesc)";

include/triton/Dialect/TritonInstrument/IR/Utility.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef TRITONINSTRUMENT_UTILITY_H
22
#define TRITONINSTRUMENT_UTILITY_H
33

4+
#include "triton/Analysis/BufferRegion.h"
45
#include "triton/Dialect/Triton/IR/Utility.h"
56
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
67
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
@@ -74,15 +75,17 @@ struct AuxDataMap {
7475
RegionToValueMap readVisibility[numMemTypes];
7576
RegionToValueMap readTracking[numMemTypes];
7677
RegionToValueMap commits[CommitKind::NumCommitKinds];
78+
RegionToValueMap aliasMatrices[numMemTypes];
7779
RegionToValueMap lock;
7880
RegionToValueMap waiting;
7981

8082
void populateAndPassToWarpSpecialize(ModuleOp module);
8183

8284
private:
83-
void getBuffersAndBarriers(ModuleOp module,
84-
SmallVector<SmallVector<uint32_t>, 2> &bufValues,
85-
SmallVector<uint32_t> &barrierValues);
85+
void getBuffersAndBarriers(
86+
ModuleOp module,
87+
SmallVector<SmallVector<triton::BufferRegion>, 2> &bufRegions,
88+
SmallVector<triton::BufferRegion> &barrierRegions);
8689
void passToWarpSpecialize(triton::FuncOp func, ValueType value,
8790
RegionToValueMap &map);
8891
void createInWarpSpecialize(

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,22 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
1515
InterfaceMethod<"Return the A operand.",
1616
"::mlir::TypedValue<::mlir::triton::gpu::MemDescType>",
1717
"getA">,
18+
InterfaceMethod<"Return the B operand.",
19+
"::mlir::TypedValue<::mlir::triton::gpu::MemDescType>",
20+
"getB">,
1821
InterfaceMethod<"Return the accumulator init flag.",
1922
"::mlir::Value",
2023
"useAccumulator">,
2124
InterfaceMethod<"Set the accumulator init flag.",
2225
"void",
2326
"setUseAccumulator",
2427
(ins "::mlir::Value":$flag)>,
28+
InterfaceMethod<"Return the completion barriers of this MMAv5 op.",
29+
"::mlir::ValueRange",
30+
"getCompletionBarriers">,
31+
InterfaceMethod<"Return the completion barrier predicates of this MMAv5 op.",
32+
"::mlir::ValueRange",
33+
"getCompletionBarrierPreds">,
2534
InterfaceMethod<"Associate a new completion barrier to this MMAv5 op.",
2635
"void",
2736
"addCompletionBarrier",

0 commit comments

Comments
 (0)