@@ -163,7 +163,7 @@ Value createInitializedScratchMemory(ImplicitLocOpBuilder &b,
163163 int64_t sizeInBytes = numEls * elSize;
164164 Type ptrType = triton::getPointerType (elType);
165165 auto alloc =
166- GlobalScratchAllocOp::create (b, ptrType, sizeInBytes, elSize, UnitAttr () );
166+ createThirdPartyScratchAlloc (b, b. getLoc (), ptrType, sizeInBytes, elSize);
167167 createStoreScratchMemory (b, b.getLoc (), alloc, tensor, tensor.getType ());
168168 return alloc;
169169}
@@ -182,8 +182,8 @@ Value createZeroInitStateTensor(ImplicitLocOpBuilder &b, int m, int n,
182182 Type ptrType = triton::getPointerType (elType);
183183 // Allocate scratch buffers with 16-byte alignment so global loads and stores
184184 // can be vectorized if possible.
185- auto alloc = GlobalScratchAllocOp::create (b , ptrType, sizeInBytes,
186- /* alignment=*/ 16 , UnitAttr () );
185+ auto alloc = createThirdPartyScratchAlloc (b, b. getLoc () , ptrType, sizeInBytes,
186+ /* alignment=*/ 16 );
187187 Value cstZero = arith::ConstantIntOp::create (b, 0 , bitWidth);
188188 funcBuilder.createFillGlobalTensorCall (b, alloc, type, cstZero);
189189 return alloc;
@@ -245,7 +245,7 @@ bool hasTMAStore(ModuleOp module) {
245245
246246Value createLockVariable (ImplicitLocOpBuilder &b) {
247247 Type ptrType = triton::getPointerType (b.getI32Type ());
248- auto alloc = GlobalScratchAllocOp::create (b, ptrType, 4 , 4 , UnitAttr () );
248+ auto alloc = createThirdPartyScratchAlloc (b, b. getLoc (), ptrType , 4 , 4 );
249249 Value zero = arith::ConstantOp::create (b, b.getLoc (), b.getI32Type (),
250250 b.getI32IntegerAttr (0 ));
251251 triton::AtomicRMWOp::create (b, b.getI32Type (), RMWOp::XCHG, alloc, zero,
@@ -258,6 +258,13 @@ Value createLockVariable(ImplicitLocOpBuilder &b) {
258258
259259namespace mlir ::triton::instrument {
260260
261+ gpu::GlobalScratchAllocOp
262+ createThirdPartyScratchAlloc (OpBuilder &b, Location loc, Type ptrType,
263+ int64_t sizeInBytes, int64_t alignment) {
264+ return gpu::GlobalScratchAllocOp::create (b, loc, ptrType, sizeInBytes,
265+ alignment, b.getUnitAttr ());
266+ }
267+
261268void createAssertInThread (ImplicitLocOpBuilder &b, Value condition,
262269 StringRef message) {
263270 if (auto tensorTy = dyn_cast<RankedTensorType>(condition.getType ())) {
0 commit comments