Skip to content

Commit 737d8c2

Browse files
committed
[MLIR][mlir-link] Refactor out SymbolAttrLinkerInterface
1 parent d5f5275 commit 737d8c2

File tree

3 files changed

+42
-15
lines changed

3 files changed

+42
-15
lines changed

mlir/include/mlir/Linker/LinkerInterface.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,24 @@ class SymbolLinkerInterface : public LinkerInterface<SymbolLinkerInterface> {
131131
unsigned flags = LinkerFlags::None;
132132
};
133133

134+
//===----------------------------------------------------------------------===//
135+
// SymbolAttrLinkerInterface
136+
//===----------------------------------------------------------------------===//
137+
138+
class SymbolAttrLinkerInterface : public SymbolLinkerInterface {
139+
public:
140+
using SymbolLinkerInterface::SymbolLinkerInterface;
141+
142+
/// Returns the symbol for the given operation.
143+
StringRef getSymbol(Operation *op) const override;
144+
145+
/// Checks if an operation conflicts with existing linked operations.
146+
Conflict findConflict(Operation *src) const override;
147+
148+
protected:
149+
llvm::StringMap<Operation *> summary;
150+
};
151+
134152
//===----------------------------------------------------------------------===//
135153
// SymbolLinkerInterfaceCollection
136154
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -212,26 +212,14 @@ unsigned getBitWidth(LLVM::GlobalOp op) {
212212
// LLVMSymbolLinkerInterface
213213
//===----------------------------------------------------------------------===//
214214

215-
class LLVMSymbolLinkerInterface : public SymbolLinkerInterface {
215+
class LLVMSymbolLinkerInterface : public SymbolAttrLinkerInterface {
216216
public:
217-
using SymbolLinkerInterface::SymbolLinkerInterface;
217+
using SymbolAttrLinkerInterface::SymbolAttrLinkerInterface;
218218

219219
bool canBeLinked(Operation *op) const override {
220220
return isa<LLVM::GlobalOp>(op) || isa<LLVM::LLVMFuncOp>(op);
221221
}
222222

223-
StringRef getSymbol(Operation *op) const override { return symbol(op); }
224-
225-
Conflict findConflict(Operation *src) const override {
226-
assert(canBeLinked(src) && "expected linkable operation");
227-
228-
if (auto it = summary.find(getSymbol(src)); it != summary.end()) {
229-
return {it->second, src};
230-
}
231-
232-
return Conflict::noConflict(src);
233-
}
234-
235223
bool isLinkNeeded(Conflict pair, bool forDependency) const override {
236224
assert(canBeLinked(pair.src) && "expected linkable operation");
237225
if (pair.src == pair.dst)
@@ -466,7 +454,6 @@ class LLVMSymbolLinkerInterface : public SymbolLinkerInterface {
466454
}
467455

468456
SetVector<Operation *> uniqued;
469-
llvm::StringMap<Operation *> summary;
470457
};
471458

472459
//===----------------------------------------------------------------------===//

mlir/lib/Linker/LinkerInterface.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
using namespace mlir;
1414
using namespace mlir::link;
1515

16+
//===----------------------------------------------------------------------===//
17+
// LinkState
18+
//===----------------------------------------------------------------------===//
19+
1620
template <typename CloneFunc>
1721
Operation *cloneImpl(Operation *src, IRMapping &mapping, CloneFunc cloneFunc) {
1822
assert(!mapping.contains(src));
@@ -40,3 +44,21 @@ Operation *LinkState::getDestinationOp() const {
4044
Operation *LinkState::remapped(Operation *src) const {
4145
return mapping.lookupOrNull(src);
4246
}
47+
48+
//===----------------------------------------------------------------------===//
49+
// SymbolAttrLinkerInterface
50+
//===----------------------------------------------------------------------===//
51+
52+
StringRef SymbolAttrLinkerInterface::getSymbol(Operation *op) const {
53+
return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
54+
.getValue();
55+
}
56+
57+
Conflict SymbolAttrLinkerInterface::findConflict(Operation *src) const {
58+
assert(canBeLinked(src) && "expected linkable operation");
59+
StringRef symbol = getSymbol(src);
60+
auto it = summary.find(symbol);
61+
if (it == summary.end())
62+
return Conflict::noConflict(src);
63+
return {it->second, src};
64+
}

0 commit comments

Comments
 (0)