Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/triton/Analysis/BufferRegion.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ class BufferRegionAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
private:
// Global registry of all regions
std::set<BufferRegion> usedBufferRegions[NUM_REGION_TYPES];

static void verifyOpIsSupported(Operation *op);
};

} // namespace mlir::triton
Expand Down
37 changes: 20 additions & 17 deletions include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,25 +117,27 @@ class FunctionBuilder {
// from the visibility bitmask. We know this is safe because there cannot be
// outstanding writes to this buffer at this point.
void createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
uint64_t threadMask, Value pred,
MemType memType, Operation *insertPoint);
uint32_t length, uint64_t threadMask,
Value pred, MemType memType,
Operation *insertPoint);
// setReadVisibility: add the threads set in threadMask to the buffer's read
// visibility bitmask.
void createSetReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
uint64_t threadMask, Value pred,
MemType memType, Operation *insertPoint);
uint32_t length, uint64_t threadMask,
Value pred, MemType memType,
Operation *insertPoint);
// clearWriteTracking: clear all the information about threads writing to a
// buffer.
void createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value buf,
Value pred, MemType memType,
Operation *insertPoint);
uint32_t length, Value pred,
MemType memType, Operation *insertPoint);
// clearReadVisibility: clear the read visibility for a buffer.
void createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
Value pred, MemType memType,
Operation *insertPoint);
uint32_t length, Value pred,
MemType memType, Operation *insertPoint);
// clearReadTracking: clear the read tracking for a buffer.
void createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value buf,
Value pred, MemType memType,
uint32_t length, Value pred, MemType memType,
Operation *insertPoint);
// trackVisibleWrites: snapshot buffers currently visible to the thread into
// the tracking table for a barrier.
Expand All @@ -160,15 +162,15 @@ class FunctionBuilder {
// verifyWriteVisibility: ensure the thread either sees the latest write or no
// other thread is writing the buffer.
void createVerifyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
int thread, StringRef operandName,
Value pred, MemType memType,
Operation *insertPoint);
uint32_t length, int thread,
StringRef operandName, Value pred,
MemType memType, Operation *insertPoint);
// verifyReadVisibility: ensure all reads from the buffer are visible to the
// thread.
void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
int thread, StringRef operandName,
Value pred, MemType memType,
Operation *insertPoint);
uint32_t length, int thread,
StringRef operandName, Value pred,
MemType memType, Operation *insertPoint);
// copyWriteVisibility: replicate the write visibility bit of sourceThread to
// every destination thread in destMask.
void createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread,
Expand All @@ -182,7 +184,8 @@ class FunctionBuilder {
// stageAccessForCommit: mark the buffer as staged (value -1) in the
// outstanding commit table for this thread.
void createStageAccessForCommitCall(ImplicitLocOpBuilder &b, Value buf,
int thread, Value pred, MemType memType,
uint32_t length, int thread, Value pred,
MemType memType,
CommitKind::Kind commitKind,
Operation *insertPoint);
// commitAccesses: convert staged entries to 1 and increment outstanding
Expand All @@ -207,7 +210,7 @@ class FunctionBuilder {
// checkOutstandingCommits: assert that the outstanding commit row for the
// buffer is zero before the access described by pendingAccessType.
void createCheckOutstandingCommitsCall(ImplicitLocOpBuilder &b, Value buf,
int thread,
uint32_t length, int thread,
StringRef pendingAccessType,
Value pred, MemType memType,
CommitKind::Kind commitKind,
Expand Down
23 changes: 13 additions & 10 deletions include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,33 @@ def TTI_ExperimentalAssertInThreadOp : TTI_Op<"experimental_assert_in_thread", [
}


def TTI_ExperimentalBufferPointersOp : TTI_Op<"experimental_buffer_pointers", [Pure]> {
let summary = "definte an array of pointers to shared memory buffers";
def TTI_ExperimentalBufferDescriptorsOp
: TTI_Op<"experimental_buffer_descriptors", [Pure]> {
let summary = "define an array of buffer descriptors";
let description = [{
Create a tensor of pointers to shared memory buffers.
Create a tensor of buffer descriptors packing 32-bit pointer offsets and
32-bit lengths into 64-bit elements.
}];
let arguments = (ins DenseI32ArrayAttr:$offsets, TT_MemTypeAttr:$memType);
let arguments = (ins DenseI32ArrayAttr:$offsets, DenseI32ArrayAttr:$lengths,
TT_MemTypeAttr:$memType);
let results = (outs TT_Tensor:$result);
let assemblyFormat = [{
$offsets `,` $memType attr-dict `:` type($result)
$offsets `,` $lengths `,` $memType attr-dict `:` type($result)
}];
}

def TTI_ExperimentalMemDescToI64Op : TTI_Op<"experimental_memdesc_to_i64", [Pure]> {
let summary = "Convert a memdesc into its base pointer as i64";
def TTI_ExperimentalMemDescToI32Op : TTI_Op<"experimental_memdesc_to_i32", [Pure]> {
let summary = "Convert a memdesc into its base pointer as i32";
let description = [{
Extract the base pointer from the given memdesc and return it as a 64-bit
Extract the base pointer from the given memdesc and return it as a 32-bit
integer. This can be used to compare the memdesc against tensors of barrier
pointers maintained by the concurrency sanitizer.
}];
let arguments = (ins TTG_MemDescType:$memdesc);
let results = (outs I64:$result);
let results = (outs I32:$result);
let builders = [
OpBuilder<(ins "Value":$memdesc), [{
build($_builder, $_state, $_builder.getI64Type(), memdesc);
build($_builder, $_state, $_builder.getI32Type(), memdesc);
}]>
];
let assemblyFormat = "$memdesc attr-dict `:` type($memdesc)";
Expand Down
9 changes: 6 additions & 3 deletions include/triton/Dialect/TritonInstrument/IR/Utility.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef TRITONINSTRUMENT_UTILITY_H
#define TRITONINSTRUMENT_UTILITY_H

#include "triton/Analysis/BufferRegion.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
Expand Down Expand Up @@ -74,15 +75,17 @@ struct AuxDataMap {
RegionToValueMap readVisibility[numMemTypes];
RegionToValueMap readTracking[numMemTypes];
RegionToValueMap commits[CommitKind::NumCommitKinds];
RegionToValueMap aliasMatrices[numMemTypes];
RegionToValueMap lock;
RegionToValueMap waiting;

void populateAndPassToWarpSpecialize(ModuleOp module);

private:
void getBuffersAndBarriers(ModuleOp module,
SmallVector<SmallVector<uint32_t>, 2> &bufValues,
SmallVector<uint32_t> &barrierValues);
void getBuffersAndBarriers(
ModuleOp module,
SmallVector<SmallVector<triton::BufferRegion>, 2> &bufRegions,
SmallVector<triton::BufferRegion> &barrierRegions);
void passToWarpSpecialize(triton::FuncOp func, ValueType value,
RegionToValueMap &map);
void createInWarpSpecialize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,22 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
InterfaceMethod<"Return the A operand.",
"::mlir::TypedValue<::mlir::triton::gpu::MemDescType>",
"getA">,
InterfaceMethod<"Return the B operand.",
"::mlir::TypedValue<::mlir::triton::gpu::MemDescType>",
"getB">,
InterfaceMethod<"Return the accumulator init flag.",
"::mlir::Value",
"useAccumulator">,
InterfaceMethod<"Set the accumulator init flag.",
"void",
"setUseAccumulator",
(ins "::mlir::Value":$flag)>,
InterfaceMethod<"Return the completion barriers of this MMAv5 op.",
"::mlir::ValueRange",
"getCompletionBarriers">,
InterfaceMethod<"Return the completion barrier predicates of this MMAv5 op.",
"::mlir::ValueRange",
"getCompletionBarrierPreds">,
InterfaceMethod<"Associate a new completion barrier to this MMAv5 op.",
"void",
"addCompletionBarrier",
Expand Down
Loading