Skip to content

Commit 0585186

Browse files
Add driver-provided default allocator for profile scratch (#9596)
Make profile global scratch be allocated by default by driver-provided allocator. Make fpsan and consan use third_party_allocation for global scratch. Remove custom allocator requirement for sanitizers.
1 parent 1954693 commit 0585186

File tree

11 files changed

+143
-141
lines changed

11 files changed

+143
-141
lines changed

include/triton/Dialect/TritonInstrument/IR/Utility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88

99
#include <array>
1010

11+
namespace mlir::triton::gpu {
12+
class GlobalScratchAllocOp;
13+
}
14+
1115
namespace mlir::triton::instrument {
1216
class FunctionBuilder;
1317

@@ -29,6 +33,9 @@ Operation *createStoreScratchMemory(OpBuilder &b, Location loc, Value alloc,
2933
Value tensor, RankedTensorType tensorType);
3034
Value createLoadScratchMemory(OpBuilder &b, Location loc, Value alloc,
3135
RankedTensorType tensorType);
36+
gpu::GlobalScratchAllocOp
37+
createThirdPartyScratchAlloc(OpBuilder &b, Location loc, Type ptrType,
38+
int64_t sizeInBytes, int64_t alignment);
3239
Value expandOuterSlicedDim(OpBuilder &b, Location loc, Value tensor);
3340
RankedTensorType getIntTensorType(Region *region, ArrayRef<int64_t> shape,
3441
unsigned bitWidth);

lib/Dialect/TritonInstrument/IR/Utility.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ Value createInitializedScratchMemory(ImplicitLocOpBuilder &b,
163163
int64_t sizeInBytes = numEls * elSize;
164164
Type ptrType = triton::getPointerType(elType);
165165
auto alloc =
166-
GlobalScratchAllocOp::create(b, ptrType, sizeInBytes, elSize, UnitAttr());
166+
createThirdPartyScratchAlloc(b, b.getLoc(), ptrType, sizeInBytes, elSize);
167167
createStoreScratchMemory(b, b.getLoc(), alloc, tensor, tensor.getType());
168168
return alloc;
169169
}
@@ -182,8 +182,8 @@ Value createZeroInitStateTensor(ImplicitLocOpBuilder &b, int m, int n,
182182
Type ptrType = triton::getPointerType(elType);
183183
// Allocate scratch buffers with 16-byte alignment so global loads and stores
184184
// can be vectorized if possible.
185-
auto alloc = GlobalScratchAllocOp::create(b, ptrType, sizeInBytes,
186-
/*alignment=*/16, UnitAttr());
185+
auto alloc = createThirdPartyScratchAlloc(b, b.getLoc(), ptrType, sizeInBytes,
186+
/*alignment=*/16);
187187
Value cstZero = arith::ConstantIntOp::create(b, 0, bitWidth);
188188
funcBuilder.createFillGlobalTensorCall(b, alloc, type, cstZero);
189189
return alloc;
@@ -245,7 +245,7 @@ bool hasTMAStore(ModuleOp module) {
245245

246246
Value createLockVariable(ImplicitLocOpBuilder &b) {
247247
Type ptrType = triton::getPointerType(b.getI32Type());
248-
auto alloc = GlobalScratchAllocOp::create(b, ptrType, 4, 4, UnitAttr());
248+
auto alloc = createThirdPartyScratchAlloc(b, b.getLoc(), ptrType, 4, 4);
249249
Value zero = arith::ConstantOp::create(b, b.getLoc(), b.getI32Type(),
250250
b.getI32IntegerAttr(0));
251251
triton::AtomicRMWOp::create(b, b.getI32Type(), RMWOp::XCHG, alloc, zero,
@@ -258,6 +258,13 @@ Value createLockVariable(ImplicitLocOpBuilder &b) {
258258

259259
namespace mlir::triton::instrument {
260260

261+
gpu::GlobalScratchAllocOp
262+
createThirdPartyScratchAlloc(OpBuilder &b, Location loc, Type ptrType,
263+
int64_t sizeInBytes, int64_t alignment) {
264+
return gpu::GlobalScratchAllocOp::create(b, loc, ptrType, sizeInBytes,
265+
alignment, b.getUnitAttr());
266+
}
267+
261268
void createAssertInThread(ImplicitLocOpBuilder &b, Value condition,
262269
StringRef message) {
263270
if (auto tensorTy = dyn_cast<RankedTensorType>(condition.getType())) {

lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ class TmemScratchManager {
164164
int64_t alignment = std::max<int64_t>(elSize, 16);
165165
int64_t sizeInBytes = product(memTy.getShape()) * elSize;
166166
auto ptrTy = triton::getPointerType(memTy.getElementType());
167-
auto allocOp = ttg::GlobalScratchAllocOp::create(
168-
rewriter, loc, ptrTy, sizeInBytes, alignment, UnitAttr());
167+
auto allocOp = createThirdPartyScratchAlloc(rewriter, loc, ptrTy,
168+
sizeInBytes, alignment);
169169
allocOp->setDiscardableAttr("tt.divisibility",
170170
rewriter.getI64IntegerAttr(alignment));
171171
Value ptr = allocOp.getResult();
@@ -312,8 +312,8 @@ Value createScratchAndStore(PatternRewriter &rewriter, Location loc, Value val,
312312
int64_t alignment = std::max<int64_t>(elSize, 16);
313313
int64_t sizeInBytes = product(tensorTy.getShape()) * elSize;
314314
auto ptrTy = triton::getPointerType(tensorTy.getElementType());
315-
auto allocOp = ttg::GlobalScratchAllocOp::create(
316-
rewriter, loc, ptrTy, sizeInBytes, alignment, UnitAttr());
315+
auto allocOp = createThirdPartyScratchAlloc(rewriter, loc, ptrTy, sizeInBytes,
316+
alignment);
317317
allocOp->setDiscardableAttr("tt.divisibility",
318318
rewriter.getI64IntegerAttr(alignment));
319319
createStoreScratchMemory(rewriter, loc, allocOp.getResult(), val, tensorTy);
@@ -482,8 +482,8 @@ createOperandScratch(PatternRewriter &rewriter, Location loc,
482482
int64_t alignment = std::max<int64_t>(elSize, 16);
483483
int64_t sizeInBytes = product(memTy.getShape()) * elSize;
484484
auto ptrTy = triton::getPointerType(memTy.getElementType());
485-
auto allocOp = ttg::GlobalScratchAllocOp::create(
486-
rewriter, loc, ptrTy, sizeInBytes, alignment, UnitAttr());
485+
auto allocOp = createThirdPartyScratchAlloc(rewriter, loc, ptrTy, sizeInBytes,
486+
alignment);
487487
allocOp->setDiscardableAttr("tt.divisibility",
488488
rewriter.getI64IntegerAttr(alignment));
489489
Value ptr = allocOp.getResult();

0 commit comments

Comments
 (0)