Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/triton/Analysis/BufferRegion.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ class BufferRegionAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
private:
// Global registry of all regions
std::set<BufferRegion> usedBufferRegions[NUM_REGION_TYPES];

static void verifyOpIsSupported(Operation *op);
};

} // namespace mlir::triton
Expand Down
37 changes: 20 additions & 17 deletions include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,25 +117,27 @@ class FunctionBuilder {
// from the visibility bitmask. We know this is safe because there cannot be
// outstanding writes to this buffer at this point.
void createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
uint64_t threadMask, Value pred,
MemType memType, Operation *insertPoint);
uint32_t length, uint64_t threadMask,
Value pred, MemType memType,
Operation *insertPoint);
// setReadVisibility: add the threads set in threadMask to the buffer's read
// visibility bitmask.
void createSetReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
uint64_t threadMask, Value pred,
MemType memType, Operation *insertPoint);
uint32_t length, uint64_t threadMask,
Value pred, MemType memType,
Operation *insertPoint);
// clearWriteTracking: clear all the information about threads writing to a
// buffer.
void createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value buf,
Value pred, MemType memType,
Operation *insertPoint);
uint32_t length, Value pred,
MemType memType, Operation *insertPoint);
// clearReadVisibility: clear the read visibility for a buffer.
void createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
Value pred, MemType memType,
Operation *insertPoint);
uint32_t length, Value pred,
MemType memType, Operation *insertPoint);
// clearReadTracking: clear the read tracking for a buffer.
void createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value buf,
Value pred, MemType memType,
uint32_t length, Value pred, MemType memType,
Operation *insertPoint);
// trackVisibleWrites: snapshot buffers currently visible to the thread into
// the tracking table for a barrier.
Expand All @@ -160,15 +162,15 @@ class FunctionBuilder {
// verifyWriteVisibility: ensure the thread either sees the latest write or no
// other thread is writing the buffer.
void createVerifyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
int thread, StringRef operandName,
Value pred, MemType memType,
Operation *insertPoint);
uint32_t length, int thread,
StringRef operandName, Value pred,
MemType memType, Operation *insertPoint);
// verifyReadVisibility: ensure all reads from the buffer are visible to the
// thread.
void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
int thread, StringRef operandName,
Value pred, MemType memType,
Operation *insertPoint);
uint32_t length, int thread,
StringRef operandName, Value pred,
MemType memType, Operation *insertPoint);
// copyWriteVisibility: replicate the write visibility bit of sourceThread to
// every destination thread in destMask.
void createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread,
Expand All @@ -182,7 +184,8 @@ class FunctionBuilder {
// stageAccessForCommit: mark the buffer as staged (value -1) in the
// outstanding commit table for this thread.
void createStageAccessForCommitCall(ImplicitLocOpBuilder &b, Value buf,
int thread, Value pred, MemType memType,
uint32_t length, int thread, Value pred,
MemType memType,
CommitKind::Kind commitKind,
Operation *insertPoint);
// commitAccesses: convert staged entries to 1 and increment outstanding
Expand All @@ -207,7 +210,7 @@ class FunctionBuilder {
// checkOutstandingCommits: assert that the outstanding commit row for the
// buffer is zero before the access described by pendingAccessType.
void createCheckOutstandingCommitsCall(ImplicitLocOpBuilder &b, Value buf,
int thread,
uint32_t length, int thread,
StringRef pendingAccessType,
Value pred, MemType memType,
CommitKind::Kind commitKind,
Expand Down
23 changes: 13 additions & 10 deletions include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,33 @@ def TTI_ExperimentalAssertInThreadOp : TTI_Op<"experimental_assert_in_thread", [
}


def TTI_ExperimentalBufferPointersOp : TTI_Op<"experimental_buffer_pointers", [Pure]> {
let summary = "definte an array of pointers to shared memory buffers";
def TTI_ExperimentalBufferDescriptorsOp
: TTI_Op<"experimental_buffer_descriptors", [Pure]> {
let summary = "define an array of buffer descriptors";
let description = [{
Create a tensor of pointers to shared memory buffers.
Create a tensor of buffer descriptors packing 32-bit pointer offsets and
32-bit lengths into 64-bit elements.
}];
let arguments = (ins DenseI32ArrayAttr:$offsets, TT_MemTypeAttr:$memType);
let arguments = (ins DenseI32ArrayAttr:$offsets, DenseI32ArrayAttr:$lengths,
TT_MemTypeAttr:$memType);
let results = (outs TT_Tensor:$result);
let assemblyFormat = [{
$offsets `,` $memType attr-dict `:` type($result)
$offsets `,` $lengths `,` $memType attr-dict `:` type($result)
}];
}

def TTI_ExperimentalMemDescToI64Op : TTI_Op<"experimental_memdesc_to_i64", [Pure]> {
let summary = "Convert a memdesc into its base pointer as i64";
def TTI_ExperimentalMemDescToI32Op : TTI_Op<"experimental_memdesc_to_i32", [Pure]> {
let summary = "Convert a memdesc into its base pointer as i32";
let description = [{
Extract the base pointer from the given memdesc and return it as a 64-bit
Extract the base pointer from the given memdesc and return it as a 32-bit
integer. This can be used to compare the memdesc against tensors of barrier
pointers maintained by the concurrency sanitizer.
}];
let arguments = (ins TTG_MemDescType:$memdesc);
let results = (outs I64:$result);
let results = (outs I32:$result);
let builders = [
OpBuilder<(ins "Value":$memdesc), [{
build($_builder, $_state, $_builder.getI64Type(), memdesc);
build($_builder, $_state, $_builder.getI32Type(), memdesc);
}]>
];
let assemblyFormat = "$memdesc attr-dict `:` type($memdesc)";
Expand Down
9 changes: 6 additions & 3 deletions include/triton/Dialect/TritonInstrument/IR/Utility.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef TRITONINSTRUMENT_UTILITY_H
#define TRITONINSTRUMENT_UTILITY_H

#include "triton/Analysis/BufferRegion.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
Expand Down Expand Up @@ -74,15 +75,17 @@ struct AuxDataMap {
RegionToValueMap readVisibility[numMemTypes];
RegionToValueMap readTracking[numMemTypes];
RegionToValueMap commits[CommitKind::NumCommitKinds];
RegionToValueMap aliasMatrices[numMemTypes];
RegionToValueMap lock;
RegionToValueMap waiting;

void populateAndPassToWarpSpecialize(ModuleOp module);

private:
void getBuffersAndBarriers(ModuleOp module,
SmallVector<SmallVector<uint32_t>, 2> &bufValues,
SmallVector<uint32_t> &barrierValues);
void getBuffersAndBarriers(
ModuleOp module,
SmallVector<SmallVector<triton::BufferRegion>, 2> &bufRegions,
SmallVector<triton::BufferRegion> &barrierRegions);
void passToWarpSpecialize(triton::FuncOp func, ValueType value,
RegionToValueMap &map);
void createInWarpSpecialize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,22 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
InterfaceMethod<"Return the A operand.",
"::mlir::TypedValue<::mlir::triton::gpu::MemDescType>",
"getA">,
InterfaceMethod<"Return the B operand.",
"::mlir::TypedValue<::mlir::triton::gpu::MemDescType>",
"getB">,
InterfaceMethod<"Return the accumulator init flag.",
"::mlir::Value",
"useAccumulator">,
InterfaceMethod<"Set the accumulator init flag.",
"void",
"setUseAccumulator",
(ins "::mlir::Value":$flag)>,
InterfaceMethod<"Return the completion barriers of this MMAv5 op.",
"::mlir::ValueRange",
"getCompletionBarriers">,
InterfaceMethod<"Return the completion barrier predicates of this MMAv5 op.",
"::mlir::ValueRange",
"getCompletionBarrierPreds">,
InterfaceMethod<"Associate a new completion barrier to this MMAv5 op.",
"void",
"addCompletionBarrier",
Expand Down
Loading
Loading