Skip to content

Commit bd58199

Browse files
[TritonGPU] fix access to InsertSliceAsyncOp's mask
1 parent 70031c1 commit bd58199

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeline.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,17 @@ void LoopPipeliner::emitPrologue() {
301301
}
302302

303303
// If this is a load/async_copy, we need to update the mask
304-
if (llvm::isa<triton::LoadOp, triton::gpu::InsertSliceAsyncOp>(newOp)) {
305-
Value mask = llvm::isa<triton::LoadOp>(newOp) ? newOp->getOperand(1)
306-
: newOp->getOperand(3);
304+
if (Value mask = [&]() {
305+
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(newOp)) {
306+
return loadOp.mask();
307+
} else if (auto insertSliceAsyncOp =
308+
llvm::dyn_cast<triton::gpu::InsertSliceAsyncOp>(
309+
newOp)) {
310+
return insertSliceAsyncOp.mask();
311+
} else {
312+
return mlir::Value();
313+
}
314+
}()) {
307315
// assert(I1 or TensorOf<[I1]>);
308316
OpBuilder::InsertionGuard g(builder);
309317
// TODO: move this out of the loop

0 commit comments

Comments
 (0)