Skip to content

Commit d3c03e6

Browse files
committed
[MLIR][mlir-link] Link required symbols even if they are only declarations
1 parent d52166a commit d3c03e6

File tree

3 files changed

+69
-30
lines changed

3 files changed

+69
-30
lines changed

mlir/include/mlir/Linker/LinkerInterface.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class SymbolLinkerInterface : public LinkerInterface<SymbolLinkerInterface> {
7575
virtual StringRef getSymbol(Operation *op) const = 0;
7676

7777
/// Determines if an operation should be linked into the destination module.
78-
virtual bool isLinkNeeded(ConflictPair pair) const = 0;
78+
virtual bool isLinkNeeded(ConflictPair pair, bool forDependency) const = 0;
7979

8080
/// Checks if an operation conflicts with existing linked operations.
8181
virtual ConflictPair findConflict(Operation *src) const = 0;
@@ -89,6 +89,9 @@ class SymbolLinkerInterface : public LinkerInterface<SymbolLinkerInterface> {
8989
/// Link the operations in the source module into the destination module.
9090
virtual Operation *materialize(ConflictPair pair, ModuleOp dst) const = 0;
9191

92+
/// Dependencies of the given operation required to be linked.
93+
virtual SmallVector<Operation *> dependencies(Operation *op) const = 0;
94+
9295
void setFlags(unsigned flags) { this->flags = flags; }
9396

9497
bool shouldLinkOnlyNeeded() const {

mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ class LLVMSymbolLinkerInterface : public SymbolLinkerInterface {
184184
return ConflictPair::noConflict(src);
185185
}
186186

187-
bool isLinkNeeded(ConflictPair pair) const override {
187+
bool isLinkNeeded(ConflictPair pair, bool forDependency) const override {
188188
assert(canBeLinked(pair.src) && "expected linkable operation");
189189

190190
LLVM::Linkage srcLinkage = getLinkage(pair.src);
@@ -194,19 +194,28 @@ class LLVMSymbolLinkerInterface : public SymbolLinkerInterface {
194194
return true;
195195
}
196196

197-
if (shouldLinkOnlyNeeded()) {
198-
// Don't import globals that are already defined
199-
if (!pair.dst || !isDeclaration(pair.dst))
200-
return false;
201-
}
197+
bool alreadyDeclared = pair.dst && isDeclaration(pair.dst);
198+
199+
// Don't import globals that are already defined
200+
if (shouldLinkOnlyNeeded() && !alreadyDeclared)
201+
return false;
202+
203+
// Always import dependencies that are not yet defined or declared
204+
if (forDependency && !pair.dst)
205+
return true;
206+
202207
if (isDeclaration(pair.src))
203208
return false;
204209

205-
bool keepOnlyInSource = isLocalLinkage(srcLinkage) ||
206-
isLinkOnceLinkage(srcLinkage) ||
207-
isAvailableExternallyLinkage(srcLinkage);
210+
if (pair.hasConflict())
211+
return true;
208212

209-
return pair.dst || shouldOverrideFromSrc() || !keepOnlyInSource;
213+
if (shouldOverrideFromSrc())
214+
return true;
215+
216+
// linkage specifies to keep operation only in source
217+
return !(isLocalLinkage(srcLinkage) || isLinkOnceLinkage(srcLinkage) ||
218+
isAvailableExternallyLinkage(srcLinkage));
210219
}
211220

212221
FailureOr<bool> shouldLinkFromSource(ConflictPair pair) const {
@@ -385,6 +394,22 @@ class LLVMSymbolLinkerInterface : public SymbolLinkerInterface {
385394
return pair.dst;
386395
}
387396

397+
SmallVector<Operation *> dependencies(Operation *op) const override {
398+
SmallVector<Operation *> result;
399+
400+
Operation *symbolTableOp = symbolTable->getOp();
401+
op->walk([&](SymbolUserOpInterface user) {
402+
if (user.getOperation() == op)
403+
return;
404+
405+
if (SymbolRefAttr symbol = user.getUserSymbol())
406+
if (Operation *dep = symbolTable->lookupSymbolIn(symbolTableOp, symbol))
407+
result.push_back(dep);
408+
});
409+
410+
return result;
411+
}
412+
388413
private:
389414
llvm::StringMap<Operation *> summary;
390415

mlir/lib/IR/BuiltinLinkerInterface.cpp

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,33 +35,44 @@ class BuiltinLinkerInterface : public ModuleLinkerInterface {
3535
if (op == src)
3636
return WalkResult::advance();
3737

38-
auto linker = dyn_cast<SymbolLinkerInterface>(op->getDialect());
39-
if (!linker)
40-
return WalkResult::advance();
38+
if (summarize(op, flags, /*forDependency=*/false).failed())
39+
return WalkResult::interrupt();
40+
return WalkResult::advance();
41+
});
4142

42-
// TODO do this in init
43-
linker->setFlags(flags);
43+
return failure(result.wasInterrupted());
44+
}
4445

45-
if (!linker->canBeLinked(op))
46-
return WalkResult::advance();
46+
LogicalResult summarize(Operation *op, unsigned flags, bool forDependency) {
47+
auto linker = dyn_cast<SymbolLinkerInterface>(op->getDialect());
48+
if (!linker)
49+
return success();
4750

48-
ConflictPair conflict = linker->findConflict(op);
49-
if (!linker->isLinkNeeded(conflict))
50-
return WalkResult::advance();
51+
linker->setFlags(flags);
52+
53+
if (!linker->canBeLinked(op))
54+
return success();
5155

52-
if (conflict.hasConflict())
53-
return failed(linker->resolveConflict(conflict))
54-
? WalkResult::interrupt()
55-
: WalkResult::advance();
56+
ConflictPair conflict = linker->findConflict(op);
57+
if (!linker->isLinkNeeded(conflict, forDependency))
58+
return success();
5659

57-
// TODO rename: registerForLink
60+
if (conflict.hasConflict()) {
61+
if (linker->resolveConflict(conflict).failed())
62+
return failure();
63+
} else {
5864
linker->registerForLink(op);
59-
return WalkResult::advance();
60-
});
65+
}
6166

62-
// TODO deal with references
67+
if (forDependency)
68+
return success();
6369

64-
return failure(result.wasInterrupted());
70+
for (Operation *dep : linker->dependencies(op)) {
71+
if (summarize(dep, flags, /*forDependency=*/true).failed())
72+
return failure();
73+
}
74+
75+
return success();
6576
}
6677

6778
LogicalResult link(ModuleOp dst) const override {

0 commit comments

Comments
 (0)