Skip to content

Commit d703656

Browse files
[GLUON] Set proper location on restoring the insert point in gluon (#8531)
`warp_specialize` ops currently have unknown location set in the TTGIR due to a quirk in the code emission in `_semantic.py`: for `warp_specialize` we need save and then restore insert point. Location is being inferred from the insert point, however if insert point happens to be in a place that doesn't have location assigned (end of a block), we set unknown loc. This change is a minimal fix that adds a helper that gets the location from block's parent in such a case. Alternatively we could also save location along with insert point, and then restore it accordingly. This approach is simpler and should help for most cases I could have think of however. This change is important for consan changes I am working on, as it breaks the LLVM backend if we create instrumentation function calls with unknown location inferred from warp_specialize op.
1 parent 869733f commit d703656

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

python/src/ir.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ class TritonOpBuilder {
3535
if (!block.empty())
3636
setLastLoc(block.begin()->getLoc());
3737
else
38-
setLastLoc(builder->getUnknownLoc());
38+
setLastLoc(getLocForBlock(&block));
3939
builder->setInsertionPointToStart(&block);
4040
}
4141

4242
void setInsertionPointToEnd(mlir::Block &block) {
4343
if (!block.empty())
4444
setLastLoc(block.back().getLoc());
4545
else
46-
setLastLoc(builder->getUnknownLoc());
46+
setLastLoc(getLocForBlock(&block));
4747
builder->setInsertionPointToEnd(&block);
4848
}
4949

@@ -53,10 +53,14 @@ class TritonOpBuilder {
5353
}
5454

5555
void restoreInsertionPoint(mlir::OpBuilder::InsertPoint pt) {
56-
if (pt.isSet() && pt.getPoint() != pt.getBlock()->end())
57-
setLastLoc(pt.getPoint()->getLoc());
58-
else
59-
setLastLoc(builder->getUnknownLoc());
56+
setLastLoc(builder->getUnknownLoc());
57+
if (pt.isSet()) {
58+
if (pt.getPoint() != pt.getBlock()->end())
59+
setLastLoc(pt.getPoint()->getLoc());
60+
else
61+
setLastLoc(getLocForBlock(pt.getBlock()));
62+
}
63+
6064
builder->restoreInsertionPoint(pt);
6165
}
6266

@@ -87,4 +91,10 @@ class TritonOpBuilder {
8791
std::unique_ptr<mlir::Location> lastLoc;
8892
bool lineInfoEnabled =
8993
!mlir::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
94+
95+
mlir::Location getLocForBlock(mlir::Block *block) {
96+
if (auto parentOp = block->getParentOp())
97+
return parentOp->getLoc();
98+
return builder->getUnknownLoc();
99+
}
90100
};

0 commit comments

Comments
 (0)